~damien/shopping

857f9fe3008e8a40a365d9e0fb00e5049c212915 — Damien Radtke 2 years ago 8d5c972
Various changes
M api/apiclient/apiclient.go => api/apiclient/apiclient.go +11 -6
@@ 1,6 1,8 @@
package apiclient

import (
	"errors"
	"log"
	"bytes"
	"context"
	"fmt"


@@ 23,7 25,7 @@ func (e Error) Error() string {

type Client struct {
	from string
	hc *http.Client
	hc   *http.Client
}

func New(from string) Client {


@@ 38,31 40,34 @@ func New(from string) Client {
}

func (c Client) Call(ctx context.Context, serviceMethod string, arg, retval interface{}) error {
	if c.hc == nil {
		return errors.New("apiclient: no HTTP client!")
	}
	b, err := json.EncodeClientRequest(serviceMethod, arg)
	if err != nil {
		return fmt.Errorf("callAPI: failed to encode request: %w", err)
		return fmt.Errorf("apiclient: failed to encode request: %w", err)
	}
	url, err := dns.LookupService(ctx, "api")
	if err != nil {
		return fmt.Errorf("callAPI: failed to look up api service: %w", err)
		return fmt.Errorf("apiclient: failed to look up api service: %w", err)
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://"+url, bytes.NewReader(b))
	if err != nil {
		return fmt.Errorf("callAPI: failed to build request: %w", err)
		return fmt.Errorf("apiclient: failed to build request: %w", err)
	}
	httputil.AddAuthHeader(c.from, req.Header)
	req.Header.Set(httputil.RequestIDHeader, httputil.RequestID(ctx).String())
	req.Header.Set("Content-Type", "application/json")
	resp, err := c.hc.Do(req)
	if err != nil {
		return fmt.Errorf("callAPI: failed to call API: %w", err)
		return fmt.Errorf("apiclient: failed to call API: %w", err)
	}
	defer resp.Body.Close()
	if err := json.DecodeClientResponse(resp.Body, &retval); err != nil {
		if apierr, ok := err.(*json.Error); ok {
			return Error{Data: apierr.Data}
		}
		return fmt.Errorf("callAPI: failed to decode return value: %w", err)
		return fmt.Errorf("apiclient: failed to decode return value: %w", err)
	}
	return nil
}

M api/apiserver/apiserver.go => api/apiserver/apiserver.go +2 -8
@@ 2,7 2,6 @@
package apiserver

import (
	"context"
	"errors"
	"net/http"
	"time"


@@ 13,7 12,7 @@ import (
	"shopping.io/httputil"
)

func Server(addr string) *http.Server {
func NewServer(addr string) *http.Server {
	return &http.Server{
		Addr:    addr,
		Handler: Handler(),


@@ 21,8 20,6 @@ func Server(addr string) *http.Server {
}

func Handler() http.Handler {
	deps := NewDependencies(context.Background())

	s := rpc.NewServer()
	s.RegisterCodec(json.NewCodec(), "application/json")
	s.RegisterInterceptFunc(func(i *rpc.RequestInfo) *http.Request {


@@ 39,9 36,6 @@ func Handler() http.Handler {
		return nil
	})

	authService := Auth{Users: deps.Users, Secrets: deps.Secrets}
	if err := s.RegisterService(authService, ""); err != nil {
		panic(err)
	}
	registerServices(s)
	return s
}

M api/apiserver/auth.go => api/apiserver/auth.go +2 -5
@@ 5,18 5,15 @@ import (

	"shopping.io"
	"shopping.io/busi/auth"
	"shopping.io/db"
	"shopping.io/httputil"
	"shopping.io/secrets"
)

type Auth struct {
	Users   db.Users
	Secrets secrets.Client
	d Dependencies
}

func (s Auth) Login(r *http.Request, req *auth.LoginReq, userID *shopping.ID) error {
	id, err := auth.Login(r.Context(), s.Users, s.Secrets, *req)
	id, err := auth.Login(r.Context(), s.d.DB, s.d.Secrets, *req)
	if err != nil {
		return err
	}

R api/apiserver/deps.go => api/apiserver/dependencies.go +2 -2
@@ 10,13 10,13 @@ import (
)

type Dependencies struct {
	Users   db.Users
	DB db.DB
	Secrets secrets.Client
}

func NewDependencies(ctx context.Context) Dependencies {
	return Dependencies{
		Users:   postgres.MustNewUsers(ctx, "db"),
		DB:   postgres.MustNewDB(ctx, "db"),
		Secrets: vault.MustNew(ctx),
	}
}

A api/apiserver/services.go => api/apiserver/services.go +19 -0
@@ 0,0 1,19 @@
package apiserver

import (
	"context"

	"github.com/gorilla/rpc/v2"
)

func registerServices(s *rpc.Server) {
	deps := NewDependencies(context.Background())

	for _, v := range []interface{}{
		Auth{deps},
	} {
		if err := s.RegisterService(v, ""); err != nil {
			panic(err)
		}
	}
}

M busi/auth/auth.go => busi/auth/auth.go +1 -1
@@ 13,7 13,7 @@ type LoginReq struct {
	Username, Password string
}

func Login(ctx context.Context, users db.Users, sec secrets.Client, req LoginReq) (shopping.ID, error) {
func Login(ctx context.Context, db db.UserFinder, sec secrets.Client, req LoginReq) (shopping.ID, error) {
	httputil.Log(ctx, "attempting login as %s", req.Username)
	if req.Username == "damien" && req.Password == "letmein" {
		return shopping.NewID(), nil

M cmd/api/main.go => cmd/api/main.go +1 -1
@@ 21,7 21,7 @@ func run() error {
	addr := httputil.Addr()
	log.Printf("== serving api on %s ==", addr)

	server := apiserver.Server(addr)
	server := apiserver.NewServer(addr)
	done := make(chan struct{})
	httputil.HandleSignals(server, 2*time.Second, done)
	defer func() {

A db/db.go => db/db.go +20 -0
@@ 0,0 1,20 @@
package db

import (
	"context"

	"shopping.io"
)

type DB interface {
	UserCreator
	UserFinder
}

type UserCreator interface {
	UserCreate(context.Context, shopping.User) (shopping.User, error)
}

type UserFinder interface {
	UserFind(context.Context, string) (shopping.User, error)
}

M db/postgres/postgres.go => db/postgres/postgres.go +50 -0
@@ 12,10 12,60 @@ import (
	"github.com/jackc/pgx/v4"

	"shopping.io/dns"
	"shopping.io"
)

const sslmode = "disable"

type DB struct {
	q queryer
}

func NewDB(ctx context.Context, serviceName string) (DB, error) {
	db, err := connect(ctx, serviceName)
	if err != nil {
		return DB{}, fmt.Errorf("NewDB: %w", err)
	}
	return DB{q: db}, nil
}

func MustNewDB(ctx context.Context, serviceName string) DB {
	db, err := NewDB(ctx, serviceName)
	if err != nil {
		panic(err)
	}
	return db
}

func (d DB) Begin(ctx context.Context) (db DB, finish func(), err error) {
	defer func() {
		if err != nil {
			err = fmt.Errorf("DB.Begin: %w", err)
		}
	}()

	var tx pgx.Tx
	if tx, err = d.q.Begin(ctx); err != nil {
		return d, nil, err
	}

	finish = func() {
		tx.Rollback(ctx)
	}

	return DB{q: tx}, finish, nil
}

func (d DB) Transaction(ctx context.Context, f func(DB) error) (err error) {
	defer shopping.ErrFmt(&err, "DB.Transaction")
	if tx, err := d.q.Begin(ctx); err != nil {
		return err
	} else {
		defer finish(ctx, tx, &err)
		return f(DB{q: tx})
	}
}

type queryer interface {
	Begin(ctx context.Context) (pgx.Tx, error)
	Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)

M db/postgres/postgrestest/postgrestest.go => db/postgres/postgrestest/postgrestest.go +3 -3
@@ 7,7 7,7 @@ import (
	"shopping.io/db/postgres"
)

func Users(t *testing.T) postgres.Users {
func DB(t *testing.T) postgres.DB {
	if testing.Short() {
		t.Skip("short mode requested, skipping integration test")
	}


@@ 17,10 17,10 @@ func Users(t *testing.T) postgres.Users {
		ctx, cancel = context.WithDeadline(ctx, deadline)
		t.Cleanup(cancel)
	}
	users, finish, err := postgres.MustNewUsers(ctx, "db").Begin(ctx)
	db, finish, err := postgres.MustNewDB(ctx, "db").Begin(ctx)
	if err != nil {
		t.Fatalf("postgrestest.Users: %s", err)
	}
	t.Cleanup(finish)
	return users
	return db
}

M db/postgres/users.go => db/postgres/users.go +7 -59
@@ 2,76 2,24 @@ package postgres

import (
	"context"
	"fmt"

	"github.com/jackc/pgx/v4"

	"shopping.io"
	"shopping.io/db"
)

type Users struct {
	db queryer
}

func NewUsers(ctx context.Context, serviceName string) (Users, error) {
	db, err := connect(ctx, serviceName)
	if err != nil {
		return Users{}, fmt.Errorf("NewUsers: %w", err)
	}
	return Users{db}, nil
}

func MustNewUsers(ctx context.Context, serviceName string) Users {
	users, err := NewUsers(ctx, serviceName)
	if err != nil {
		panic(err)
	}
	return users
}

func (u Users) Begin(ctx context.Context) (users Users, finish func(), err error) {
	defer func() {
		if err != nil {
			err = fmt.Errorf("Users.Begin: %w", err)
		}
	}()

	var tx pgx.Tx
	if tx, err = u.db.Begin(ctx); err != nil {
		return u, nil, err
	}

	finish = func() {
		tx.Rollback(ctx)
	}

	return Users{tx}, finish, nil
}

func (u Users) Transaction(ctx context.Context, f func(db.Users) error) (err error) {
	defer shopping.ErrFmt(&err, "Users.Transaction")
	if tx, err := u.db.Begin(ctx); err != nil {
		return err
	} else {
		defer finish(ctx, tx, &err)
		return f(Users{tx})
	}
}

func (u Users) Create(ctx context.Context, user shopping.User) (retUser shopping.User, err error) {
	defer shopping.ErrFmt(&err, "Users.Create")
func (d DB) UserCreate(ctx context.Context, user shopping.User) (retUser shopping.User, err error) {
	defer shopping.ErrFmt(&err, "DB.UserCreate")
	user.ID = shopping.NewID()
	_, err = u.db.Exec(ctx, `INSERT INTO users (id, username, display_name) VALUES ($1, $2, $3)`, user.ID, user.Username, user.DisplayName)
	_, err = d.q.Exec(ctx, `INSERT INTO users (id, username, display_name) VALUES ($1, $2, $3)`, user.ID, user.Username, user.DisplayName)
	return user, err
}

func (u Users) Find(ctx context.Context, username string) (user shopping.User, err error) {
	defer shopping.ErrFmt(&err, "Users.Find")
func (d DB) UserFind(ctx context.Context, username string) (user shopping.User, err error) {
	defer shopping.ErrFmt(&err, "DB.UserFind")
	user = shopping.User{Username: username}
	if err = u.db.QueryRow(ctx, "SELECT id, display_name FROM users WHERE username = $1", username).Scan(&user.ID, &user.DisplayName); err == pgx.ErrNoRows {
	if err = d.q.QueryRow(ctx, "SELECT id, display_name FROM users WHERE username = $1", username).Scan(&user.ID, &user.DisplayName); err == pgx.ErrNoRows {
		return user, nil
	}
	return user, err
}

var _ db.Users = Users{}

M db/postgres/users_test.go => db/postgres/users_test.go +3 -3
@@ 9,9 9,9 @@ import (
)

func TestCreateAndFindUser(t *testing.T) {
	users := postgrestest.Users(t)
	db := postgrestest.DB(t)

	homerCreate, err := users.Create(context.Background(), shopping.User{
	homerCreate, err := db.UserCreate(context.Background(), shopping.User{
		Username:    "homer",
		DisplayName: "Homer Simpson",
	})


@@ 22,7 22,7 @@ func TestCreateAndFindUser(t *testing.T) {
		t.Fatal("create: did not expect zero id")
	}

	homerFind, err := users.Find(context.Background(), "homer")
	homerFind, err := db.UserFind(context.Background(), "homer")
	if err != nil {
		t.Fatalf("failed to find user: %s", err)
	}

M httputil/httputil.go => httputil/httputil.go +11 -0
@@ 7,6 7,7 @@ import (
	"html/template"
	"log"
	"net/http"
	"runtime"

	"github.com/gorilla/csrf"
)


@@ 23,3 24,13 @@ func CSRFField(r *http.Request) template.HTML {
func Log(ctx context.Context, format string, v ...interface{}) {
	log.Printf(fmt.Sprintf("[%s] %s", RequestID(ctx), format), v...)
}

func LogStackTrace(ctx context.Context) {
	for skip := 2; true; skip += 1 {
		_, file, line, ok := runtime.Caller(skip)
		if !ok {
			return
		}
		Log(ctx, "\tat %s:%d", file, line)
	}
}

M httputil/middleware.go => httputil/middleware.go +1 -0
@@ 37,6 37,7 @@ func Recover(next http.Handler) http.Handler {
		defer func() {
			if err := recover(); err != nil {
				Log(r.Context(), "panic: %s", err)
				LogStackTrace(r.Context())
				http.Error(w, "unexpected error", http.StatusInternalServerError)
			}
		}()

M web/server.go => web/server.go +1 -0
@@ 26,6 26,7 @@ func NewServer(addr string, debug bool) *http.Server {
		debug: debug,
		router: mux.NewRouter(),
		sessionStore: newSessionStore(debug),
		api: apiclient.New("web"),
	}
	s.router.Use(httputil.SetRequestID)
	s.router.Use(httputil.Recover)