@@ 15,20 15,117 @@ import (
"git.sr.ht/~whereswaldon/ledger"
)
-type TxService struct {
- txLock sync.Mutex
+// StreamProvider helps contruct streams that emit a single, shared value.
+// S is an internal state type, protected by the stream provider's lock,
+// that is used to track the current state of the provider.
+// T is the result type emitted on streams from this provider.
+type StreamProvider[S, T any] struct {
+ lock sync.Mutex
+ broadcast *sync.Cond
+ value S
+ valuer func(S) (stream.Result[T], bool)
+}
+
+// NewStreamProvider constructs a stream provider using the provided
+// valuer function to transform its current state (S) into a stream result.
+// The boolean return value indicates whether the result should be emitted
+// over the stream or discarded. The valuer should deep copy all data it uses
+// in T to ensure that other invocations of valuer do not reference the
+// same memory.
+func NewStreamProvider[S, T any](valuer func(S) (stream.Result[T], bool)) *StreamProvider[S, T] {
+ sp := &StreamProvider[S, T]{
+ valuer: valuer,
+ }
+ sp.broadcast = sync.NewCond(&sp.lock)
+ return sp
+}
+
+// Update runs fn with the provider's lock held, passing the current
+// state to fn and setting the state to the return value of fn.
+func (s *StreamProvider[S, T]) Update(fn func(oldState S) S) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ s.value = fn(s.value)
+ s.broadcast.Broadcast()
+}
+
+// newValue returns the result of the valuer function for S's current
+// value.
+func (s *StreamProvider[S, T]) newValue() (stream.Result[T], bool) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ return s.valuer(s.value)
+}
+
+// Stream returns a stream bound to the lifetime of the provided context.
+// Each time Update is invoked on s, all streams will run s's valuer function
+// and will emit the results if the valuer's second return value was true.
+func (s *StreamProvider[S, T]) Stream(ctx context.Context) <-chan stream.Result[T] {
+ out := make(chan stream.Result[T])
+ broadcasted := make(chan struct{})
+ go func() {
+ defer close(broadcasted)
+ for {
+ s.lock.Lock()
+ s.broadcast.Wait()
+ select {
+ case <-ctx.Done():
+ s.lock.Unlock()
+ return
+ case broadcasted <- struct{}{}:
+ s.lock.Unlock()
+ }
+ }
+ }()
+ go func() {
+ defer close(out)
+ emit := out
+ emitVal, shouldEmit := s.newValue()
+ if !shouldEmit {
+ emit = nil
+ }
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case emit <- emitVal:
+ emit = nil
+ case <-broadcasted:
+ emit = out
+ emitVal, shouldEmit = s.newValue()
+ if !shouldEmit {
+ emit = nil
+ }
+ }
+ }
+ }()
+ return out
+}
+
+type txServiceState struct {
loaded bool
transactions []*ledger.Transaction
loadErr error
txFile string
+}
- broadcast *sync.Cond
+type TxService struct {
+ provider *StreamProvider[txServiceState, []*ledger.Transaction]
}
func NewTxService() *TxService {
- tx := &TxService{}
- tx.broadcast = sync.NewCond(&tx.txLock)
- return tx
+ return &TxService{
+ provider: NewStreamProvider(func(state txServiceState) (stream.Result[[]*ledger.Transaction], bool) {
+ txs := make([]*ledger.Transaction, len(state.transactions))
+ for i, tx := range state.transactions {
+ txCopy := *tx
+ txs[i] = &(txCopy)
+ }
+ loadErr := state.loadErr
+ loaded := state.loaded
+ return stream.ResultFrom(txs, loadErr), loaded
+ }),
+ }
}
// OpenLedger reconfigures the service using the contents of the provided
@@ 36,9 133,10 @@ func (t *TxService) OpenLedger(contents io.ReadCloser) {
var err error
defer func() {
if err != nil {
- t.txLock.Lock()
- t.loadErr = err
- t.txLock.Unlock()
+ t.provider.Update(func(state txServiceState) txServiceState {
+ state.loadErr = err
+ return state
+ })
}
}()
defer func() {
@@ 92,85 190,36 @@ func (t *TxService) OpenLedger(contents io.ReadCloser) {
err = fmt.Errorf("failed parsing input file: %w", err)
return
}
- t.txLock.Lock()
- defer t.txLock.Unlock()
- t.transactions = transactions
- t.txFile = outFilePath
- t.loadErr = nil
- t.broadcast.Broadcast()
- t.loaded = true
+ t.provider.Update(func(state txServiceState) txServiceState {
+ return txServiceState{
+ transactions: transactions,
+ txFile: outFilePath,
+ loadErr: nil,
+ loaded: true,
+ }
+ })
}
func (t *TxService) UpdateLedger(txs []*ledger.Transaction) {
- t.txLock.Lock()
- defer t.txLock.Unlock()
- var b bytes.Buffer
- for _, tx := range txs {
- b.Write(toTxBytes(*tx))
- }
- if err := os.WriteFile(t.txFile, b.Bytes(), 0o644); err != nil {
- t.loadErr = fmt.Errorf("failed rewriting input file: %w", err)
- return
- }
- transactions, err := ledger.ParseLedgerFile(t.txFile)
- if err != nil {
- t.loadErr = fmt.Errorf("failed parsing input file: %w", err)
- return
- }
- t.transactions = transactions
- t.broadcast.Broadcast()
-}
-
-func (t *TxService) TxStream(ctx context.Context) <-chan stream.Result[[]*ledger.Transaction] {
- out := make(chan stream.Result[[]*ledger.Transaction])
- broadcasted := make(chan struct{})
- newValue := func() (stream.Result[[]*ledger.Transaction], bool) {
- t.txLock.Lock()
- txs := make([]*ledger.Transaction, len(t.transactions))
- for i, tx := range t.transactions {
- txCopy := *tx
- txs[i] = &(txCopy)
+ t.provider.Update(func(state txServiceState) txServiceState {
+ var b bytes.Buffer
+ for _, tx := range txs {
+ b.Write(toTxBytes(*tx))
}
- loadErr := t.loadErr
- loaded := t.loaded
- t.txLock.Unlock()
- return stream.ResultFrom(txs, loadErr), loaded
- }
- go func() {
- defer close(broadcasted)
- for {
- t.txLock.Lock()
- t.broadcast.Wait()
- select {
- case <-ctx.Done():
- t.txLock.Unlock()
- return
- case broadcasted <- struct{}{}:
- t.txLock.Unlock()
- }
+ if err := os.WriteFile(state.txFile, b.Bytes(), 0o644); err != nil {
+ state.loadErr = fmt.Errorf("failed rewriting input file: %w", err)
+ return state
}
- }()
- go func() {
- defer close(out)
- emit := out
- emitVal, shouldEmit := newValue()
- if !shouldEmit {
- emit = nil
- }
- for {
- select {
- case <-ctx.Done():
- return
- case emit <- emitVal:
- emit = nil
- case <-broadcasted:
- emit = out
- emitVal, shouldEmit = newValue()
- if !shouldEmit {
- emit = nil
- }
- }
+ transactions, err := ledger.ParseLedgerFile(state.txFile)
+ if err != nil {
+ state.loadErr = fmt.Errorf("failed parsing input file: %w", err)
+ return state
}
- }()
- return out
+ state.transactions = transactions
+ return state
+ })
+}
+
+func (t *TxService) TxStream(ctx context.Context) <-chan stream.Result[[]*ledger.Transaction] {
+ return t.provider.Stream(ctx)
}