@@ 1,228 0,0 @@
-package main
-
-import (
- "fmt"
- "io"
- "strings"
-
- "git.sr.ht/~sircmpwn/go-bare/schema"
-)
-
-type Context struct {
- unions []*schema.UserDefinedType
- unionMembers map[schema.Type]interface{}
-}
-
-func genTypes(w io.Writer, types []schema.SchemaType) {
- fmt.Fprintf(w, `
-// THIS FILE WAS GENERATED BY A TOOL, DO NOT EDIT
-
-import (
- "errors"
- "git.sr.ht/~sircmpwn/go-bare"
-)
-`)
-
- ctx := Context{
- unions: nil,
- unionMembers: make(map[schema.Type]interface{}),
- }
- for _, ty := range types {
- switch ty := ty.(type) {
- case *schema.UserDefinedType:
- ctx.genUserType(w, ty)
- case *schema.UserDefinedEnum:
- ctx.genUserEnum(w, ty)
- }
- }
-
- if len(ctx.unions) > 0 {
- fmt.Fprintf(w, "\nfunc init() {\n")
- for _, udt := range ctx.unions {
- fmt.Fprintf(w, "\tbare.RegisterUnion((*%s)(nil)).\n", udt.Name())
- ut, _ := udt.Type().(*schema.UnionType)
- for i, ty := range ut.Types() {
- tag := ty.Tag()
- switch ty := ty.Type().(type) {
- case *schema.NamedUserType:
- fmt.Fprintf(w, "\t\tMember(*new(%s), %d)", ty.Name(), tag)
- default:
- panic(fmt.Errorf("TODO: Implement unions with primitive types"))
- }
- if i < len(ut.Types()) - 1 {
- fmt.Fprintf(w, ".\n")
- }
- }
- fmt.Fprintf(w, "\n")
- }
- fmt.Fprintf(w, "}\n")
- }
-}
-
-func (ctx *Context) genUserType(w io.Writer, udt *schema.UserDefinedType) {
- if udt.Type().Kind() == schema.Union {
- ctx.genUserUnion(w, udt)
- return
- }
-
- fmt.Fprintf(w, "\ntype %s ", udt.Name())
- genType(w, udt.Type(), 0)
- fmt.Fprintf(w, "\n")
-
- fmt.Fprintf(w, "\nfunc (t *%s) Decode(data []byte) error {", udt.Name())
- fmt.Fprintf(w, "\n\treturn bare.Unmarshal(data, t)")
- fmt.Fprintf(w, "\n}\n")
-
- fmt.Fprintf(w, "\nfunc (t *%s) Encode() ([]byte, error) {", udt.Name())
- fmt.Fprintf(w, "\n\treturn bare.Marshal(t)")
- fmt.Fprintf(w, "\n}\n")
-}
-
-func (ctx *Context) genUserEnum(w io.Writer, ude *schema.UserDefinedEnum) {
- // TODO: Disambiguate between enums with conflicting value names
- fmt.Fprintf(w, "\ntype %s %s\n", ude.Name(), primitiveType(ude.Kind()))
- fmt.Fprintf(w, "\nconst (")
- for i, val := range ude.Values() {
- if i == 0 {
- fmt.Fprintf(w, "\n\t%s %s = %d", val.Name(), ude.Name(), val.Value())
- } else {
- fmt.Fprintf(w, "\n\t%s = %d", val.Name(), val.Value())
- }
- }
- fmt.Fprintf(w, "\n)\n")
-
- fmt.Fprintf(w, "\nfunc (t %s) String() string {", ude.Name())
- fmt.Fprintf(w, "\n\tswitch (t) {")
- for _, val := range ude.Values() {
- fmt.Fprintf(w, "\n\tcase %s:", val.Name())
- fmt.Fprintf(w, "\n\t\treturn \"%s\"", val.Name())
- }
- fmt.Fprintf(w, "\n\t}")
- fmt.Fprintf(w, "\n\tpanic(errors.New(\"Invalid %s value\"))", ude.Name())
- fmt.Fprintf(w, "\n}\n")
-}
-
-func (ctx *Context) genUserUnion(w io.Writer, udt *schema.UserDefinedType) {
- fmt.Fprintf(w, "\ntype %s interface {", udt.Name())
- fmt.Fprintf(w, "\n\tbare.Union")
- fmt.Fprintf(w, "\n}\n")
-
- ut, _ := udt.Type().(*schema.UnionType)
- for _, ty := range ut.Types() {
- // XXX: This doesn't actually work the way it looks like it ought to
- if _, ok := ctx.unionMembers[ty.Type()]; ok {
- continue
- }
-
- ctx.unionMembers[ty.Type()] = nil
-
- switch ty := ty.Type().(type) {
- case *schema.NamedUserType:
- fmt.Fprintf(w, "\nfunc (_ %s) IsUnion() { }\n", ty.Name())
- default:
- panic(fmt.Errorf("TODO: Implement unions with primitive types"))
- }
- }
-
- ctx.unions = append(ctx.unions, udt)
-}
-
-func genType(w io.Writer, ty schema.Type, indent int) {
- switch ty := ty.(type) {
- case *schema.PrimitiveType:
- fmt.Fprintf(w, "%s", primitiveType(ty.Kind()))
- case *schema.DataType:
- if ty.Kind() == schema.DataArray {
- fmt.Fprintf(w, "[%d]byte", ty.Length())
- } else {
- fmt.Fprintf(w, "[]byte")
- }
- case *schema.StructType:
- maxName := 0
- for _, field := range ty.Fields() {
- if len(field.Name()) > maxName {
- maxName = len(field.Name())
- }
- }
-
- fmt.Fprintf(w, "struct {\n")
- for _, field := range ty.Fields() {
- genIndent(w, indent + 1)
- n := fieldName(field.Name())
- fmt.Fprintf(w, "%s ", n)
- for i := len(n); i < maxName; i++ {
- fmt.Fprintf(w, " ")
- }
- genType(w, field.Type(), indent + 1)
- fmt.Fprintf(w, " `bare:\"%s\"`", field.Name())
- fmt.Fprintf(w, "\n")
- }
- genIndent(w, indent)
- fmt.Fprintf(w, "}")
- case *schema.NamedUserType:
- fmt.Fprintf(w, "%s", ty.Name())
- case *schema.MapType:
- fmt.Fprintf(w, "map[")
- genType(w, ty.Key(), indent)
- fmt.Fprintf(w, "]")
- genType(w, ty.Value(), indent)
- case *schema.ArrayType:
- if ty.Kind() == schema.Array {
- fmt.Fprintf(w, "[%d]", ty.Length())
- } else {
- fmt.Fprintf(w, "[]")
- }
- genType(w, ty.Member(), indent)
- case *schema.OptionalType:
- fmt.Fprintf(w, "*")
- genType(w, ty.Subtype(), indent)
- default:
- panic(fmt.Errorf("Unimplemented schema type: %T", ty))
- }
-}
-
-func genUnion(w io.Writer, ut *schema.UnionType, indent int) {
-}
-
-func primitiveType(kind schema.TypeKind) string {
- switch kind {
- case schema.U8:
- return "uint8"
- case schema.U16:
- return "uint16"
- case schema.U32:
- return "uint32"
- case schema.U64:
- return "uint64"
- case schema.I8:
- return "int8"
- case schema.I16:
- return "int16"
- case schema.I32:
- return "int32"
- case schema.I64:
- return "int64"
- case schema.F32:
- return "float32"
- case schema.F64:
- return "float64"
- case schema.Bool:
- return "bool"
- case schema.String:
- return "string"
- case schema.Void:
- return "struct{}"
- }
- panic(fmt.Errorf("Invalid primitive type %d", kind))
-}
-
-func genIndent(w io.Writer, indent int) {
- for ; indent > 0; indent-- {
- fmt.Fprintf(w, "\t")
- }
-}
-
-func fieldName(n string) string {
- // TODO: Correct initialisms
- return strings.ToUpper(n[:1]) + n[1:]
-}
@@ 1,29 1,235 @@
package main
import (
+ "bytes"
"fmt"
+ "go/format"
+ "io/ioutil"
"log"
"os"
+ "strings"
+ "text/template"
"git.sr.ht/~sircmpwn/getopt"
"git.sr.ht/~sircmpwn/go-bare/schema"
)
+const templateString = `
+package {{.package}}
+
+// Code generated by go-bare/cmd/gen, DO NOT EDIT.
+
+import (
+ "errors"
+ "git.sr.ht/~sircmpwn/go-bare"
+)
+
+{{ define "type" }}
+ {{- if eq (typeKind .) "PrimitiveType" -}}
+ {{ primitiveType .Kind }}
+ {{- else if eq (typeKind .) "DataType" -}}
+ [{{if gt .Length 0 }}{{.Length}}{{end}}]byte
+ {{- else if eq (typeKind .) "ArrayType" -}}
+ [{{if gt .Length 0 }}{{.Length}}{{end}}]{{template "type" .Member}}
+ {{- else if eq (typeKind .) "StructType" -}}
+ struct {
+ {{- range .Fields }}
+ {{ capitalize .Name }} {{ template "type" .Type }} {{ structTag .Name }}
+ {{- end -}}
+ }
+ {{- else if eq (typeKind .) "NamedUserType" -}}
+ {{.Name}}
+ {{- else if eq (typeKind .) "MapType" -}}
+ map[{{template "type" .Key}}]{{template "type" .Value}}
+ {{- else if eq (typeKind .) "OptionalType" -}}
+ *{{template "type" .Subtype}}
+ {{- end -}}
+{{ end }}
+
+{{with .schema}}
+
+{{range .UserTypes}}
+ type {{ .Name }} {{ template "type" .Type }}
+
+ func (t *{{ .Name }}) Decode(data []byte) error {
+ return bare.Unmarshal(data, t)
+ }
+
+ func (t *{{ .Name }}) Encode() ([]byte, error) {
+ return bare.Marshal(t)
+ }
+{{end}}
+
+{{range .Enums}}
+type {{ .Name }} {{ primitiveType .Kind }}
+
+{{ $name := .Name }}
+
+const (
+ {{- range $i, $el := .Values }}
+ {{ .Name }} {{ $name }} = {{ .Value }}
+ {{- end -}}
+ )
+
+ func (t {{ .Name }}) String() string {
+ switch (t) {
+ {{- range .Values }}
+ case {{ .Name }}:
+ return "{{ .Name }}"
+ {{- end -}}
+ }
+ panic(errors.New("Invalid {{.Name}} value"))
+ }
+{{end}}
+
+{{ if gt (len .Unions) 0 }}
+ {{range .Unions}}
+ type {{ .Name }} interface {
+ bare.Union
+ }
+
+ {{range .Type.Types}}
+ func (_ {{.Type.Name}}) IsUnion() {}
+ {{end}}
+ {{end}}
+
+ func init() {
+ {{- range .Unions}}
+ bare.RegisterUnion((*{{.Name}})(nil)).
+ {{ $len := len .Type.Types }}
+ {{range $i, $el := .Type.Types}}
+ Member(*new({{ template "type" $el.Type}}), {{$el.Tag}}){{- if not (last $len $i) -}}.{{end}}
+ {{end}}
+ {{ end }}
+ }
+{{ end}}
+
+{{end}}
+`
+
+var funcs = template.FuncMap{
+ "typeKind": func(ty interface{}) string {
+ switch ty := ty.(type) {
+ case *schema.PrimitiveType:
+ return "PrimitiveType"
+ case *schema.DataType:
+ return "DataType"
+ case *schema.StructType:
+ return "StructType"
+ case *schema.NamedUserType:
+ return "NamedUserType"
+ case *schema.MapType:
+ return "MapType"
+ case *schema.ArrayType:
+ return "ArrayType"
+ case *schema.OptionalType:
+ return "OptionalType"
+ default:
+ panic(fmt.Sprintf("Unimplemented schema type: %T", ty))
+ }
+ },
+ "primitiveType": func(t schema.TypeKind) string {
+ switch t {
+ case schema.U8:
+ return "uint8"
+ case schema.U16:
+ return "uint16"
+ case schema.U32:
+ return "uint32"
+ case schema.U64:
+ return "uint64"
+ case schema.I8:
+ return "int8"
+ case schema.I16:
+ return "int16"
+ case schema.I32:
+ return "int32"
+ case schema.I64:
+ return "int64"
+ case schema.F32:
+ return "float32"
+ case schema.F64:
+ return "float64"
+ case schema.Bool:
+ return "bool"
+ case schema.String:
+ return "string"
+ case schema.Void:
+ return "struct{}"
+ }
+ panic(fmt.Errorf("Invalid primitive type %d", t))
+ },
+ "structTag": func(name string) string {
+ return fmt.Sprintf("`bare:\"%s\"`", name)
+ },
+ "capitalize": func(s string) string {
+ return strings.ToUpper(s[:1]) + s[1:]
+ },
+ "last": func(len, i int) bool {
+ return i+1 == len
+ },
+}
+
func main() {
+ cfg := parseArgs()
+ out := &bytes.Buffer{}
+
+ tmpl, err := template.New("").Funcs(funcs).Parse(templateString)
+ if err != nil {
+ log.Fatalf("error parsing template: %v", err)
+ }
+
+ types := parseSchema(cfg.In, cfg.Skip)
+
+ data := make(map[string]interface{})
+
+ data["package"] = cfg.PackageName
+ data["schema"] = types
+
+ err = tmpl.Execute(out, data)
+ if err != nil {
+ log.Fatalf("error executing template: %v", err)
+ }
+
+ // Format generated code
+ formatted, err := format.Source(out.Bytes())
+ if err != nil {
+ log.Println(out.String())
+ log.Fatalf("--- error formatting source code: %v", err)
+ }
+
+ err = ioutil.WriteFile(cfg.Out, formatted, 0644)
+ if err != nil {
+ log.Fatalf("error writing output to %s: %e", cfg.Out, err)
+ }
+}
+
+type Config struct {
+ PackageName string
+ In string
+ Out string
+ Skip map[string]bool
+}
+
+func parseArgs() *Config {
+ cfg := &Config{}
+
log.SetFlags(0)
opts, optind, err := getopt.Getopts(os.Args, "hs:p:")
if err != nil {
log.Fatalf("error: %e", err)
}
- pkg := "gen"
- skip := make(map[string]interface{})
+
+ cfg.PackageName = "gen"
+ cfg.Skip = make(map[string]bool)
+
for _, opt := range opts {
switch opt.Option {
case 'p':
- pkg = opt.Value
+ cfg.PackageName = opt.Value
case 's':
- skip[opt.Value] = nil
+ cfg.Skip[opt.Value] = true
case 'h':
log.Println("Usage: gen [-p <package>] [-s <skip type>] <input.bare> <output.go>")
os.Exit(0)
@@ 34,36 240,51 @@ func main() {
if len(args) != 2 {
log.Fatal("Usage: gen [-p <package>] <input.bare> <output.go>")
}
- in := args[0]
- out := args[1]
- inf, err := os.Open(in)
+ cfg.In = args[0]
+ cfg.Out = args[1]
+
+ return cfg
+}
+
+type Types struct {
+ UserTypes []*schema.UserDefinedType
+ Enums []*schema.UserDefinedEnum
+ Unions []*schema.UserDefinedType
+}
+
+func parseSchema(path string, skip map[string]bool) Types {
+ inf, err := os.Open(path)
if err != nil {
- log.Fatalf("error opening %s: %e", in, err)
+ log.Fatalf("error opening %s: %e", path, err)
}
defer inf.Close()
- types, err := schema.Parse(inf)
+ schemaTypes, err := schema.Parse(inf)
if err != nil {
- log.Fatalf("error parsing %s: %e", in, err)
+ log.Fatalf("error parsing %s: %e", path, err)
}
- outf, err := os.Create(out)
- if err != nil {
- log.Fatalf("error opening %s for writing: %e", out, err)
- }
- defer outf.Close()
- fmt.Fprintf(outf, "package %s\n", pkg)
-
- if len(skip) != 0 {
- var typesp []schema.SchemaType
- for _, ty := range types {
- if _, ok := skip[ty.Name()]; !ok {
- typesp = append(typesp, ty)
+ types := Types{}
+
+ for _, ty := range schemaTypes {
+ if skip[ty.Name()] {
+ continue
+ }
+
+ switch ty := ty.(type) {
+ case *schema.UserDefinedType:
+ if ty.Type().Kind() == schema.Union {
+ types.Unions = append(types.Unions, ty)
+ continue
}
+ types.UserTypes = append(types.UserTypes, ty)
+
+ case *schema.UserDefinedEnum:
+ types.Enums = append(types.Enums, ty)
+
}
- types = typesp
}
- genTypes(outf, types)
+ return types
}
@@ 1,6 1,6 @@
package example
-// THIS FILE WAS GENERATED BY A TOOL, DO NOT EDIT
+// Code generated by go-bare/cmd/gen, DO NOT EDIT.
import (
"errors"
@@ 17,37 17,11 @@ func (t *PublicKey) Encode() ([]byte, error) {
return bare.Marshal(t)
}
-type Department uint8
-
-const (
- ACCOUNTING Department = 0
- ADMINISTRATION = 1
- CUSTOMER_SERVICE = 2
- DEVELOPMENT = 3
- JSMITH = 99
-)
-
-func (t Department) String() string {
- switch (t) {
- case ACCOUNTING:
- return "ACCOUNTING"
- case ADMINISTRATION:
- return "ADMINISTRATION"
- case CUSTOMER_SERVICE:
- return "CUSTOMER_SERVICE"
- case DEVELOPMENT:
- return "DEVELOPMENT"
- case JSMITH:
- return "JSMITH"
- }
- panic(errors.New("Invalid Department value"))
-}
-
type Customer struct {
- Name string `bare:"name"`
- Email string `bare:"email"`
- Address Address `bare:"address"`
- Orders []struct {
+ Name string `bare:"name"`
+ Email string `bare:"email"`
+ Address Address `bare:"address"`
+ Orders []struct {
OrderId int64 `bare:"orderId"`
Quantity int32 `bare:"quantity"`
} `bare:"orders"`
@@ 63,12 37,12 @@ func (t *Customer) Encode() ([]byte, error) {
}
type Employee struct {
- Name string `bare:"name"`
- Email string `bare:"email"`
- Address Address `bare:"address"`
- Department Department `bare:"department"`
- HireDate Time `bare:"hireDate"`
- PublicKey *PublicKey `bare:"publicKey"`
+ Name string `bare:"name"`
+ Email string `bare:"email"`
+ Address Address `bare:"address"`
+ Department Department `bare:"department"`
+ HireDate Time `bare:"hireDate"`
+ PublicKey *PublicKey `bare:"publicKey"`
Metadata map[string][]byte `bare:"metadata"`
}
@@ 90,21 64,11 @@ func (t *TerminatedEmployee) Encode() ([]byte, error) {
return bare.Marshal(t)
}
-type Person interface {
- bare.Union
-}
-
-func (_ Customer) IsUnion() { }
-
-func (_ Employee) IsUnion() { }
-
-func (_ TerminatedEmployee) IsUnion() { }
-
type Address struct {
Address [4]string `bare:"address"`
- City string `bare:"city"`
- State string `bare:"state"`
- Country string `bare:"country"`
+ City string `bare:"city"`
+ State string `bare:"state"`
+ Country string `bare:"country"`
}
func (t *Address) Decode(data []byte) error {
@@ 115,9 79,46 @@ func (t *Address) Encode() ([]byte, error) {
return bare.Marshal(t)
}
+type Department uint8
+
+const (
+ ACCOUNTING Department = 0
+ ADMINISTRATION Department = 1
+ CUSTOMER_SERVICE Department = 2
+ DEVELOPMENT Department = 3
+ JSMITH Department = 99
+)
+
+func (t Department) String() string {
+ switch t {
+ case ACCOUNTING:
+ return "ACCOUNTING"
+ case ADMINISTRATION:
+ return "ADMINISTRATION"
+ case CUSTOMER_SERVICE:
+ return "CUSTOMER_SERVICE"
+ case DEVELOPMENT:
+ return "DEVELOPMENT"
+ case JSMITH:
+ return "JSMITH"
+ }
+ panic(errors.New("Invalid Department value"))
+}
+
+type Person interface {
+ bare.Union
+}
+
+func (_ Customer) IsUnion() {}
+
+func (_ Employee) IsUnion() {}
+
+func (_ TerminatedEmployee) IsUnion() {}
+
func init() {
bare.RegisterUnion((*Person)(nil)).
Member(*new(Customer), 0).
Member(*new(Employee), 1).
Member(*new(TerminatedEmployee), 2)
+
}