~mna/webparts-flag

ref: 306cfcede5b817c470e3564493265d6125d97956 webparts-flag/flag.go -rw-r--r-- 4.3 KiB
306cfcedMartin Angers implement parsing with and without env vars 1 year, 7 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Package flag implements a command-line flags parser that uses struct tags to
// configure supported flags and returns any error it encounters, without
// printing anything automatically. It can optionally read flag values from
// environment variables first, with the command-line flags used to override
// them.
//
// The struct tag to specify flags is `flag`, while the one to specify
// environment variables is `envconfig`. See the envconfig package for full
// details on struct tags configuration and decoding support:
// https://github.com/kelseyhightower/envconfig.
//
// Flag parsing uses the stdlib's flag package internally, and as such shares
// the same behaviour regarding short and long flags.
package flag

import (
	stdflag "flag"
	"fmt"
	"io/ioutil"
	"reflect"
	"strings"
	"time"

	"git.sr.ht/~mna/webparts/flag"
	"github.com/kelseyhightower/envconfig"
)

var _ flag.Parser = (*Parser)(nil)

// Parser implements a flag parser.
type Parser struct {
	// EnvVars indicates if environment variables are used to read flag values.
	EnvVars bool

	// EnvPrefix is the prefix to use in front of each flag's environment
	// variable name. If it is empty, the name of the program (as read from the
	// args slice at index 0) is used, with dashes replaced with underscores.
	EnvPrefix string
}

// Parse parses args into v, using struct tags to detect flags.  The tag must
// be named "flag" and multiple flags may be set for the same field using a
// comma-separated list.  v must be a pointer to a struct and the flags must be
// defined on fields with a type of string, int, bool or time.Duration.
// If Parser.EnvVars is true, flag values are initialized from corresponding
// environment variables first.
//
// After parsing, if v implements a Validate method that returns an error, it
// is called and any non-nil error is returned as error.
//
// If v has a SetArgs method, it is called with the list of non-flag arguments.
//
// If v has a SetFlags method, it is called with the set of flags that were set
// by args (a map[string]bool).
//
// It panics if v is not a pointer to a struct or if a flag is defined with an
// unsupported type.
func (p *Parser) Parse(args []string, v interface{}) error {
	if p.EnvVars {
		if err := p.parseEnvVars(args, v); err != nil {
			return err
		}
	}

	if err := p.parseFlags(args, v); err != nil {
		return err
	}

	if val, ok := v.(interface{ Validate() error }); ok {
		return val.Validate()
	}
	return nil
}

func (p *Parser) parseFlags(args []string, v interface{}) error {
	// create a FlagSet that is silent and only returns any error
	// it encounters.
	fs := stdflag.NewFlagSet("", stdflag.ContinueOnError)
	fs.SetOutput(ioutil.Discard)
	fs.Usage = nil

	durationType := reflect.TypeOf(time.Duration(0))

	// extract the flags from the struct
	val := reflect.ValueOf(v).Elem()
	str := reflect.TypeOf(v).Elem()
	count := val.NumField()
	for i := 0; i < count; i++ {
		fld := val.Field(i)
		typ := str.Field(i)
		names := strings.Split(typ.Tag.Get("flag"), ",")

		for _, nm := range names {
			if nm == "" {
				continue
			}
			switch fld.Kind() {
			case reflect.Bool:
				fs.BoolVar(fld.Addr().Interface().(*bool), nm, fld.Bool(), "")
			case reflect.String:
				fs.StringVar(fld.Addr().Interface().(*string), nm, fld.String(), "")
			case reflect.Int:
				fs.IntVar(fld.Addr().Interface().(*int), nm, int(fld.Int()), "")
			default:
				switch typ.Type {
				case durationType:
					fs.DurationVar(fld.Addr().Interface().(*time.Duration), nm, fld.Interface().(time.Duration), "")
				default:
					panic(fmt.Sprintf("unsupported flag field kind: %s (%s: %s)", fld.Kind(), typ.Name, typ.Type))
				}
			}
		}
	}

	if err := fs.Parse(args); err != nil {
		return err
	}

	if sa, ok := v.(interface{ SetArgs([]string) }); ok {
		args := fs.Args()
		if len(args) == 0 {
			args = nil
		}
		sa.SetArgs(args)
	}
	if sf, ok := v.(interface{ SetFlags(map[string]bool) }); ok {
		set := make(map[string]bool)
		fs.Visit(func(fl *stdflag.Flag) {
			set[fl.Name] = true
		})
		if len(set) == 0 {
			set = nil
		}
		sf.SetFlags(set)
	}

	return nil
}

func (p *Parser) parseEnvVars(args []string, v interface{}) error {
	prefix := p.EnvPrefix

	if prefix == "" && len(args) > 0 {
		prefix = prefixFromProgramName(args[0])
	}
	return envconfig.Process(prefix, v)
}

func prefixFromProgramName(name string) string {
	return strings.ReplaceAll(name, "-", "_")
}