From 4b42cdd3896ad988e58ba043378c515b437ebc6a Mon Sep 17 00:00:00 2001 From: Chris Waldon Date: Sun, 26 Feb 2023 10:55:04 -0500 Subject: [PATCH] cmd/rosebud: refactor tx service to extract helper type Signed-off-by: Chris Waldon --- cmd/rosebud/tx-service.go | 217 +++++++++++++++++++++++--------------- 1 file changed, 133 insertions(+), 84 deletions(-) diff --git a/cmd/rosebud/tx-service.go b/cmd/rosebud/tx-service.go index 5aa705e..66e7eed 100644 --- a/cmd/rosebud/tx-service.go +++ b/cmd/rosebud/tx-service.go @@ -15,20 +15,117 @@ import ( "git.sr.ht/~whereswaldon/ledger" ) -type TxService struct { - txLock sync.Mutex +// StreamProvider helps contruct streams that emit a single, shared value. +// S is an internal state type, protected by the stream provider's lock, +// that is used to track the current state of the provider. +// T is the result type emitted on streams from this provider. +type StreamProvider[S, T any] struct { + lock sync.Mutex + broadcast *sync.Cond + value S + valuer func(S) (stream.Result[T], bool) +} + +// NewStreamProvider constructs a stream provider using the provided +// valuer function to transform its current state (S) into a stream result. +// The boolean return value indicates whether the result should be emitted +// over the stream or discarded. The valuer should deep copy all data it uses +// in T to ensure that other invocations of valuer do not reference the +// same memory. +func NewStreamProvider[S, T any](valuer func(S) (stream.Result[T], bool)) *StreamProvider[S, T] { + sp := &StreamProvider[S, T]{ + valuer: valuer, + } + sp.broadcast = sync.NewCond(&sp.lock) + return sp +} + +// Update runs fn with the provider's lock held, passing the current +// state to fn and setting the state to the return value of fn. +func (s *StreamProvider[S, T]) Update(fn func(oldState S) S) { + s.lock.Lock() + defer s.lock.Unlock() + s.value = fn(s.value) + s.broadcast.Broadcast() +} + +// newValue returns the result of the valuer function for S's current +// value. +func (s *StreamProvider[S, T]) newValue() (stream.Result[T], bool) { + s.lock.Lock() + defer s.lock.Unlock() + return s.valuer(s.value) +} + +// Stream returns a stream bound to the lifetime of the provided context. +// Each time Update is invoked on s, all streams will run s's valuer function +// and will emit the results if the valuer's second return value was true. +func (s *StreamProvider[S, T]) Stream(ctx context.Context) <-chan stream.Result[T] { + out := make(chan stream.Result[T]) + broadcasted := make(chan struct{}) + go func() { + defer close(broadcasted) + for { + s.lock.Lock() + s.broadcast.Wait() + select { + case <-ctx.Done(): + s.lock.Unlock() + return + case broadcasted <- struct{}{}: + s.lock.Unlock() + } + } + }() + go func() { + defer close(out) + emit := out + emitVal, shouldEmit := s.newValue() + if !shouldEmit { + emit = nil + } + for { + select { + case <-ctx.Done(): + return + case emit <- emitVal: + emit = nil + case <-broadcasted: + emit = out + emitVal, shouldEmit = s.newValue() + if !shouldEmit { + emit = nil + } + } + } + }() + return out +} + +type txServiceState struct { loaded bool transactions []*ledger.Transaction loadErr error txFile string +} - broadcast *sync.Cond +type TxService struct { + provider *StreamProvider[txServiceState, []*ledger.Transaction] } func NewTxService() *TxService { - tx := &TxService{} - tx.broadcast = sync.NewCond(&tx.txLock) - return tx + return &TxService{ + provider: NewStreamProvider(func(state txServiceState) (stream.Result[[]*ledger.Transaction], bool) { + txs := make([]*ledger.Transaction, len(state.transactions)) + for i, tx := range state.transactions { + txCopy := *tx + txs[i] = &(txCopy) + } + loadErr := state.loadErr + loaded := state.loaded + return stream.ResultFrom(txs, loadErr), loaded + }), + } } // OpenLedger reconfigures the service using the contents of the provided @@ -36,9 +133,10 @@ func (t *TxService) OpenLedger(contents io.ReadCloser) { var err error defer func() { if err != nil { - t.txLock.Lock() - t.loadErr = err - t.txLock.Unlock() + t.provider.Update(func(state txServiceState) txServiceState { + state.loadErr = err + return state + }) } }() defer func() { @@ -92,85 +190,36 @@ func (t *TxService) OpenLedger(contents io.ReadCloser) { err = fmt.Errorf("failed parsing input file: %w", err) return } - t.txLock.Lock() - defer t.txLock.Unlock() - t.transactions = transactions - t.txFile = outFilePath - t.loadErr = nil - t.broadcast.Broadcast() - t.loaded = true + t.provider.Update(func(state txServiceState) txServiceState { + return txServiceState{ + transactions: transactions, + txFile: outFilePath, + loadErr: nil, + loaded: true, + } + }) } func (t *TxService) UpdateLedger(txs []*ledger.Transaction) { - t.txLock.Lock() - defer t.txLock.Unlock() - var b bytes.Buffer - for _, tx := range txs { - b.Write(toTxBytes(*tx)) - } - if err := os.WriteFile(t.txFile, b.Bytes(), 0o644); err != nil { - t.loadErr = fmt.Errorf("failed rewriting input file: %w", err) - return - } - transactions, err := ledger.ParseLedgerFile(t.txFile) - if err != nil { - t.loadErr = fmt.Errorf("failed parsing input file: %w", err) - return - } - t.transactions = transactions - t.broadcast.Broadcast() -} - -func (t *TxService) TxStream(ctx context.Context) <-chan stream.Result[[]*ledger.Transaction] { - out := make(chan stream.Result[[]*ledger.Transaction]) - broadcasted := make(chan struct{}) - newValue := func() (stream.Result[[]*ledger.Transaction], bool) { - t.txLock.Lock() - txs := make([]*ledger.Transaction, len(t.transactions)) - for i, tx := range t.transactions { - txCopy := *tx - txs[i] = &(txCopy) + t.provider.Update(func(state txServiceState) txServiceState { + var b bytes.Buffer + for _, tx := range txs { + b.Write(toTxBytes(*tx)) } - loadErr := t.loadErr - loaded := t.loaded - t.txLock.Unlock() - return stream.ResultFrom(txs, loadErr), loaded - } - go func() { - defer close(broadcasted) - for { - t.txLock.Lock() - t.broadcast.Wait() - select { - case <-ctx.Done(): - t.txLock.Unlock() - return - case broadcasted <- struct{}{}: - t.txLock.Unlock() - } + if err := os.WriteFile(state.txFile, b.Bytes(), 0o644); err != nil { + state.loadErr = fmt.Errorf("failed rewriting input file: %w", err) + return state } - }() - go func() { - defer close(out) - emit := out - emitVal, shouldEmit := newValue() - if !shouldEmit { - emit = nil - } - for { - select { - case <-ctx.Done(): - return - case emit <- emitVal: - emit = nil - case <-broadcasted: - emit = out - emitVal, shouldEmit = newValue() - if !shouldEmit { - emit = nil - } - } + transactions, err := ledger.ParseLedgerFile(state.txFile) + if err != nil { + state.loadErr = fmt.Errorf("failed parsing input file: %w", err) + return state } - }() - return out + state.transactions = transactions + return state + }) +} + +func (t *TxService) TxStream(ctx context.Context) <-chan stream.Result[[]*ledger.Transaction] { + return t.provider.Stream(ctx) } -- 2.45.2