@@ 1,6 1,9 @@
package flag
import (
+ "errors"
+ "fmt"
+ "os"
"strings"
"testing"
"time"
@@ 110,7 113,8 @@ func TestParseFlags(t *testing.T) {
for _, c := range cases {
t.Run(strings.Join(c.args, " "), func(t *testing.T) {
var f F
- err := p.Parse(c.args, &f)
+ args := append([]string{""}, c.args...)
+ err := p.Parse(args, &f)
if c.err != "" {
require.Error(t, err)
@@ 131,7 135,7 @@ func TestParseNoFlag(t *testing.T) {
var p Parser
f := F{V: 4}
- err := p.Parse([]string{"x"}, &f)
+ err := p.Parse([]string{"", "x"}, &f)
require.NoError(t, err)
require.Equal(t, 4, f.V)
}
@@ 147,7 151,7 @@ func (n *noFlagSetArgs) SetArgs(args []string) {
func TestParseNoFlagSetArgs(t *testing.T) {
var p Parser
f := noFlagSetArgs{}
- err := p.Parse([]string{"x"}, &f)
+ err := p.Parse([]string{"", "x"}, &f)
require.NoError(t, err)
require.Equal(t, []string{"x"}, f.args)
}
@@ 158,7 162,7 @@ func TestParseArgsError(t *testing.T) {
}
var p Parser
f := F{}
- err := p.Parse([]string{"-zz"}, &f)
+ err := p.Parse([]string{"", "-zz"}, &f)
require.Error(t, err)
require.Contains(t, err.Error(), "-zz")
}
@@ 182,6 186,129 @@ func TestParseUnsupportedFlagType(t *testing.T) {
p Parser
)
require.Panics(t, func() {
- _ = p.Parse([]string{"-h"}, &f)
+ _ = p.Parse([]string{"", "-h"}, &f)
})
}
+
+type E struct {
+ Addr string `flag:"addr"`
+ DB string `flag:"db"`
+ Help bool `flag:"h,help" ignored:"true"`
+ Version bool `flag:"v,version" ignored:"true"`
+}
+
+func (e *E) Validate() error {
+ if e.Help || e.Version {
+ return nil
+ }
+ if e.Addr == "" {
+ return errors.New("address must be set")
+ }
+ if e.DB == "" {
+ return errors.New("db must be set")
+ }
+ return nil
+}
+
+func TestParseEnvVars(t *testing.T) {
+ const progName = "mainer-test"
+
+ p := Parser{
+ EnvVars: true,
+ }
+
+ cases := []struct {
+ env string // prefix-less Key:val pairs, space-separated
+ args string // space-separated, index 0 added automatically
+ want E
+ errMsg string // error must contain that errMsg
+ }{
+ {
+ "",
+ "",
+ E{},
+ "address must be set",
+ },
+ {
+ "ADDR::1234 DB:x",
+ "",
+ E{Addr: ":1234", DB: "x"},
+ "",
+ },
+ {
+ "",
+ "-addr :2345 -db v",
+ E{Addr: ":2345", DB: "v"},
+ "",
+ },
+ {
+ "ADDR::1234",
+ "-addr :2345 -db x",
+ E{Addr: ":2345", DB: "x"},
+ "",
+ },
+ {
+ "HELP:true",
+ "-addr :2345",
+ E{Addr: ":2345"},
+ "db must be set",
+ },
+ {
+ "VERSION:1",
+ "-addr :2345 -db x",
+ E{Addr: ":2345", DB: "x"},
+ "",
+ },
+ {
+ "",
+ "-help",
+ E{Help: true},
+ "",
+ },
+ {
+ "",
+ "-v",
+ E{Version: true},
+ "",
+ },
+ {
+ "",
+ "-z",
+ E{},
+ "flag provided but not defined: -z",
+ },
+ }
+ for _, c := range cases {
+ t.Run(fmt.Sprintf("%s|%s", c.env, c.args), func(t *testing.T) {
+ // set env vars
+ if c.env != "" {
+ envPairs := strings.Split(c.env, " ")
+ for _, pair := range envPairs {
+ ix := strings.Index(pair, ":")
+ require.True(t, ix >= 0, "%s: missing colon", pair)
+ key, val := pair[:ix], pair[ix+1:]
+ key = strings.ToUpper(prefixFromProgramName(progName)) + "_" + key
+ os.Setenv(key, val)
+ defer os.Unsetenv(key)
+ }
+ }
+
+ // parse args
+ args := []string{progName}
+ if c.args != "" {
+ args = append(args, strings.Split(c.args, " ")...)
+ }
+
+ var e E
+ err := p.Parse(args, &e)
+ if c.errMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), c.errMsg)
+ } else {
+ require.NoError(t, err)
+ }
+
+ require.Equal(t, c.want, e)
+ })
+ }
+}