~emersion/soju

05aafb5edffb9f82a204f3a03f159a680bdbfe4b — Simon Ser 30 days ago af1e578
Add message store abstraction

Introduce a messageStore type, which will allow for multiple
implementations (e.g. in the DB or in-memory instead of on-disk).

The message store is per-user so that we don't need to deal with locking
and it's easier to implement per-user limits.
4 files changed, 80 insertions(+), 79 deletions(-)

M downstream.go
R logger.go => msgstore.go
M upstream.go
M user.go
M downstream.go => downstream.go +6 -6
@@ 863,7 863,7 @@ func (dc *downstreamConn) welcome() error {
				continue
			}

			lastID, err := lastMsgID(net, target, time.Now())
			lastID, err := dc.user.msgStore.LastMsgID(net, target, time.Now())
			if err != nil {
				dc.logger.Printf("failed to get last message ID: %v", err)
				continue


@@ 876,7 876,7 @@ func (dc *downstreamConn) welcome() error {
}

func (dc *downstreamConn) sendNetworkHistory(net *network) {
	if dc.caps["draft/chathistory"] || dc.srv.LogPath == "" {
	if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
		return
	}
	for target, history := range net.history {


@@ 890,7 890,7 @@ func (dc *downstreamConn) sendNetworkHistory(net *network) {
		}

		limit := 4000
		history, err := loadHistoryLatestID(net, target, lastDelivered, limit)
		history, err := dc.user.msgStore.LoadLatestID(net, target, lastDelivered, limit)
		if err != nil {
			dc.logger.Printf("failed to send implicit history for %q: %v", target, err)
			continue


@@ 1601,7 1601,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
			}}
		}

		if dc.srv.LogPath == "" {
		if dc.user.msgStore == nil {
			return ircError{&irc.Message{
				Command: irc.ERR_UNKNOWNCOMMAND,
				Params:  []string{dc.nick, subcommand, "Unknown command"},


@@ 1641,9 1641,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
		var history []*irc.Message
		switch subcommand {
		case "BEFORE":
			history, err = loadHistoryBeforeTime(uc.network, entity, timestamp, limit)
			history, err = dc.user.msgStore.LoadBeforeTime(uc.network, entity, timestamp, limit)
		case "AFTER":
			history, err = loadHistoryAfterTime(uc.network, entity, timestamp, limit)
			history, err = dc.user.msgStore.LoadAfterTime(uc.network, entity, timestamp, limit)
		default:
			// TODO: support LATEST, BETWEEN
			return ircError{&irc.Message{

R logger.go => msgstore.go +52 -50
@@ 12,32 12,28 @@ import (
	"gopkg.in/irc.v3"
)

const messageLoggerMaxTries = 100
const messageStoreMaxTries = 100

type messageLogger struct {
	network *network
	entity  string
var escapeFilename = strings.NewReplacer("/", "-", "\\", "-")

// messageStore is a per-user store for IRC messages.
type messageStore struct {
	root string

	path string
	file *os.File
	files map[string]*os.File // indexed by entity
}

func newMessageLogger(network *network, entity string) *messageLogger {
	return &messageLogger{
		network: network,
		entity:  entity,
func newMessageStore(root, username string) *messageStore {
	return &messageStore{
		root:  filepath.Join(root, escapeFilename.Replace(username)),
		files: make(map[string]*os.File),
	}
}

var escapeFilename = strings.NewReplacer("/", "-", "\\", "-")

func logPath(network *network, entity string, t time.Time) string {
	user := network.user
	srv := user.srv

func (ms *messageStore) logPath(network *network, entity string, t time.Time) string {
	year, month, day := t.Date()
	filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
	return filepath.Join(srv.LogPath, escapeFilename.Replace(user.Username), escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename)
	return filepath.Join(ms.root, escapeFilename.Replace(network.GetName()), escapeFilename.Replace(entity), filename)
}

func parseMsgID(s string) (network, entity string, t time.Time, offset int64, err error) {


@@ 64,11 60,11 @@ func nextMsgID(network *network, entity string, t time.Time, f *os.File) (string
	return formatMsgID(network.GetName(), entity, t, offset), nil
}

// lastMsgID queries the last message ID for the given network, entity and
// LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be
// used in history queries.
func lastMsgID(network *network, entity string, t time.Time) (string, error) {
	p := logPath(network, entity, t)
func (ms *messageStore) LastMsgID(network *network, entity string, t time.Time) (string, error) {
	p := ms.logPath(network, entity, t)
	fi, err := os.Stat(p)
	if os.IsNotExist(err) {
		return formatMsgID(network.GetName(), entity, t, -1), nil


@@ 78,7 74,7 @@ func lastMsgID(network *network, entity string, t time.Time) (string, error) {
	return formatMsgID(network.GetName(), entity, t, fi.Size()-1), nil
}

func (ml *messageLogger) Append(msg *irc.Message) (string, error) {
func (ms *messageStore) Append(network *network, entity string, msg *irc.Message) (string, error) {
	s := formatMessage(msg)
	if s == "" {
		return "", nil


@@ 97,44 93,50 @@ func (ml *messageLogger) Append(msg *irc.Message) (string, error) {
	}

	// TODO: enforce maximum open file handles (LRU cache of file handles)
	f := ms.files[entity]

	// TODO: handle non-monotonic clock behaviour
	path := logPath(ml.network, ml.entity, t)
	if ml.path != path {
		if ml.file != nil {
			ml.file.Close()
	path := ms.logPath(network, entity, t)
	if f == nil || f.Name() != path {
		if f != nil {
			f.Close()
		}

		dir := filepath.Dir(path)
		if err := os.MkdirAll(dir, 0700); err != nil {
			return "", fmt.Errorf("failed to create logs directory %q: %v", dir, err)
			return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err)
		}

		f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
		var err error
		f, err = os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
		if err != nil {
			return "", fmt.Errorf("failed to open log file %q: %v", path, err)
			return "", fmt.Errorf("failed to open message log file %q: %v", path, err)
		}

		ml.path = path
		ml.file = f
		ms.files[entity] = f
	}

	msgID, err := nextMsgID(ml.network, ml.entity, t, ml.file)
	msgID, err := nextMsgID(network, entity, t, f)
	if err != nil {
		return "", fmt.Errorf("failed to generate message ID: %v", err)
	}

	_, err = fmt.Fprintf(ml.file, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
	_, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
	if err != nil {
		return "", fmt.Errorf("failed to log message to %q: %v", ml.path, err)
		return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err)
	}

	return msgID, nil
}

func (ml *messageLogger) Close() error {
	if ml.file == nil {
		return nil
func (ms *messageStore) Close() error {
	var closeErr error
	for _, f := range ms.files {
		if err := f.Close(); err != nil {
			closeErr = fmt.Errorf("failed to close message store: %v", err)
		}
	}
	return ml.file.Close()
	return closeErr
}

// formatMessage formats a message log line. It assumes a well-formed IRC


@@ 233,8 235,8 @@ func parseMessage(line, entity string, ref time.Time) (*irc.Message, time.Time, 
	return msg, t, nil
}

func parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) {
	path := logPath(network, entity, ref)
func (ms *messageStore) parseMessagesBefore(network *network, entity string, ref time.Time, limit int, afterOffset int64) ([]*irc.Message, error) {
	path := ms.logPath(network, entity, ref)
	f, err := os.Open(path)
	if err != nil {
		if os.IsNotExist(err) {


@@ 289,8 291,8 @@ func parseMessagesBefore(network *network, entity string, ref time.Time, limit i
	}
}

func parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) {
	path := logPath(network, entity, ref)
func (ms *messageStore) parseMessagesAfter(network *network, entity string, ref time.Time, limit int) ([]*irc.Message, error) {
	path := ms.logPath(network, entity, ref)
	f, err := os.Open(path)
	if err != nil {
		if os.IsNotExist(err) {


@@ 319,12 321,12 @@ func parseMessagesAfter(network *network, entity string, ref time.Time, limit in
	return history, nil
}

func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
func (ms *messageStore) LoadBeforeTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
	history := make([]*irc.Message, limit)
	remaining := limit
	tries := 0
	for remaining > 0 && tries < messageLoggerMaxTries {
		buf, err := parseMessagesBefore(network, entity, t, remaining, -1)
	for remaining > 0 && tries < messageStoreMaxTries {
		buf, err := ms.parseMessagesBefore(network, entity, t, remaining, -1)
		if err != nil {
			return nil, err
		}


@@ 342,13 344,13 @@ func loadHistoryBeforeTime(network *network, entity string, t time.Time, limit i
	return history[remaining:], nil
}

func loadHistoryAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
func (ms *messageStore) LoadAfterTime(network *network, entity string, t time.Time, limit int) ([]*irc.Message, error) {
	var history []*irc.Message
	remaining := limit
	tries := 0
	now := time.Now()
	for remaining > 0 && tries < messageLoggerMaxTries && t.Before(now) {
		buf, err := parseMessagesAfter(network, entity, t, remaining)
	for remaining > 0 && tries < messageStoreMaxTries && t.Before(now) {
		buf, err := ms.parseMessagesAfter(network, entity, t, remaining)
		if err != nil {
			return nil, err
		}


@@ 370,7 372,7 @@ func truncateDay(t time.Time) time.Time {
	return time.Date(year, month, day, 0, 0, 0, 0, t.Location())
}

func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) {
func (ms *messageStore) LoadLatestID(network *network, entity, id string, limit int) ([]*irc.Message, error) {
	var afterTime time.Time
	var afterOffset int64
	if id != "" {


@@ 389,13 391,13 @@ func loadHistoryLatestID(network *network, entity, id string, limit int) ([]*irc
	t := time.Now()
	remaining := limit
	tries := 0
	for remaining > 0 && tries < messageLoggerMaxTries && !truncateDay(t).Before(afterTime) {
	for remaining > 0 && tries < messageStoreMaxTries && !truncateDay(t).Before(afterTime) {
		var offset int64 = -1
		if afterOffset >= 0 && truncateDay(t).Equal(afterTime) {
			offset = afterOffset
		}

		buf, err := parseMessagesBefore(network, entity, t, remaining, offset)
		buf, err := ms.parseMessagesBefore(network, entity, t, remaining, offset)
		if err != nil {
			return nil, err
		}

M upstream.go => upstream.go +3 -12
@@ 81,8 81,6 @@ type upstreamConn struct {

	// set of LIST commands in progress, per downstream
	pendingLISTDownstreamSet map[uint64]struct{}

	messageLoggers map[string]*messageLogger
}

func connectToUpstream(network *network) (*upstreamConn, error) {


@@ 182,7 180,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
		availableChannelModes:    stdChannelModes,
		availableMemberships:     stdMemberships,
		pendingLISTDownstreamSet: make(map[uint64]struct{}),
		messageLoggers:           make(map[string]*messageLogger),
	}
	return uc, nil
}


@@ 1611,16 1608,10 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message
}

func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
	if uc.srv.LogPath == "" {
	if uc.user.msgStore == nil {
		return
	}

	ml, ok := uc.messageLoggers[entity]
	if !ok {
		ml = newMessageLogger(uc.network, entity)
		uc.messageLoggers[entity] = ml
	}

	detached := false
	if ch, ok := uc.network.channels[entity]; ok {
		detached = ch.Detached


@@ 1628,7 1619,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {

	history, ok := uc.network.history[entity]
	if !ok {
		lastID, err := lastMsgID(uc.network, entity, time.Now())
		lastID, err := uc.user.msgStore.LastMsgID(uc.network, entity, time.Now())
		if err != nil {
			uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
			return


@@ 1652,7 1643,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
		}
	}

	msgID, err := ml.Append(msg)
	msgID, err := uc.user.msgStore.Append(uc.network, entity, msg)
	if err != nil {
		uc.logger.Printf("failed to log message: %v", err)
		return

M user.go => user.go +19 -11
@@ 249,6 249,7 @@ type user struct {

	networks        []*network
	downstreamConns []*downstreamConn
	msgStore        *messageStore

	// LIST commands in progress
	pendingLISTs []pendingLIST


@@ 261,11 262,17 @@ type pendingLIST struct {
}

func newUser(srv *Server, record *User) *user {
	var msgStore *messageStore
	if srv.LogPath != "" {
		msgStore = newMessageStore(srv.LogPath, record.Username)
	}

	return &user{
		User:   *record,
		srv:    srv,
		events: make(chan event, 64),
		done:   make(chan struct{}),
		User:     *record,
		srv:      srv,
		events:   make(chan event, 64),
		done:     make(chan struct{}),
		msgStore: msgStore,
	}
}



@@ 312,7 319,14 @@ func (u *user) getNetworkByID(id int64) *network {
}

func (u *user) run() {
	defer close(u.done)
	defer func() {
		if u.msgStore != nil {
			if err := u.msgStore.Close(); err != nil {
				u.srv.Logger.Printf("failed to close message store for user %q: %v", u.Username, err)
			}
		}
		close(u.done)
	}()

	networks, err := u.srv.db.ListNetworks(u.ID)
	if err != nil {


@@ 459,12 473,6 @@ func (u *user) run() {
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
	uc.network.conn = nil

	for _, ml := range uc.messageLoggers {
		if err := ml.Close(); err != nil {
			uc.logger.Printf("failed to close message logger: %v", err)
		}
	}

	uc.endPendingLISTs(true)

	uc.forEachDownstream(func(dc *downstreamConn) {