~whereswaldon/sprout-go

ref: relay sprout-go/cmd/relay/worker.go -rw-r--r-- 7.2 KiB View raw
82082f28Chris Waldon Implement list verb in example relay 8 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
package main

import (
	"fmt"
	"log"
	"net"

	forest "git.sr.ht/~whereswaldon/forest-go"
	"git.sr.ht/~whereswaldon/forest-go/fields"
	sprout "git.sr.ht/~whereswaldon/sprout-go"
)

type Worker struct {
	Done <-chan struct{}
	*sprout.Conn
	*log.Logger
	*Session
	*MessageStore
	subscriptionID int
}

func NewWorker(done <-chan struct{}, conn net.Conn, store *MessageStore) (*Worker, error) {
	w := &Worker{
		Done:         done,
		MessageStore: store,
	}
	var err error
	w.Conn, err = sprout.NewConn(conn)
	if err != nil {
		return nil, fmt.Errorf("failed to create sprout conn: %w", err)
	}
	w.Session = NewSession()
	w.Conn.OnVersion = w.OnVersion
	w.Conn.OnList = w.OnList
	w.Conn.OnQuery = w.OnQuery
	w.Conn.OnAncestry = w.OnAncestry
	w.Conn.OnLeavesOf = w.OnLeavesOf
	w.Conn.OnResponse = w.OnResponse
	w.Conn.OnSubscribe = w.OnSubscribe
	w.Conn.OnUnsubscribe = w.OnUnsubscribe
	w.Conn.OnStatus = w.OnStatus
	w.Conn.OnAnnounce = w.OnAnnounce
	return w, nil
}

func (c *Worker) Run() {
	defer func() {
		if err := c.Conn.Conn.Close(); err != nil {
			c.Printf("Failed closing connection: %v", err)
			return
		}
		c.Printf("Closed network connection")
	}()
	defer c.Printf("Shutting down")
	c.subscriptionID = c.MessageStore.SubscribeToNewMessages(c.HandleNewNode)
	defer c.MessageStore.UnsubscribeToNewMessages(c.subscriptionID)
	for {
		if err := c.ReadMessage(); err != nil {
			c.Printf("failed to read sprout message: %v", err)
			return
		}
		select {
		case <-c.Done:
			c.Printf("Done channel closed")
			return
		default:
		}
	}
}

func (c *Worker) HandleNewNode(node forest.Node) {
	log.Printf("Got new node: %v", node)
	switch n := node.(type) {
	case *forest.Identity:
		// shouldn't just announce random user ids unsolicted
	case *forest.Community:
		// maybe we should announce new communities?
	case *forest.Reply:
		if c.IsSubscribed(&n.CommunityID) {
			if _, err := c.SendAnnounce([]forest.Node{n}); err != nil {
				c.Printf("Error announcing new reply: %v", err)
			}
		}
	default:
		log.Printf("Unknown node type: %T", n)
	}
}

func (c *Worker) OnVersion(s *sprout.Conn, messageID sprout.MessageID, major, minor int) error {
	c.Printf("Received version: id:%d major:%d minor:%d", messageID, major, minor)
	if major < sprout.CurrentMajor {
		if err := s.SendStatus(messageID, sprout.ErrorProtocolTooOld); err != nil {
			return fmt.Errorf("Failed to send protocol too old message: %w", err)
		}
		return nil
	}
	if major > sprout.CurrentMajor {
		if err := s.SendStatus(messageID, sprout.ErrorProtocolTooNew); err != nil {
			return fmt.Errorf("Failed to send protocol too new message: %w", err)
		}
		return nil
	}
	if err := s.SendStatus(messageID, sprout.StatusOk); err != nil {
		return fmt.Errorf("Failed to send okay message: %w", err)
	}
	return nil
}

func (c *Worker) OnList(s *sprout.Conn, messageID sprout.MessageID, nodeType fields.NodeType, quantity int) error {
	// requires better iteration on Store types
	recentNodes, err := c.MessageStore.Recent(nodeType, quantity)
	if err != nil {
		return fmt.Errorf("error listing recent nodes of type %d: %w", nodeType, err)
	}
	return s.SendResponse(messageID, recentNodes)
}

func (c *Worker) OnQuery(s *sprout.Conn, messageID sprout.MessageID, nodeIds []*fields.QualifiedHash) error {
	results := make([]forest.Node, 0, len(nodeIds))
	for _, id := range nodeIds {
		node, present, err := c.MessageStore.Get(id)
		if err != nil {
			return fmt.Errorf("failed checking for node %v in store: %w", id, err)
		} else if present {
			results = append(results, node)
		}
	}
	return s.SendResponse(messageID, results)
}

func (c *Worker) OnAncestry(s *sprout.Conn, messageID sprout.MessageID, nodeID *fields.QualifiedHash, levels int) error {
	ancestors := make([]forest.Node, 0, 1024)
	currentNode, known, err := c.MessageStore.Get(nodeID)
	if err != nil {
		return fmt.Errorf("failed looking for node %v: %w", nodeID, err)
	} else if !known {
		return fmt.Errorf("asked for ancestry of unknown node %v", nodeID)
	}
	for i := 0; i < levels; i++ {
		if currentNode.ParentID().Equals(fields.NullHash()) {
			// no parent, we're done
			break
		}
		parentNode, known, err := c.MessageStore.Get(currentNode.ParentID())
		if err != nil {
			return fmt.Errorf("couldn't look up node with id %v (parent of %v): %w", currentNode.ParentID(), currentNode.ID(), err)
		} else if !known {
			// we don't know any more ancestry, so we're done
			break
		}
		ancestors = append(ancestors, parentNode)
		currentNode = parentNode
	}
	return s.SendResponse(messageID, ancestors)
}

func (c *Worker) OnLeavesOf(s *sprout.Conn, messageID sprout.MessageID, nodeID *fields.QualifiedHash, quantity int) error {
	descendants := make([]*fields.QualifiedHash, 0, 1024)
	descendants = append(descendants, nodeID)
	leaves := make([]forest.Node, 0, 1024)
	for len(descendants) > 0 {
		children, err := c.MessageStore.store.Children(descendants[0])
		if err != nil {
			return fmt.Errorf("failed fetching children for %v: %w", descendants[0], err)
		}
		if len(children) == 0 {
			node, has, err := c.MessageStore.Get(descendants[0])
			if err != nil {
				return fmt.Errorf("failed fetching node for %v: %w", descendants[0], err)
			} else if !has {
				// not sure what to do here
				continue
			}
			leaves = append(leaves, node)
			continue
		}
		descendants = descendants[1:]
		for _, child := range children {
			descendants = append(descendants, child)
		}
	}
	return s.SendResponse(messageID, leaves)
}

func (c *Worker) OnResponse(s *sprout.Conn, target sprout.MessageID, nodes []forest.Node) error {
	for _, node := range nodes {
		if err := c.MessageStore.Add(node, c.subscriptionID); err != nil {
			return fmt.Errorf("failed to add node to store: %w", err)
		}
	}
	return nil
}

func (c *Worker) OnSubscribe(s *sprout.Conn, messageID sprout.MessageID, nodeID *fields.QualifiedHash) (err error) {
	defer func() {
		if err != nil {
			err = fmt.Errorf("Error during subscribe: %w", err)
		}
	}()
	c.Subscribe(nodeID)
	if err := s.SendStatus(messageID, sprout.StatusOk); err != nil {
		return fmt.Errorf("Failed to send okay status: %w", err)
	}
	return nil
}

func (c *Worker) OnUnsubscribe(s *sprout.Conn, messageID sprout.MessageID, nodeID *fields.QualifiedHash) (err error) {
	defer func() {
		if err != nil {
			err = fmt.Errorf("Error during unsubscribe: %w", err)
		}
	}()
	c.Unsubscribe(nodeID)
	if err := s.SendStatus(messageID, sprout.StatusOk); err != nil {
		return fmt.Errorf("Failed to send okay status: %w", err)
	}
	return nil
}

func (c *Worker) OnStatus(s *sprout.Conn, messageID sprout.MessageID, code sprout.StatusCode) error {
	c.Printf("Received status %d for message %d", code, messageID)
	return nil
}

func (c *Worker) OnAnnounce(s *sprout.Conn, messageID sprout.MessageID, nodes []forest.Node) error {
	var err error
	for _, node := range nodes {
		switch n := node.(type) {
		case *forest.Identity:
			err = c.MessageStore.Add(n, c.subscriptionID)
		case *forest.Community:
			err = c.MessageStore.Add(n, c.subscriptionID)
		case *forest.Reply:
			if c.Session.IsSubscribed(&n.CommunityID) {
				err = c.MessageStore.Add(n, c.subscriptionID)
			}
		default:
			err = fmt.Errorf("Unknown node type announced: %T", node)
		}
	}
	if err != nil {
		return fmt.Errorf("Failed handling announce node: %w", err)
	}
	return s.SendStatus(messageID, sprout.StatusOk)
}