@@ 29,69 29,68 @@ func collectFields(ctx context.Context) []graphql.CollectedField {
return fields
}
-func ColumnsFor(ctx context.Context, alias string,
- colMap map[string]string) []string {
-
- fields := collectFields(ctx)
- if len(fields) == 0 {
+func Scan(ctx context.Context, m Model) []interface{} {
+ qlFields := collectFields(ctx)
+ if len(qlFields) == 0 {
// Collect all fields if we are not in an active graphql context
- for qlCol, _ := range colMap {
- fields = append(fields, graphql.CollectedField{
- &ast.Field{Name: qlCol}, nil,
+ for _, field := range m.Fields().All() {
+ qlFields = append(qlFields, graphql.CollectedField{
+ &ast.Field{Name: field.GQL}, nil,
})
}
}
- sort.Slice(fields, func(a, b int) bool {
- return fields[a].Name < fields[b].Name
+ sort.Slice(qlFields, func(a, b int) bool {
+ return qlFields[a].Name < qlFields[b].Name
})
- var columns []string
- for _, qlCol := range fields {
- if sqlCol, ok := colMap[qlCol.Name]; ok {
- if alias != "" {
- columns = append(columns, pq.QuoteIdentifier(alias)+
- "."+pq.QuoteIdentifier(sqlCol))
- } else {
- columns = append(columns, pq.QuoteIdentifier(sqlCol))
- }
+ var fields []interface{}
+ for _, qlField := range qlFields {
+ if field, ok := m.Fields().GQL(qlField.Name); ok {
+ fields = append(fields, field.Ptr)
}
}
- return columns
-}
+ for _, field := range m.Fields().Anonymous() {
+ fields = append(fields, field.Ptr)
+ }
-func FieldsFor(ctx context.Context,
- colMap map[string]interface{}) []interface{} {
+ return fields
+}
- qlFields := collectFields(ctx)
- if len(qlFields) == 0 {
+func Columns(ctx context.Context, m Model) []string {
+ fields := collectFields(ctx)
+ if len(fields) == 0 {
// Collect all fields if we are not in an active graphql context
- for qlCol, _ := range colMap {
- qlFields = append(qlFields, graphql.CollectedField{
- &ast.Field{Name: qlCol}, nil,
+ for _, field := range m.Fields().All() {
+ fields = append(fields, graphql.CollectedField{
+ &ast.Field{Name: field.GQL}, nil,
})
}
}
- sort.Slice(qlFields, func(a, b int) bool {
- return qlFields[a].Name < qlFields[b].Name
+ sort.Slice(fields, func(a, b int) bool {
+ return fields[a].Name < fields[b].Name
})
- var fields []interface{}
- for _, qlField := range qlFields {
- if field, ok := colMap[qlField.Name]; ok {
- fields = append(fields, field)
+ var columns []string
+ for _, gql := range fields {
+ if field, ok := m.Fields().GQL(gql.Name); ok {
+ columns = append(columns, WithAlias(m.Alias(), field.SQL))
}
}
- return fields
+ for _, field := range m.Fields().Anonymous() {
+ columns = append(columns, WithAlias(m.Alias(), field.SQL))
+ }
+
+ return columns
}
func WithAlias(alias, col string) string {
if alias != "" {
- return alias + "." + col
+ return pq.QuoteIdentifier(alias) + "." + pq.QuoteIdentifier(col)
} else {
- return col
+ return pq.QuoteIdentifier(col)
}
}
@@ 3,13 3,75 @@ package database
import (
"context"
"fmt"
+ "reflect"
sq "github.com/Masterminds/squirrel"
)
-type Selectable interface {
- Select(ctx context.Context) []string
- Fields(ctx context.Context) []interface{}
+// Provides a mapping between PostgreSQL columns, GQL fields, and Go struct
+// fields for all of the data associated with a model.
+type FieldMap struct {
+ SQL string
+ GQL string
+ Ptr interface{}
+}
+
+type ModelFields struct {
+ Fields []*FieldMap
+
+ byGQL map[string]*FieldMap
+ bySQL map[string]*FieldMap
+ anon []*FieldMap
+}
+
+func (mf *ModelFields) buildCache() {
+ if mf.byGQL != nil && mf.bySQL != nil {
+ return
+ }
+
+ mf.byGQL = make(map[string]*FieldMap)
+ mf.bySQL = make(map[string]*FieldMap)
+ for _, f := range mf.Fields {
+ if f.GQL != "" {
+ mf.byGQL[f.GQL] = f
+ } else {
+ mf.anon = append(mf.anon, f)
+ }
+ mf.bySQL[f.SQL] = f
+ }
+}
+
+func (mf *ModelFields) GQL(name string) (*FieldMap, bool) {
+ mf.buildCache()
+ if f, ok := mf.byGQL[name]; !ok {
+ return nil, false
+ } else {
+ return f, true
+ }
+}
+
+func (mf *ModelFields) SQL(name string) (*FieldMap, bool) {
+ mf.buildCache()
+ if f, ok := mf.bySQL[name]; !ok {
+ return nil, false
+ } else {
+ return f, true
+ }
+}
+
+func (mf *ModelFields) All() []*FieldMap {
+ return mf.Fields
+}
+
+func (mf *ModelFields) Anonymous() []*FieldMap {
+ mf.buildCache()
+ return mf.anon
+}
+
+type Model interface {
+ Alias() string
+ Fields() *ModelFields
+ Table() string
}
func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {
@@ 20,11 82,61 @@ func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {
q = q.Columns(col)
case []string:
q = q.Columns(col...)
- case Selectable:
- q = q.Columns(col.Select(ctx)...)
+ case Model:
+ q = q.Columns(Columns(ctx, col)...)
default:
panic(fmt.Errorf("Unknown selectable type %T", col))
}
}
return q
}
+
+// Prepares an UPDATE statement which applies the changes in the input map to
+// the given model.
+func Apply(m Model, input map[string]interface{}) sq.UpdateBuilder {
+ // XXX: This relies on the GraphQL validator to prevent the user from
+ // updating columns they're not supposed to. Risky?
+ table := m.Table()
+ if m.Alias() != "" {
+ table += " " + m.Alias()
+ }
+ update := sq.Update(table).PlaceholderFormat(sq.Dollar)
+
+ defer func() {
+ // Some weird reflection errors don't get properly logged if they're
+ // caught at a higher level.
+ if err := recover(); err != nil {
+ fmt.Printf("%v\n", err)
+ panic(err)
+ }
+ }()
+
+ for field, value := range input {
+ f, ok := m.Fields().GQL(field)
+ if !ok {
+ continue
+ }
+
+ var (
+ pv reflect.Value = reflect.Indirect(reflect.ValueOf(f.Ptr))
+ rv reflect.Value = reflect.ValueOf(value)
+ )
+ if pv.Type().Kind() == reflect.Ptr {
+ if !rv.IsValid() {
+ pv.Set(reflect.Zero(pv.Type()))
+ update = update.Set(WithAlias(m.Alias(), f.SQL), nil)
+ } else {
+ if !pv.Elem().IsValid() {
+ pv.Set(reflect.New(pv.Type().Elem()))
+ }
+ reflect.Indirect(pv).Set(reflect.Indirect(rv))
+ update = update.Set(WithAlias(m.Alias(), f.SQL),
+ reflect.Indirect(rv).Interface())
+ }
+ } else {
+ panic(fmt.Errorf("TODO"))
+ }
+ }
+
+ return update
+}