~evanj/errgroupcount

45e72d8ba8fcca5ae47d06e98648bb2047cfadc6 — Evan M Jones 3 years ago
Feat(*): Project init. Breaking into its own module, outside
git.evanjon.es/i/mini-eggs-github-io.
4 files changed, 248 insertions(+), 0 deletions(-)

A errgroupcount.go
A errgroupcount_test.go
A go.mod
A go.sum
A  => errgroupcount.go +126 -0
@@ 1,126 @@
// Why? Isn't errgroup.Group{} enough?
//
// No.
//
// It has happened for too many times now where I have to start up a list of
// goroutines, have to keep track of their errors, but don't care about ALL
// their errors. So I've made this. A canonical example is something like the
// following (this is taken from my own personal blog). The requirements are:
// (1) Given a "slug" value we check to see if there's a piece of content in the
// database of with a "slug" of that value.
// (2) The piece of content can either have a type of "page" or "post"
// (3) We only want to return an http.StatusNotFound page if we cannot find a
// "page" or "post" with the slug value. E.G. We only care if we hit two errors
// (not found) instead of just one (as only the one case is supported by the
// normal errgroup.
//
// Aside: for this personal case about "slug" clashing is not cared about.
//
// Example:
//
// func (s SearchEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 	slug := s.Param(r, "slug")
// 	eg, _ := errgroupcount.WithContext(r.Context())
// 	once := sync.Once{}
//
// 	eg.Go(func() error {
// 		item, err := s.content.Search("page", slug)
// 		if err != nil {
// 			return err
// 		}
// 		once.Do(func() {
// 			s.HTML(w, r, http.StatusOK, "item.html", map[string]interface{}{"Item": item})
// 		})
// 		return nil
// 	})
//
// 	eg.Go(func() error {
// 		item, err := s.content.Search("post", slug)
// 		if err != nil {
// 			return err
// 		}
// 		once.Do(func() {
// 			s.HTML(w, r, http.StatusOK, "item.html", map[string]interface{}{"Item": item})
// 		})
// 		return nil
// 	})
//
// 	if err := eg.WaitCount(2); err != nil {
// 		s.ErrorString(w, r, http.StatusNotFound, "failed to find content")
// 	}
// }
//
// Enjoy.
//
package errgroupcount

import (
	"context"
	"sync"
)

// Package errgroupcount provides synchronization, error propagation, and Context
// cancelation for groups of goroutines working on subtasks of a common task.
//
// TYPES
//
// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
	cancel func()

	wg sync.WaitGroup

	errMutex sync.Mutex
	err      []error
}

// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
	ctx, cancel := context.WithCancel(ctx)
	return &Group{cancel: cancel}, ctx
}

// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
	g.wg.Add(1)

	go func() {
		defer g.wg.Done()

		if err := f(); err != nil {
			g.errMutex.Lock()
			g.err = append(g.err, err)
			if g.cancel != nil {
				g.cancel()
			}
			g.errMutex.Unlock()
		}
	}()
}

// WaitCount blocks until all function calls from the Go method have returned,
// then will pop shift and return the first error received if and only if the
// total error count is greater than the int provided to WaitCount. If the total
// error count is less than the value supplied to WaitCount nil is returned.
// `top int` must be greated than zero, otherise nil is always returned.
func (g *Group) WaitCount(top int) (err error) {
	g.wg.Wait()
	if g.cancel != nil {
		g.cancel()
	}
	if top > 0 && len(g.err) >= top {
		err, g.err = g.err[0], g.err[1:]
		return err
	}
	return nil
}

A  => errgroupcount_test.go +106 -0
@@ 1,106 @@
package errgroupcount_test

import (
	"errors"
	"testing"
	"time"

	"git.evanjon.es/i/errgroupcount"
	"github.com/stretchr/testify/assert"
)

var (
	ErrOne = errors.New("this is the first error")
	ErrTwo = errors.New("this is the second error")
)

func TestBasic(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	eg.Go(func() error { return nil })
	eg.Go(func() error { return nil })
	eg.Go(func() error { return nil })
	assert.Equal(t, nil, eg.WaitCount(3))
}

func TestZero(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	assert.Equal(t, nil, eg.WaitCount(0))
	assert.Equal(t, nil, eg.WaitCount(-1))
}

func TestBasicError(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	eg.Go(func() error { return ErrOne })
	assert.Equal(t, ErrOne, eg.WaitCount(1))
}

func TestBasicFirst(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error {
		time.Sleep(time.Millisecond * 100) // We can assume eg obtains ErrOne first.
		return ErrTwo
	})
	assert.Equal(t, ErrOne, eg.WaitCount(1))
}

func TestIsError(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return nil })
	assert.Equal(t, ErrOne, eg.WaitCount(1))
}

func TestNoError(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return nil })
	assert.Equal(t, nil, eg.WaitCount(2))
}

func TestIsErrorTwo(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	assert.Equal(t, ErrOne, eg.WaitCount(7))
}

func TestIsErrorThree(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	eg.Go(func() error { return ErrOne })
	assert.Equal(t, nil, eg.WaitCount(8))
}

func TestLots(t *testing.T) {
	t.Parallel()
	eg := errgroupcount.Group{}
	top := 1000
	for x := 0; x < top/2; x++ {
		eg.Go(func() error { return ErrOne })
	}
	for x := 0; x < top/2; x++ {
		eg.Go(func() error { return nil })
	}
	assert.Equal(t, ErrOne, eg.WaitCount(top/2))
	assert.Equal(t, nil, eg.WaitCount(top/2+1))
	assert.Equal(t, nil, eg.WaitCount(top))
}

A  => go.mod +5 -0
@@ 1,5 @@
module git.evanjon.es/i/errgroupcount

go 1.12

require github.com/stretchr/testify v1.4.0

A  => go.sum +11 -0
@@ 1,11 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=