From fbfe6ad96677b22d38fae3005c0a10c88bb303cc Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Mon, 12 Sep 2022 23:06:30 -0400 Subject: [PATCH] Add support for Enums --- main.go | 141 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 100 insertions(+), 41 deletions(-) diff --git a/main.go b/main.go index 8bcb88c..1f68611 100644 --- a/main.go +++ b/main.go @@ -47,6 +47,7 @@ var ( protoreflect.FloatKind: "float", protoreflect.DoubleKind: "double", protoreflect.MessageKind: "none", + protoreflect.EnumKind: "uint32", } // decodeLUT is a look-up table from protobuf kind to the decode function to @@ -95,14 +96,14 @@ var ( fieldDecl{binding: "f_unknown_payload_bool", typ: "bool", abbr: "frpc.payload.bool", name: "Bool"}, fieldDecl{binding: "f_unknown_payload_bytes", typ: "bytes", abbr: "frpc.payload.bytes", name: "Bytes"}, fieldDecl{binding: "f_unknown_payload_string", typ: "string", abbr: "frpc.payload.string", name: "String"}, - fieldDecl{binding: "f_unknown_payload_uint8", typ: "uint8", abbr: "frpc.payload.uint8", name: "Uint8"}, - fieldDecl{binding: "f_unknown_payload_uint16", typ: "uint16", abbr: "frpc.payload.uint16", name: "Uint16"}, - fieldDecl{binding: "f_unknown_payload_uint32", typ: "uint32", abbr: "frpc.payload.uint32", name: "Uint32"}, - fieldDecl{binding: "f_unknown_payload_uint64", typ: "uint64", abbr: "frpc.payload.uint64", name: "Uint64"}, - fieldDecl{binding: "f_unknown_payload_int32", typ: "int32", abbr: "frpc.payload.int32", name: "Int32"}, - fieldDecl{binding: "f_unknown_payload_int64", typ: "int64", abbr: "frpc.payload.int64", name: "Int64"}, - fieldDecl{binding: "f_unknown_payload_float32", typ: "float", abbr: "frpc.payload.float32", name: "Float32"}, - fieldDecl{binding: "f_unknown_payload_float64", typ: "double", abbr: "frpc.payload.float64", name: "Float64"}, + fieldDecl{binding: "f_unknown_payload_uint8", typ: "uint8", abbr: "frpc.payload.uint8", name: "Uint8", base: "DEC"}, + fieldDecl{binding: "f_unknown_payload_uint16", typ: "uint16", abbr: "frpc.payload.uint16", name: "Uint16", base: "DEC"}, + fieldDecl{binding: "f_unknown_payload_uint32", typ: "uint32", abbr: "frpc.payload.uint32", name: "Uint32", base: "DEC"}, + fieldDecl{binding: "f_unknown_payload_uint64", typ: "uint64", abbr: "frpc.payload.uint64", name: "Uint64", base: "DEC"}, + fieldDecl{binding: "f_unknown_payload_int32", typ: "int32", abbr: "frpc.payload.int32", name: "Int32", base: "DEC"}, + fieldDecl{binding: "f_unknown_payload_int64", typ: "int64", abbr: "frpc.payload.int64", name: "Int64", base: "DEC"}, + fieldDecl{binding: "f_unknown_payload_float32", typ: "float", abbr: "frpc.payload.float32", name: "Float32", base: "DEC"}, + fieldDecl{binding: "f_unknown_payload_float64", typ: "double", abbr: "frpc.payload.float64", name: "Float64", base: "DEC"}, } customFields = []fieldDecl{ @@ -111,24 +112,46 @@ var ( otherFields = []fieldDecl{ fieldDecl{binding: "f_message_type", typ: "string", abbr: "frpc.message.type", name: "Type"}, + fieldDecl{binding: "f_enum_string", typ: "string", abbr: "frpc.enum.string", name: "Value"}, } ) -type operation struct { - service string - method string - inFunc string - outFunc string +type Operation struct { + Service string + Method string + InFunc string + OutFunc string } -func getMessagesDecl(desc protoreflect.MessageDescriptor) []protoreflect.MessageDescriptor { - descs := make([]protoreflect.MessageDescriptor, 0) +type FlatFile struct { + Messages []protoreflect.MessageDescriptor + Enums []protoreflect.EnumDescriptor +} + +func flattenFile(desc protoreflect.FileDescriptor) FlatFile { + file := FlatFile{ + Messages: make([]protoreflect.MessageDescriptor, 0), + Enums: make([]protoreflect.EnumDescriptor, 0), + } + for i := 0; i < desc.Enums().Len(); i += 1 { + file.Enums = append(file.Enums, desc.Enums().Get(i)) + } for i := 0; i < desc.Messages().Len(); i += 1 { - it := desc.Messages().Get(i) - descs = append(descs, it) - descs = append(descs, getMessagesDecl(it)...) + flattenMessage(desc.Messages().Get(i), &file) + } + return file +} + +func flattenMessage(desc protoreflect.MessageDescriptor, file *FlatFile) { + file.Messages = append(file.Messages, desc) + + for i := 0; i < desc.Enums().Len(); i += 1 { + file.Enums = append(file.Enums, desc.Enums().Get(i)) + } + + for i := 0; i < desc.Messages().Len(); i += 1 { + flattenMessage(desc.Messages().Get(i), file) } - return descs } func getMessageDissectFuncName(desc protoreflect.MessageDescriptor) string { @@ -146,6 +169,16 @@ func getFieldFieldName(desc protoreflect.FieldDescriptor) string { return fmt.Sprintf(`f_%s`, name) } +func getEnumFieldName(desc protoreflect.EnumDescriptor) string { + name := strings.ReplaceAll(strings.ToLower(string(desc.FullName())), ".", "_") + return fmt.Sprintf(`f_%s`, name) +} + +func getEnumLUTname(desc protoreflect.EnumDescriptor) string { + name := strings.ReplaceAll(strings.ToLower(string(desc.FullName())), ".", "_") + return fmt.Sprintf(`enum_%s_lut`, name) +} + // generateFile generates a _ascii.pb.go file containing gRPC service definitions. func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile { filename := file.GeneratedFilenamePrefix + ".frpc.lua" @@ -167,13 +200,19 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated fieldDecls = append(fieldDecls, customFields...) fieldDecls = append(fieldDecls, otherFields...) - var messages []protoreflect.MessageDescriptor - for _, msg := range file.Messages { - messages = append(messages, msg.Desc) - messages = append(messages, getMessagesDecl(msg.Desc)...) + flatFile := flattenFile(file.Desc) + + for _, enum := range flatFile.Enums { + decl := fieldDecl{ + binding: getEnumFieldName(enum), + typ: "uint32", + abbr: "frpc." + strings.ToLower(string(enum.FullName())), + name: string(enum.Name()), + } + fieldDecls = append(fieldDecls, decl) } - for _, msg := range messages { + for _, msg := range flatFile.Messages { fullName := string(msg.FullName()) decl := fieldDecl{ binding: getMessageFieldName(msg), @@ -227,13 +266,23 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated g.P("}") g.P() + for _, enum := range flatFile.Enums { + write(`local %s = {`, getEnumLUTname(enum)) + for i := 0; i < enum.Values().Len(); i += 1 { + val := enum.Values().Get(i) + write4(`[%d] = "%s",`, val.Number(), val.Name()) + } + write(`}`) + g.P() + } + // Forward declaration is necessary for Lua (if we don't want to build a DAG, etc) - for _, msg := range messages { + for _, msg := range flatFile.Messages { write(`local %s = function() end`, getMessageDissectFuncName(msg)) } g.P() - for _, msg := range messages { + for _, msg := range flatFile.Messages { write(`%s = function(buf, pinfo, tree)`, getMessageDissectFuncName(msg)) write4(`local subtree = tree:add(%s, buf())`, getMessageFieldName(msg)) write4(`subtree:add(f_message_type, "%s"):set_generated()`, string(msg.Name())) @@ -251,17 +300,28 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated continue } + builtInOnce.Do(func() { + write4(`-- Built-in types`) + }) + kind := field.Kind() fieldName := strings.ToLower(field.TextName()) + if kind == protoreflect.EnumKind { + enum := field.Enum() + write4(`local %s, noffset = decode_uint32(buf, noffset)`, fieldName) + write4(`local %s_str = %s[%s:uint()]`, fieldName, getEnumLUTname(enum), fieldName) + write4(`local enum_t = subtree:add(%s, %s)`, getEnumFieldName(enum), fieldName) + write4(`enum_t:add(f_enum_string, %s_str):set_generated()`, fieldName) + write4(`enum_t:append_text(string.format(" (%%s)", %s_str))`, fieldName) + g.P() + continue + } + decoder, exists := decodeLUT[kind] if !exists { continue } - builtInOnce.Do(func() { - write4(`-- Built-in types`) - }) - write4(`local %s, noffset = %s(buf, noffset)`, fieldName, decoder) write4(`subtree:add(%s, %s)`, getFieldFieldName(field), fieldName) g.P() @@ -340,20 +400,19 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated g.P() } - operationLUT := map[int]operation{ - 0: operation{method: "Heartbeat", inFunc: "nil", outFunc: "nil"}, - 1: operation{method: "Ping", inFunc: "nil", outFunc: "nil"}, - 2: operation{method: "Pong", inFunc: "nil", outFunc: "nil"}, + operationLUT := map[int]Operation{ + 0: Operation{Method: "Heartbeat", InFunc: "nil", OutFunc: "nil"}, + 1: Operation{Method: "Ping", InFunc: "nil", OutFunc: "nil"}, + 2: Operation{Method: "Pong", InFunc: "nil", OutFunc: "nil"}, } counter := 10 for _, svc := range file.Services { - // g.P("Service: ", svc.GoName) for _, method := range svc.Methods { - op := operation{ - service: string(svc.Desc.Name()), - method: string(method.Desc.Name()), - inFunc: getMessageDissectFuncName(method.Desc.Input()), - outFunc: getMessageDissectFuncName(method.Desc.Output()), + op := Operation{ + Service: string(svc.Desc.Name()), + Method: string(method.Desc.Name()), + InFunc: getMessageDissectFuncName(method.Desc.Input()), + OutFunc: getMessageDissectFuncName(method.Desc.Output()), } operationLUT[counter] = op counter += 1 @@ -363,7 +422,7 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated write(`local operation_lut =`) write(`{`) for opID, op := range operationLUT { - write4(`[%d] = { service = "%s", method = "%s", in_func = %s, out_func = %s },`, opID, op.service, op.method, op.inFunc, op.outFunc) + write4(`[%d] = { service = "%s", method = "%s", in_func = %s, out_func = %s },`, opID, op.Service, op.Method, op.InFunc, op.OutFunc) } write(`}`) g.P() -- 2.45.2