~whereswaldon/rosebud

4b42cdd3896ad988e58ba043378c515b437ebc6a — Chris Waldon 9 months ago ac25ca3 main
cmd/rosebud: refactor tx service to extract helper type

Signed-off-by: Chris Waldon <christopher.waldon.dev@gmail.com>
1 files changed, 133 insertions(+), 84 deletions(-)

M cmd/rosebud/tx-service.go
M cmd/rosebud/tx-service.go => cmd/rosebud/tx-service.go +133 -84
@@ 15,20 15,117 @@ import (
	"git.sr.ht/~whereswaldon/ledger"
)

type TxService struct {
	txLock       sync.Mutex
// StreamProvider helps contruct streams that emit a single, shared value.
// S is an internal state type, protected by the stream provider's lock,
// that is used to track the current state of the provider.
// T is the result type emitted on streams from this provider.
type StreamProvider[S, T any] struct {
	lock      sync.Mutex
	broadcast *sync.Cond
	value     S
	valuer    func(S) (stream.Result[T], bool)
}

// NewStreamProvider constructs a stream provider using the provided
// valuer function to transform its current state (S) into a stream result.
// The boolean return value indicates whether the result should be emitted
// over the stream or discarded. The valuer should deep copy all data it uses
// in T to ensure that other invocations of valuer do not reference the
// same memory.
func NewStreamProvider[S, T any](valuer func(S) (stream.Result[T], bool)) *StreamProvider[S, T] {
	sp := &StreamProvider[S, T]{
		valuer: valuer,
	}
	sp.broadcast = sync.NewCond(&sp.lock)
	return sp
}

// Update runs fn with the provider's lock held, passing the current
// state to fn and setting the state to the return value of fn.
func (s *StreamProvider[S, T]) Update(fn func(oldState S) S) {
	s.lock.Lock()
	defer s.lock.Unlock()
	s.value = fn(s.value)
	s.broadcast.Broadcast()
}

// newValue returns the result of the valuer function for S's current
// value.
func (s *StreamProvider[S, T]) newValue() (stream.Result[T], bool) {
	s.lock.Lock()
	defer s.lock.Unlock()
	return s.valuer(s.value)
}

// Stream returns a stream bound to the lifetime of the provided context.
// Each time Update is invoked on s, all streams will run s's valuer function
// and will emit the results if the valuer's second return value was true.
func (s *StreamProvider[S, T]) Stream(ctx context.Context) <-chan stream.Result[T] {
	out := make(chan stream.Result[T])
	broadcasted := make(chan struct{})
	go func() {
		defer close(broadcasted)
		for {
			s.lock.Lock()
			s.broadcast.Wait()
			select {
			case <-ctx.Done():
				s.lock.Unlock()
				return
			case broadcasted <- struct{}{}:
				s.lock.Unlock()
			}
		}
	}()
	go func() {
		defer close(out)
		emit := out
		emitVal, shouldEmit := s.newValue()
		if !shouldEmit {
			emit = nil
		}
		for {
			select {
			case <-ctx.Done():
				return
			case emit <- emitVal:
				emit = nil
			case <-broadcasted:
				emit = out
				emitVal, shouldEmit = s.newValue()
				if !shouldEmit {
					emit = nil
				}
			}
		}
	}()
	return out
}

type txServiceState struct {
	loaded       bool
	transactions []*ledger.Transaction
	loadErr      error
	txFile       string
}

	broadcast *sync.Cond
type TxService struct {
	provider *StreamProvider[txServiceState, []*ledger.Transaction]
}

func NewTxService() *TxService {
	tx := &TxService{}
	tx.broadcast = sync.NewCond(&tx.txLock)
	return tx
	return &TxService{
		provider: NewStreamProvider(func(state txServiceState) (stream.Result[[]*ledger.Transaction], bool) {
			txs := make([]*ledger.Transaction, len(state.transactions))
			for i, tx := range state.transactions {
				txCopy := *tx
				txs[i] = &(txCopy)
			}
			loadErr := state.loadErr
			loaded := state.loaded
			return stream.ResultFrom(txs, loadErr), loaded
		}),
	}
}

// OpenLedger reconfigures the service using the contents of the provided


@@ 36,9 133,10 @@ func (t *TxService) OpenLedger(contents io.ReadCloser) {
	var err error
	defer func() {
		if err != nil {
			t.txLock.Lock()
			t.loadErr = err
			t.txLock.Unlock()
			t.provider.Update(func(state txServiceState) txServiceState {
				state.loadErr = err
				return state
			})
		}
	}()
	defer func() {


@@ 92,85 190,36 @@ func (t *TxService) OpenLedger(contents io.ReadCloser) {
		err = fmt.Errorf("failed parsing input file: %w", err)
		return
	}
	t.txLock.Lock()
	defer t.txLock.Unlock()
	t.transactions = transactions
	t.txFile = outFilePath
	t.loadErr = nil
	t.broadcast.Broadcast()
	t.loaded = true
	t.provider.Update(func(state txServiceState) txServiceState {
		return txServiceState{
			transactions: transactions,
			txFile:       outFilePath,
			loadErr:      nil,
			loaded:       true,
		}
	})
}

func (t *TxService) UpdateLedger(txs []*ledger.Transaction) {
	t.txLock.Lock()
	defer t.txLock.Unlock()
	var b bytes.Buffer
	for _, tx := range txs {
		b.Write(toTxBytes(*tx))
	}
	if err := os.WriteFile(t.txFile, b.Bytes(), 0o644); err != nil {
		t.loadErr = fmt.Errorf("failed rewriting input file: %w", err)
		return
	}
	transactions, err := ledger.ParseLedgerFile(t.txFile)
	if err != nil {
		t.loadErr = fmt.Errorf("failed parsing input file: %w", err)
		return
	}
	t.transactions = transactions
	t.broadcast.Broadcast()
}

func (t *TxService) TxStream(ctx context.Context) <-chan stream.Result[[]*ledger.Transaction] {
	out := make(chan stream.Result[[]*ledger.Transaction])
	broadcasted := make(chan struct{})
	newValue := func() (stream.Result[[]*ledger.Transaction], bool) {
		t.txLock.Lock()
		txs := make([]*ledger.Transaction, len(t.transactions))
		for i, tx := range t.transactions {
			txCopy := *tx
			txs[i] = &(txCopy)
	t.provider.Update(func(state txServiceState) txServiceState {
		var b bytes.Buffer
		for _, tx := range txs {
			b.Write(toTxBytes(*tx))
		}
		loadErr := t.loadErr
		loaded := t.loaded
		t.txLock.Unlock()
		return stream.ResultFrom(txs, loadErr), loaded
	}
	go func() {
		defer close(broadcasted)
		for {
			t.txLock.Lock()
			t.broadcast.Wait()
			select {
			case <-ctx.Done():
				t.txLock.Unlock()
				return
			case broadcasted <- struct{}{}:
				t.txLock.Unlock()
			}
		if err := os.WriteFile(state.txFile, b.Bytes(), 0o644); err != nil {
			state.loadErr = fmt.Errorf("failed rewriting input file: %w", err)
			return state
		}
	}()
	go func() {
		defer close(out)
		emit := out
		emitVal, shouldEmit := newValue()
		if !shouldEmit {
			emit = nil
		}
		for {
			select {
			case <-ctx.Done():
				return
			case emit <- emitVal:
				emit = nil
			case <-broadcasted:
				emit = out
				emitVal, shouldEmit = newValue()
				if !shouldEmit {
					emit = nil
				}
			}
		transactions, err := ledger.ParseLedgerFile(state.txFile)
		if err != nil {
			state.loadErr = fmt.Errorf("failed parsing input file: %w", err)
			return state
		}
	}()
	return out
		state.transactions = transactions
		return state
	})
}

func (t *TxService) TxStream(ctx context.Context) <-chan stream.Result[[]*ledger.Transaction] {
	return t.provider.Stream(ctx)
}