~emersion/soju

c709ebfc912cfca9b9c412bc27bd811d5115ba51 — Simon Ser a month ago bee2001
Add network update command

The user.updateNetwork function is a bit involved because we need to
make sure that the upstream connection is closed before re-connecting
(would otherwise cause "Nick already used" errors) and that the
downstream connections' state is kept in sync.

References: https://todo.sr.ht/~emersion/soju/17
2 files changed, 248 insertions(+), 85 deletions(-)

M service.go
M user.go
M service.go => service.go +126 -42
@@ 118,7 118,7 @@ func init() {
		"network": {
			children: serviceCommandSet{
				"create": {
					usage:  "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]",
					usage:  "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
					desc:   "add a new network",
					handle: handleServiceCreateNetwork,
				},


@@ 126,6 126,11 @@ func init() {
					desc:   "show a list of saved networks and their current status",
					handle: handleServiceNetworkStatus,
				},
				"update": {
					usage: "[-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
					desc:  "update a network",
					handle: handleServiceNetworkUpdate,
				},
				"delete": {
					usage:  "<name>",
					desc:   "delete a network",


@@ 338,65 343,115 @@ func newFlagSet() *flag.FlagSet {
	return fs
}

type stringSliceVar []string
type stringSliceFlag []string

func (v *stringSliceVar) String() string {
func (v *stringSliceFlag) String() string {
	return fmt.Sprint([]string(*v))
}

func (v *stringSliceVar) Set(s string) error {
func (v *stringSliceFlag) Set(s string) error {
	*v = append(*v, s)
	return nil
}

func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
	fs := newFlagSet()
	addr := fs.String("addr", "", "")
	name := fs.String("name", "", "")
	username := fs.String("username", "", "")
	pass := fs.String("pass", "", "")
	realname := fs.String("realname", "", "")
	nick := fs.String("nick", "", "")
	var connectCommands stringSliceVar
	fs.Var(&connectCommands, "connect-command", "")
// stringPtrFlag is a flag value populating a string pointer. This allows to
// disambiguate between a flag that hasn't been set and a flag that has been
// set to an empty string.
type stringPtrFlag struct {
	ptr **string
}

	if err := fs.Parse(params); err != nil {
		return err
	}
	if *addr == "" {
		return fmt.Errorf("flag -addr is required")
func (f stringPtrFlag) String() string {
	if *f.ptr == nil {
		return ""
	}
	return **f.ptr
}

func (f stringPtrFlag) Set(s string) error {
	*f.ptr = &s
	return nil
}

type networkFlagSet struct {
	*flag.FlagSet
	Addr, Name, Nick, Username, Pass, Realname *string
	ConnectCommands []string
}

func newNetworkFlagSet() *networkFlagSet {
	fs := &networkFlagSet{FlagSet: newFlagSet()}
	fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
	fs.Var(stringPtrFlag{&fs.Name}, "name", "")
	fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
	fs.Var(stringPtrFlag{&fs.Username}, "username", "")
	fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
	fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
	fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
	return fs
}

	if addrParts := strings.SplitN(*addr, "://", 2); len(addrParts) == 2 {
		scheme := addrParts[0]
		switch scheme {
		case "ircs", "irc+insecure":
		default:
			return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
func (fs *networkFlagSet) update(network *Network) error {
	if fs.Addr != nil {
		if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
			scheme := addrParts[0]
			switch scheme {
			case "ircs", "irc+insecure":
			default:
				return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
			}
		}
		network.Addr = *fs.Addr
	}

	for _, command := range connectCommands {
		_, err := irc.ParseMessage(command)
		if err != nil {
			return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
	if fs.Name != nil {
		network.Name = *fs.Name
	}
	if fs.Nick != nil {
		network.Nick = *fs.Nick
	}
	if fs.Username != nil {
		network.Username = *fs.Username
	}
	if fs.Pass != nil {
		network.Pass = *fs.Pass
	}
	if fs.Realname != nil {
		network.Realname = *fs.Realname
	}
	if fs.ConnectCommands != nil {
		if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
			network.ConnectCommands = nil
		} else {
			for _, command := range fs.ConnectCommands {
				_, err := irc.ParseMessage(command)
				if err != nil {
					return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
				}
			}
			network.ConnectCommands = fs.ConnectCommands
		}
	}
	return nil
}

	if *nick == "" {
		*nick = dc.nick
func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
	fs := newNetworkFlagSet()
	if err := fs.Parse(params); err != nil {
		return err
	}
	if fs.Addr == nil {
		return fmt.Errorf("flag -addr is required")
	}

	var err error
	network, err := dc.user.createNetwork(&Network{
		Addr:            *addr,
		Name:            *name,
		Username:        *username,
		Pass:            *pass,
		Realname:        *realname,
		Nick:            *nick,
		ConnectCommands: connectCommands,
	})
	record := &Network{
		Addr: *fs.Addr,
		Nick: dc.nick,
	}
	if err := fs.update(record); err != nil {
		return err
	}

	network, err := dc.user.createNetwork(record)
	if err != nil {
		return fmt.Errorf("could not create network: %v", err)
	}


@@ 441,6 496,35 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
	return nil
}

func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
	if len(params) < 1 {
		return fmt.Errorf("expected exactly one argument")
	}

	fs := newNetworkFlagSet()
	if err := fs.Parse(params[1:]); err != nil {
		return err
	}

	net := dc.user.getNetwork(params[0])
	if net == nil {
		return fmt.Errorf("unknown network %q", params[0])
	}

	record := net.Network // copy network record because we'll mutate it
	if err := fs.update(&record); err != nil {
		return err
	}

	network, err := dc.user.updateNetwork(&record)
	if err != nil {
		return fmt.Errorf("could not update network: %v", err)
	}

	sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
	return nil
}

func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
	if len(params) != 1 {
		return fmt.Errorf("expected exactly one argument")

M user.go => user.go +122 -43
@@ 272,6 272,15 @@ func (u *user) getNetwork(name string) *network {
	return nil
}

func (u *user) getNetworkByID(id int64) *network {
	for _, net := range u.networks {
		if net.ID == id {
			return net
		}
	}
	return nil
}

func (u *user) run() {
	networks, err := u.srv.db.ListNetworks(u.Username)
	if err != nil {


@@ 309,31 318,18 @@ func (u *user) run() {
			})
			uc.network.lastError = nil
		case eventUpstreamDisconnected:
			uc := e.uc

			uc.network.conn = nil

			for _, ml := range uc.messageLoggers {
				if err := ml.Close(); err != nil {
					uc.logger.Printf("failed to close message logger: %v", err)
				}
			}

			uc.endPendingLISTs(true)

			uc.forEachDownstream(func(dc *downstreamConn) {
				dc.updateSupportedCaps()
			})

			if uc.network.lastError == nil {
				uc.forEachDownstream(func(dc *downstreamConn) {
					sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
				})
			}
			u.handleUpstreamDisconnected(e.uc)
		case eventUpstreamConnectionError:
			net := e.net

			if net.lastError == nil || net.lastError.Error() != e.err.Error() {
			stopped := false
			select {
			case <-net.stopped:
				stopped = true
			default:
			}

			if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
				net.forEachDownstream(func(dc *downstreamConn) {
					sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
				})


@@ 425,45 421,128 @@ func (u *user) run() {
	}
}

func (u *user) createNetwork(net *Network) (*network, error) {
	if net.ID != 0 {
func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
	uc.network.conn = nil

	for _, ml := range uc.messageLoggers {
		if err := ml.Close(); err != nil {
			uc.logger.Printf("failed to close message logger: %v", err)
		}
	}

	uc.endPendingLISTs(true)

	uc.forEachDownstream(func(dc *downstreamConn) {
		dc.updateSupportedCaps()
	})

	if uc.network.lastError == nil {
		uc.forEachDownstream(func(dc *downstreamConn) {
			sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
		})
	}
}

func (u *user) addNetwork(network *network) {
	u.networks = append(u.networks, network)
	go network.run()
}

func (u *user) removeNetwork(network *network) {
	network.stop()

	u.forEachDownstream(func(dc *downstreamConn) {
		if dc.network != nil && dc.network == network {
			dc.Close()
		}
	})

	for i, net := range u.networks {
		if net == network {
			u.networks = append(u.networks[:i], u.networks[i+1:]...)
			return
		}
	}

	panic("tried to remove a non-existing network")
}

func (u *user) createNetwork(record *Network) (*network, error) {
	if record.ID != 0 {
		panic("tried creating an already-existing network")
	}

	network := newNetwork(u, net, nil)
	network := newNetwork(u, record, nil)
	err := u.srv.db.StoreNetwork(u.Username, &network.Network)
	if err != nil {
		return nil, err
	}

	u.networks = append(u.networks, network)
	u.addNetwork(network)

	go network.run()
	return network, nil
}

func (u *user) deleteNetwork(id int64) error {
	for i, net := range u.networks {
		if net.ID != id {
			continue
		}
func (u *user) updateNetwork(record *Network) (*network, error) {
	if record.ID == 0 {
		panic("tried updating a new network")
	}

	network := u.getNetworkByID(record.ID)
	if network == nil {
		panic("tried updating a non-existing network")
	}

	if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
		return nil, err
	}

	// Most network changes require us to re-connect to the upstream server

	channels := make([]Channel, 0, len(network.channels))
	for _, ch := range network.channels {
		channels = append(channels, *ch)
	}

	updatedNetwork := newNetwork(u, record, channels)

		if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
			return err
	// If we're currently connected, disconnect and perform the necessary
	// bookkeeping
	if network.conn != nil {
		network.stop()
		// Note: this will set network.conn to nil
		u.handleUpstreamDisconnected(network.conn)
	}

	// Patch downstream connections to use our fresh updated network
	u.forEachDownstream(func(dc *downstreamConn) {
		if dc.network != nil && dc.network == network {
			dc.network = updatedNetwork
		}
	})

		u.forEachDownstream(func(dc *downstreamConn) {
			if dc.network != nil && dc.network == net {
				dc.Close()
			}
		})
	// We need to remove the network after patching downstream connections,
	// otherwise they'll get closed
	u.removeNetwork(network)

	// This will re-connect to the upstream server
	u.addNetwork(updatedNetwork)

		net.stop()
		u.networks = append(u.networks[:i], u.networks[i+1:]...)
		return nil
	return updatedNetwork, nil
}

func (u *user) deleteNetwork(id int64) error {
	network := u.getNetworkByID(id)
	if network == nil {
		panic("tried deleting a non-existing network")
	}

	panic("tried deleting a non-existing network")
	if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
		return err
	}

	u.removeNetwork(network)
	return nil
}

func (u *user) updatePassword(hashed string) error {