@@ 0,0 1,152 @@
+// Package flag implements a command-line flags parser that uses struct tags to
+// configure supported flags and returns any error it encounters, without
+// printing anything automatically. It can optionally read flag values from
+// environment variables first, with the command-line flags used to override
+// them.
+//
+// The struct tag to specify flags is `flag`, while the one to specify
+// environment variables is `envconfig`. See the envconfig package for full
+// details on struct tags configuration and decoding support:
+// https://github.com/kelseyhightower/envconfig.
+//
+// Flag parsing uses the stdlib's flag package internally, and as such shares
+// the same behaviour regarding short and long flags.
+package flag
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "reflect"
+ "strings"
+ "time"
+
+ "github.com/kelseyhightower/envconfig"
+)
+
+// Parser implements a flag parser.
+type Parser struct {
+ // EnvVars indicates if environment variables are used to read flag values.
+ EnvVars bool
+
+ // EnvPrefix is the prefix to use in front of each flag's environment
+ // variable name. If it is empty, the name of the program (as read from the
+ // args slice at index 0) is used, with dashes replaced with underscores.
+ EnvPrefix string
+}
+
+// Parse parses args into v, using struct tags to detect flags. The tag must
+// be named "flag" and multiple flags may be set for the same field using a
+// comma-separated list. v must be a pointer to a struct and the flags must be
+// defined on fields with a type of string, int, bool or time.Duration.
+// If Parser.EnvVars is true, flag values are initialized from corresponding
+// environment variables first.
+//
+// After parsing, if v implements a Validate method that returns an error, it
+// is called and any non-nil error is returned as error.
+//
+// If v has a SetArgs method, it is called with the list of non-flag arguments.
+//
+// If v has a SetFlags method, it is called with the set of flags that were set
+// by args (a map[string]bool).
+//
+// It panics if v is not a pointer to a struct or if a flag is defined with an
+// unsupported type.
+func (p *Parser) Parse(args []string, v interface{}) error {
+ if p.EnvVars {
+ if err := p.parseEnvVars(args, v); err != nil {
+ return err
+ }
+ }
+
+ if err := p.parseFlags(args, v); err != nil {
+ return err
+ }
+
+ if val, ok := v.(interface{ Validate() error }); ok {
+ return val.Validate()
+ }
+ return nil
+}
+
+func (p *Parser) parseFlags(args []string, v interface{}) error {
+ if len(args) == 0 {
+ return nil
+ }
+
+ // create a FlagSet that is silent and only returns any error
+ // it encounters.
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.SetOutput(ioutil.Discard)
+ fs.Usage = nil
+
+ durationType := reflect.TypeOf(time.Duration(0))
+
+ // extract the flags from the struct
+ val := reflect.ValueOf(v).Elem()
+ str := reflect.TypeOf(v).Elem()
+ count := val.NumField()
+ for i := 0; i < count; i++ {
+ fld := val.Field(i)
+ typ := str.Field(i)
+ names := strings.Split(typ.Tag.Get("flag"), ",")
+
+ for _, nm := range names {
+ if nm == "" {
+ continue
+ }
+ switch fld.Kind() {
+ case reflect.Bool:
+ fs.BoolVar(fld.Addr().Interface().(*bool), nm, fld.Bool(), "")
+ case reflect.String:
+ fs.StringVar(fld.Addr().Interface().(*string), nm, fld.String(), "")
+ case reflect.Int:
+ fs.IntVar(fld.Addr().Interface().(*int), nm, int(fld.Int()), "")
+ default:
+ switch typ.Type {
+ case durationType:
+ fs.DurationVar(fld.Addr().Interface().(*time.Duration), nm, fld.Interface().(time.Duration), "")
+ default:
+ panic(fmt.Sprintf("unsupported flag field kind: %s (%s: %s)", fld.Kind(), typ.Name, typ.Type))
+ }
+ }
+ }
+ }
+
+ if err := fs.Parse(args[1:]); err != nil {
+ return err
+ }
+
+ if sa, ok := v.(interface{ SetArgs([]string) }); ok {
+ args := fs.Args()
+ if len(args) == 0 {
+ args = nil
+ }
+ sa.SetArgs(args)
+ }
+ if sf, ok := v.(interface{ SetFlags(map[string]bool) }); ok {
+ set := make(map[string]bool)
+ fs.Visit(func(fl *flag.Flag) {
+ set[fl.Name] = true
+ })
+ if len(set) == 0 {
+ set = nil
+ }
+ sf.SetFlags(set)
+ }
+
+ return nil
+}
+
+func (p *Parser) parseEnvVars(args []string, v interface{}) error {
+ prefix := p.EnvPrefix
+
+ if prefix == "" && len(args) > 0 {
+ prefix = prefixFromProgramName(args[0])
+ }
+ return envconfig.Process(prefix, v)
+}
+
+func prefixFromProgramName(name string) string {
+ return strings.ReplaceAll(name, "-", "_")
+}
@@ 0,0 1,314 @@
+package flag
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type F struct {
+ S string `flag:"s,string,long-string"`
+ I int `flag:"i,int"`
+ B bool `flag:"b"`
+ H bool `flag:"h,help"`
+ T time.Duration `flag:"t"`
+ N int
+ args []string
+ flags map[string]bool
+}
+
+func (f *F) SetArgs(args []string) {
+ f.args = args
+}
+
+func (f *F) SetFlags(flags map[string]bool) {
+ f.flags = flags
+}
+
+func TestParseFlags(t *testing.T) {
+ cases := []struct {
+ args []string
+ want *F
+ err string
+ }{
+ {
+ want: &F{},
+ },
+ {
+ args: []string{"toto"},
+ want: &F{
+ args: []string{"toto"},
+ },
+ },
+ {
+ args: []string{"-h"},
+ want: &F{
+ H: true,
+ flags: map[string]bool{"h": true},
+ },
+ },
+ {
+ args: []string{"-i", "10", "--int", "20"},
+ want: &F{
+ I: 20,
+ flags: map[string]bool{"i": true, "int": true},
+ },
+ },
+ {
+ args: []string{"-i", "10", "--int", "20"},
+ want: &F{
+ I: 20,
+ flags: map[string]bool{"i": true, "int": true},
+ },
+ },
+ {
+ args: []string{"-s", "a", "--string", "b", "-long-string", "c"},
+ want: &F{
+ S: "c",
+ flags: map[string]bool{"s": true, "string": true, "long-string": true},
+ },
+ },
+ {
+ args: []string{"-b", "--b", "-b"},
+ want: &F{
+ B: true,
+ flags: map[string]bool{"b": true},
+ },
+ },
+ {
+ args: []string{"-b", "-int", "1", "-string", "a", "arg1", "arg2"},
+ want: &F{
+ B: true,
+ I: 1,
+ S: "a",
+ args: []string{"arg1", "arg2"},
+ flags: map[string]bool{"b": true, "int": true, "string": true},
+ },
+ },
+ {
+ args: []string{"-n", "1"},
+ want: &F{},
+ err: "not defined: -n",
+ },
+ {
+ args: []string{"-t", "3s"},
+ want: &F{
+ T: 3 * time.Second,
+ flags: map[string]bool{"t": true},
+ },
+ },
+ {
+ args: []string{"-t", "nope"},
+ want: &F{},
+ err: "invalid value",
+ },
+ }
+
+ var p Parser
+ for _, c := range cases {
+ t.Run(strings.Join(c.args, " "), func(t *testing.T) {
+ var f F
+ args := append([]string{""}, c.args...)
+ err := p.Parse(args, &f)
+
+ if c.err != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), c.err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, c.want, &f)
+ })
+ }
+}
+
+func TestParseNoFlag(t *testing.T) {
+ type F struct {
+ V int
+ }
+ var p Parser
+
+ f := F{V: 4}
+ err := p.Parse([]string{"", "x"}, &f)
+ require.NoError(t, err)
+ require.Equal(t, 4, f.V)
+}
+
+type noFlagSetArgs struct {
+ args []string
+}
+
+func (n *noFlagSetArgs) SetArgs(args []string) {
+ n.args = args
+}
+
+func TestParseNoFlagSetArgs(t *testing.T) {
+ var p Parser
+ f := noFlagSetArgs{}
+ err := p.Parse([]string{"", "x"}, &f)
+ require.NoError(t, err)
+ require.Equal(t, []string{"x"}, f.args)
+}
+
+func TestParseArgsError(t *testing.T) {
+ type F struct {
+ X bool `flag:"x"`
+ }
+ var p Parser
+ f := F{}
+ err := p.Parse([]string{"", "-zz"}, &f)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "-zz")
+}
+
+func TestParseNotStructPointer(t *testing.T) {
+ var (
+ i int
+ p Parser
+ )
+ require.Panics(t, func() {
+ _ = p.Parse([]string{"-h"}, i)
+ })
+}
+
+func TestParseUnsupportedFlagType(t *testing.T) {
+ type F struct {
+ C *bool `flag:"c"`
+ }
+ var (
+ f F
+ p Parser
+ )
+ require.Panics(t, func() {
+ _ = p.Parse([]string{"", "-h"}, &f)
+ })
+}
+
+type E struct {
+ Addr string `flag:"addr"`
+ DB string `flag:"db"`
+ Help bool `flag:"h,help" ignored:"true"`
+ Version bool `flag:"v,version" ignored:"true"`
+}
+
+func (e *E) Validate() error {
+ if e.Help || e.Version {
+ return nil
+ }
+ if e.Addr == "" {
+ return errors.New("address must be set")
+ }
+ if e.DB == "" {
+ return errors.New("db must be set")
+ }
+ return nil
+}
+
+func TestParseEnvVars(t *testing.T) {
+ const progName = "mainer-test"
+
+ p := Parser{
+ EnvVars: true,
+ }
+
+ cases := []struct {
+ env string // prefix-less Key:val pairs, space-separated
+ args string // space-separated, index 0 added automatically
+ want E
+ errMsg string // error must contain that errMsg
+ }{
+ {
+ "",
+ "",
+ E{},
+ "address must be set",
+ },
+ {
+ "ADDR::1234 DB:x",
+ "",
+ E{Addr: ":1234", DB: "x"},
+ "",
+ },
+ {
+ "",
+ "-addr :2345 -db v",
+ E{Addr: ":2345", DB: "v"},
+ "",
+ },
+ {
+ "ADDR::1234",
+ "-addr :2345 -db x",
+ E{Addr: ":2345", DB: "x"},
+ "",
+ },
+ {
+ "HELP:true",
+ "-addr :2345",
+ E{Addr: ":2345"},
+ "db must be set",
+ },
+ {
+ "VERSION:1",
+ "-addr :2345 -db x",
+ E{Addr: ":2345", DB: "x"},
+ "",
+ },
+ {
+ "",
+ "-help",
+ E{Help: true},
+ "",
+ },
+ {
+ "",
+ "-v",
+ E{Version: true},
+ "",
+ },
+ {
+ "",
+ "-z",
+ E{},
+ "flag provided but not defined: -z",
+ },
+ }
+ for _, c := range cases {
+ t.Run(fmt.Sprintf("%s|%s", c.env, c.args), func(t *testing.T) {
+ // set env vars
+ if c.env != "" {
+ envPairs := strings.Split(c.env, " ")
+ for _, pair := range envPairs {
+ ix := strings.Index(pair, ":")
+ require.True(t, ix >= 0, "%s: missing colon", pair)
+ key, val := pair[:ix], pair[ix+1:]
+ key = strings.ToUpper(prefixFromProgramName(progName)) + "_" + key
+ os.Setenv(key, val)
+ defer os.Unsetenv(key)
+ }
+ }
+
+ // parse args
+ args := []string{progName}
+ if c.args != "" {
+ args = append(args, strings.Split(c.args, " ")...)
+ }
+
+ var e E
+ err := p.Parse(args, &e)
+ if c.errMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), c.errMsg)
+ } else {
+ require.NoError(t, err)
+ }
+
+ require.Equal(t, c.want, e)
+ })
+ }
+}