~sircmpwn/gql.sr.ht

6a1a8f1031f3c2c7ea60a66c62eea8a7c8af7237 — Drew DeVault 4 years ago 646e9d9
Streamline database support code
2 files changed, 153 insertions(+), 42 deletions(-)

M database/ql.go
M database/sq.go
M database/ql.go => database/ql.go +36 -37
@@ 29,69 29,68 @@ func collectFields(ctx context.Context) []graphql.CollectedField {
	return fields
}

func ColumnsFor(ctx context.Context, alias string,
	colMap map[string]string) []string {

	fields := collectFields(ctx)
	if len(fields) == 0 {
func Scan(ctx context.Context, m Model) []interface{} {
	qlFields := collectFields(ctx)
	if len(qlFields) == 0 {
		// Collect all fields if we are not in an active graphql context
		for qlCol, _ := range colMap {
			fields = append(fields, graphql.CollectedField{
				&ast.Field{Name: qlCol}, nil,
		for _, field := range m.Fields().All() {
			qlFields = append(qlFields, graphql.CollectedField{
				&ast.Field{Name: field.GQL}, nil,
			})
		}
	}

	sort.Slice(fields, func(a, b int) bool {
		return fields[a].Name < fields[b].Name
	sort.Slice(qlFields, func(a, b int) bool {
		return qlFields[a].Name < qlFields[b].Name
	})

	var columns []string
	for _, qlCol := range fields {
		if sqlCol, ok := colMap[qlCol.Name]; ok {
			if alias != "" {
				columns = append(columns, pq.QuoteIdentifier(alias)+
					"."+pq.QuoteIdentifier(sqlCol))
			} else {
				columns = append(columns, pq.QuoteIdentifier(sqlCol))
			}
	var fields []interface{}
	for _, qlField := range qlFields {
		if field, ok := m.Fields().GQL(qlField.Name); ok {
			fields = append(fields, field.Ptr)
		}
	}

	return columns
}
	for _, field := range m.Fields().Anonymous() {
		fields = append(fields, field.Ptr)
	}

func FieldsFor(ctx context.Context,
	colMap map[string]interface{}) []interface{} {
	return fields
}

	qlFields := collectFields(ctx)
	if len(qlFields) == 0 {
func Columns(ctx context.Context, m Model) []string {
	fields := collectFields(ctx)
	if len(fields) == 0 {
		// Collect all fields if we are not in an active graphql context
		for qlCol, _ := range colMap {
			qlFields = append(qlFields, graphql.CollectedField{
				&ast.Field{Name: qlCol}, nil,
		for _, field := range m.Fields().All() {
			fields = append(fields, graphql.CollectedField{
				&ast.Field{Name: field.GQL}, nil,
			})
		}
	}

	sort.Slice(qlFields, func(a, b int) bool {
		return qlFields[a].Name < qlFields[b].Name
	sort.Slice(fields, func(a, b int) bool {
		return fields[a].Name < fields[b].Name
	})

	var fields []interface{}
	for _, qlField := range qlFields {
		if field, ok := colMap[qlField.Name]; ok {
			fields = append(fields, field)
	var columns []string
	for _, gql := range fields {
		if field, ok := m.Fields().GQL(gql.Name); ok {
			columns = append(columns, WithAlias(m.Alias(), field.SQL))
		}
	}

	return fields
	for _, field := range m.Fields().Anonymous() {
		columns = append(columns, WithAlias(m.Alias(), field.SQL))
	}

	return columns
}

func WithAlias(alias, col string) string {
	if alias != "" {
		return alias + "." + col
		return pq.QuoteIdentifier(alias) + "." + pq.QuoteIdentifier(col)
	} else {
		return col
		return pq.QuoteIdentifier(col)
	}
}

M database/sq.go => database/sq.go +117 -5
@@ 3,13 3,75 @@ package database
import (
	"context"
	"fmt"
	"reflect"

	sq "github.com/Masterminds/squirrel"
)

type Selectable interface {
	Select(ctx context.Context) []string
	Fields(ctx context.Context) []interface{}
// Provides a mapping between PostgreSQL columns, GQL fields, and Go struct
// fields for all of the data associated with a model.
type FieldMap struct {
	SQL string
	GQL string
	Ptr interface{}
}

type ModelFields struct {
	Fields []*FieldMap

	byGQL map[string]*FieldMap
	bySQL map[string]*FieldMap
	anon  []*FieldMap
}

func (mf *ModelFields) buildCache() {
	if mf.byGQL != nil && mf.bySQL != nil {
		return
	}

	mf.byGQL = make(map[string]*FieldMap)
	mf.bySQL = make(map[string]*FieldMap)
	for _, f := range mf.Fields {
		if f.GQL != "" {
			mf.byGQL[f.GQL] = f
		} else {
			mf.anon = append(mf.anon, f)
		}
		mf.bySQL[f.SQL] = f
	}
}

func (mf *ModelFields) GQL(name string) (*FieldMap, bool) {
	mf.buildCache()
	if f, ok := mf.byGQL[name]; !ok {
		return nil, false
	} else {
		return f, true
	}
}

func (mf *ModelFields) SQL(name string) (*FieldMap, bool) {
	mf.buildCache()
	if f, ok := mf.bySQL[name]; !ok {
		return nil, false
	} else {
		return f, true
	}
}

func (mf *ModelFields) All() []*FieldMap {
	return mf.Fields
}

func (mf *ModelFields) Anonymous() []*FieldMap {
	mf.buildCache()
	return mf.anon
}

type Model interface {
	Alias()  string
	Fields() *ModelFields
	Table()  string
}

func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {


@@ 20,11 82,61 @@ func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {
			q = q.Columns(col)
		case []string:
			q = q.Columns(col...)
		case Selectable:
			q = q.Columns(col.Select(ctx)...)
		case Model:
			q = q.Columns(Columns(ctx, col)...)
		default:
			panic(fmt.Errorf("Unknown selectable type %T", col))
		}
	}
	return q
}

// Prepares an UPDATE statement which applies the changes in the input map to
// the given model.
func Apply(m Model, input map[string]interface{}) sq.UpdateBuilder {
	// XXX: This relies on the GraphQL validator to prevent the user from
	// updating columns they're not supposed to. Risky?
	table := m.Table()
	if m.Alias() != "" {
		table += " " + m.Alias()
	}
	update := sq.Update(table).PlaceholderFormat(sq.Dollar)

	defer func() {
		// Some weird reflection errors don't get properly logged if they're
		// caught at a higher level.
		if err := recover(); err != nil {
			fmt.Printf("%v\n", err)
			panic(err)
		}
	}()

	for field, value := range input {
		f, ok := m.Fields().GQL(field)
		if !ok {
			continue
		}

		var (
			pv reflect.Value = reflect.Indirect(reflect.ValueOf(f.Ptr))
			rv reflect.Value = reflect.ValueOf(value)
		)
		if pv.Type().Kind() == reflect.Ptr {
			if !rv.IsValid() {
				pv.Set(reflect.Zero(pv.Type()))
				update = update.Set(WithAlias(m.Alias(), f.SQL), nil)
			} else {
				if !pv.Elem().IsValid() {
					pv.Set(reflect.New(pv.Type().Elem()))
				}
				reflect.Indirect(pv).Set(reflect.Indirect(rv))
				update = update.Set(WithAlias(m.Alias(), f.SQL),
					reflect.Indirect(rv).Interface())
			}
		} else {
			panic(fmt.Errorf("TODO"))
		}
	}

	return update
}