~mna/webparts-flag

3c1884910ba40f218f0b5bbc99e7ed7239f242f2 — Martin Angers 1 year, 6 months ago 306cfce
test with env vars
2 files changed, 137 insertions(+), 6 deletions(-)

M flag.go
M flag_test.go
M flag.go => flag.go +5 -1
@@ 73,6 73,10 @@ func (p *Parser) Parse(args []string, v interface{}) error {
}

func (p *Parser) parseFlags(args []string, v interface{}) error {
	if len(args) == 0 {
		return nil
	}

	// create a FlagSet that is silent and only returns any error
	// it encounters.
	fs := stdflag.NewFlagSet("", stdflag.ContinueOnError)


@@ 112,7 116,7 @@ func (p *Parser) parseFlags(args []string, v interface{}) error {
		}
	}

	if err := fs.Parse(args); err != nil {
	if err := fs.Parse(args[1:]); err != nil {
		return err
	}


M flag_test.go => flag_test.go +132 -5
@@ 1,6 1,9 @@
package flag

import (
	"errors"
	"fmt"
	"os"
	"strings"
	"testing"
	"time"


@@ 110,7 113,8 @@ func TestParseFlags(t *testing.T) {
	for _, c := range cases {
		t.Run(strings.Join(c.args, " "), func(t *testing.T) {
			var f F
			err := p.Parse(c.args, &f)
			args := append([]string{""}, c.args...)
			err := p.Parse(args, &f)

			if c.err != "" {
				require.Error(t, err)


@@ 131,7 135,7 @@ func TestParseNoFlag(t *testing.T) {
	var p Parser

	f := F{V: 4}
	err := p.Parse([]string{"x"}, &f)
	err := p.Parse([]string{"", "x"}, &f)
	require.NoError(t, err)
	require.Equal(t, 4, f.V)
}


@@ 147,7 151,7 @@ func (n *noFlagSetArgs) SetArgs(args []string) {
func TestParseNoFlagSetArgs(t *testing.T) {
	var p Parser
	f := noFlagSetArgs{}
	err := p.Parse([]string{"x"}, &f)
	err := p.Parse([]string{"", "x"}, &f)
	require.NoError(t, err)
	require.Equal(t, []string{"x"}, f.args)
}


@@ 158,7 162,7 @@ func TestParseArgsError(t *testing.T) {
	}
	var p Parser
	f := F{}
	err := p.Parse([]string{"-zz"}, &f)
	err := p.Parse([]string{"", "-zz"}, &f)
	require.Error(t, err)
	require.Contains(t, err.Error(), "-zz")
}


@@ 182,6 186,129 @@ func TestParseUnsupportedFlagType(t *testing.T) {
		p Parser
	)
	require.Panics(t, func() {
		_ = p.Parse([]string{"-h"}, &f)
		_ = p.Parse([]string{"", "-h"}, &f)
	})
}

type E struct {
	Addr    string `flag:"addr"`
	DB      string `flag:"db"`
	Help    bool   `flag:"h,help" ignored:"true"`
	Version bool   `flag:"v,version" ignored:"true"`
}

func (e *E) Validate() error {
	if e.Help || e.Version {
		return nil
	}
	if e.Addr == "" {
		return errors.New("address must be set")
	}
	if e.DB == "" {
		return errors.New("db must be set")
	}
	return nil
}

func TestParseEnvVars(t *testing.T) {
	const progName = "mainer-test"

	p := Parser{
		EnvVars: true,
	}

	cases := []struct {
		env    string // prefix-less Key:val pairs, space-separated
		args   string // space-separated, index 0 added automatically
		want   E
		errMsg string // error must contain that errMsg
	}{
		{
			"",
			"",
			E{},
			"address must be set",
		},
		{
			"ADDR::1234 DB:x",
			"",
			E{Addr: ":1234", DB: "x"},
			"",
		},
		{
			"",
			"-addr :2345 -db v",
			E{Addr: ":2345", DB: "v"},
			"",
		},
		{
			"ADDR::1234",
			"-addr :2345 -db x",
			E{Addr: ":2345", DB: "x"},
			"",
		},
		{
			"HELP:true",
			"-addr :2345",
			E{Addr: ":2345"},
			"db must be set",
		},
		{
			"VERSION:1",
			"-addr :2345 -db x",
			E{Addr: ":2345", DB: "x"},
			"",
		},
		{
			"",
			"-help",
			E{Help: true},
			"",
		},
		{
			"",
			"-v",
			E{Version: true},
			"",
		},
		{
			"",
			"-z",
			E{},
			"flag provided but not defined: -z",
		},
	}
	for _, c := range cases {
		t.Run(fmt.Sprintf("%s|%s", c.env, c.args), func(t *testing.T) {
			// set env vars
			if c.env != "" {
				envPairs := strings.Split(c.env, " ")
				for _, pair := range envPairs {
					ix := strings.Index(pair, ":")
					require.True(t, ix >= 0, "%s: missing colon", pair)
					key, val := pair[:ix], pair[ix+1:]
					key = strings.ToUpper(prefixFromProgramName(progName)) + "_" + key
					os.Setenv(key, val)
					defer os.Unsetenv(key)
				}
			}

			// parse args
			args := []string{progName}
			if c.args != "" {
				args = append(args, strings.Split(c.args, " ")...)
			}

			var e E
			err := p.Parse(args, &e)
			if c.errMsg != "" {
				require.Error(t, err)
				require.Contains(t, err.Error(), c.errMsg)
			} else {
				require.NoError(t, err)
			}

			require.Equal(t, c.want, e)
		})
	}
}