A => errgroupcount.go +126 -0
@@ 1,126 @@
+// Why? Isn't errgroup.Group{} enough?
+//
+// No.
+//
+// It has happened for too many times now where I have to start up a list of
+// goroutines, have to keep track of their errors, but don't care about ALL
+// their errors. So I've made this. A canonical example is something like the
+// following (this is taken from my own personal blog). The requirements are:
+// (1) Given a "slug" value we check to see if there's a piece of content in the
+// database of with a "slug" of that value.
+// (2) The piece of content can either have a type of "page" or "post"
+// (3) We only want to return an http.StatusNotFound page if we cannot find a
+// "page" or "post" with the slug value. E.G. We only care if we hit two errors
+// (not found) instead of just one (as only the one case is supported by the
+// normal errgroup.
+//
+// Aside: for this personal case about "slug" clashing is not cared about.
+//
+// Example:
+//
+// func (s SearchEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+// slug := s.Param(r, "slug")
+// eg, _ := errgroupcount.WithContext(r.Context())
+// once := sync.Once{}
+//
+// eg.Go(func() error {
+// item, err := s.content.Search("page", slug)
+// if err != nil {
+// return err
+// }
+// once.Do(func() {
+// s.HTML(w, r, http.StatusOK, "item.html", map[string]interface{}{"Item": item})
+// })
+// return nil
+// })
+//
+// eg.Go(func() error {
+// item, err := s.content.Search("post", slug)
+// if err != nil {
+// return err
+// }
+// once.Do(func() {
+// s.HTML(w, r, http.StatusOK, "item.html", map[string]interface{}{"Item": item})
+// })
+// return nil
+// })
+//
+// if err := eg.WaitCount(2); err != nil {
+// s.ErrorString(w, r, http.StatusNotFound, "failed to find content")
+// }
+// }
+//
+// Enjoy.
+//
+package errgroupcount
+
+import (
+ "context"
+ "sync"
+)
+
+// Package errgroupcount provides synchronization, error propagation, and Context
+// cancelation for groups of goroutines working on subtasks of a common task.
+//
+// TYPES
+//
+// A Group is a collection of goroutines working on subtasks that are part of
+// the same overall task.
+//
+// A zero Group is valid and does not cancel on error.
+type Group struct {
+ cancel func()
+
+ wg sync.WaitGroup
+
+ errMutex sync.Mutex
+ err []error
+}
+
+// WithContext returns a new Group and an associated Context derived from ctx.
+//
+// The derived Context is canceled the first time a function passed to Go
+// returns a non-nil error or the first time Wait returns, whichever occurs
+// first.
+func WithContext(ctx context.Context) (*Group, context.Context) {
+ ctx, cancel := context.WithCancel(ctx)
+ return &Group{cancel: cancel}, ctx
+}
+
+// Go calls the given function in a new goroutine.
+//
+// The first call to return a non-nil error cancels the group; its error will be
+// returned by Wait.
+func (g *Group) Go(f func() error) {
+ g.wg.Add(1)
+
+ go func() {
+ defer g.wg.Done()
+
+ if err := f(); err != nil {
+ g.errMutex.Lock()
+ g.err = append(g.err, err)
+ if g.cancel != nil {
+ g.cancel()
+ }
+ g.errMutex.Unlock()
+ }
+ }()
+}
+
+// WaitCount blocks until all function calls from the Go method have returned,
+// then will pop shift and return the first error received if and only if the
+// total error count is greater than the int provided to WaitCount. If the total
+// error count is less than the value supplied to WaitCount nil is returned.
+// `top int` must be greated than zero, otherise nil is always returned.
+func (g *Group) WaitCount(top int) (err error) {
+ g.wg.Wait()
+ if g.cancel != nil {
+ g.cancel()
+ }
+ if top > 0 && len(g.err) >= top {
+ err, g.err = g.err[0], g.err[1:]
+ return err
+ }
+ return nil
+}
A => errgroupcount_test.go +106 -0
@@ 1,106 @@
+package errgroupcount_test
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "git.evanjon.es/i/errgroupcount"
+ "github.com/stretchr/testify/assert"
+)
+
+var (
+ ErrOne = errors.New("this is the first error")
+ ErrTwo = errors.New("this is the second error")
+)
+
+func TestBasic(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ eg.Go(func() error { return nil })
+ eg.Go(func() error { return nil })
+ eg.Go(func() error { return nil })
+ assert.Equal(t, nil, eg.WaitCount(3))
+}
+
+func TestZero(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ assert.Equal(t, nil, eg.WaitCount(0))
+ assert.Equal(t, nil, eg.WaitCount(-1))
+}
+
+func TestBasicError(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ eg.Go(func() error { return ErrOne })
+ assert.Equal(t, ErrOne, eg.WaitCount(1))
+}
+
+func TestBasicFirst(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error {
+ time.Sleep(time.Millisecond * 100) // We can assume eg obtains ErrOne first.
+ return ErrTwo
+ })
+ assert.Equal(t, ErrOne, eg.WaitCount(1))
+}
+
+func TestIsError(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return nil })
+ assert.Equal(t, ErrOne, eg.WaitCount(1))
+}
+
+func TestNoError(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return nil })
+ assert.Equal(t, nil, eg.WaitCount(2))
+}
+
+func TestIsErrorTwo(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ assert.Equal(t, ErrOne, eg.WaitCount(7))
+}
+
+func TestIsErrorThree(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ eg.Go(func() error { return ErrOne })
+ assert.Equal(t, nil, eg.WaitCount(8))
+}
+
+func TestLots(t *testing.T) {
+ t.Parallel()
+ eg := errgroupcount.Group{}
+ top := 1000
+ for x := 0; x < top/2; x++ {
+ eg.Go(func() error { return ErrOne })
+ }
+ for x := 0; x < top/2; x++ {
+ eg.Go(func() error { return nil })
+ }
+ assert.Equal(t, ErrOne, eg.WaitCount(top/2))
+ assert.Equal(t, nil, eg.WaitCount(top/2+1))
+ assert.Equal(t, nil, eg.WaitCount(top))
+}
A => go.mod +5 -0
@@ 1,5 @@
+module git.evanjon.es/i/errgroupcount
+
+go 1.12
+
+require github.com/stretchr/testify v1.4.0
A => go.sum +11 -0
@@ 1,11 @@
+github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=