~mna/webparts-flag

ref: e84af31ab0aac85bc4ec57598c0bbfb6f3c23827 webparts-flag/flag.go -rw-r--r-- 2.6 KiB
e84af31aMartin Angers initial commit 1 year, 7 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Package flag implements a command-line flags parser that uses
// struct tags to configure supported flags and returns any error
// it encounters, without printing anything automatically.
//
// It uses the stdlib's flag package internally, and as such shares
// the same behaviour regarding short and long flags.
package flag

import (
	"flag"
	"fmt"
	"io/ioutil"
	"reflect"
	"strings"
	"time"
)

// Parse parses args into v, using struct tags to detect flags.
// The tag must be named "flag" and multiple flags may be set
// for the same field using a comma-separated list.
// v must be a pointer to a struct and the flags must be defined
// on fields with a type of string, int, bool or time.Duration.
//
// If v has a SetArgs method, it is called with the list
// of non-flag arguments.
//
// If v has a SetFlags method, it is called with the set of
// flags that were set by args (a map[string]bool).
//
// It panics if v is not a pointer to a struct or if a flag is
// defined with an unsupported type.
func Parse(args []string, v interface{}) error {
	// create a FlagSet that is silent and only returns any error
	// it encounters.
	fs := flag.NewFlagSet("", flag.ContinueOnError)
	fs.SetOutput(ioutil.Discard)
	fs.Usage = nil

	durationType := reflect.TypeOf(time.Duration(0))

	// extract the flags from the struct
	val := reflect.ValueOf(v).Elem()
	str := reflect.TypeOf(v).Elem()
	count := val.NumField()
	for i := 0; i < count; i++ {
		fld := val.Field(i)
		typ := str.Field(i)
		names := strings.Split(typ.Tag.Get("flag"), ",")

		for _, nm := range names {
			if nm == "" {
				continue
			}
			switch fld.Kind() {
			case reflect.Bool:
				fs.BoolVar(fld.Addr().Interface().(*bool), nm, fld.Bool(), "")
			case reflect.String:
				fs.StringVar(fld.Addr().Interface().(*string), nm, fld.String(), "")
			case reflect.Int:
				fs.IntVar(fld.Addr().Interface().(*int), nm, int(fld.Int()), "")
			default:
				switch typ.Type {
				case durationType:
					fs.DurationVar(fld.Addr().Interface().(*time.Duration), nm, fld.Interface().(time.Duration), "")
				default:
					panic(fmt.Sprintf("unsupported flag field kind: %s (%s: %s)", fld.Kind(), typ.Name, typ.Type))
				}
			}
		}
	}

	if err := fs.Parse(args); err != nil {
		return err
	}

	if sa, ok := v.(interface{ SetArgs([]string) }); ok {
		args := fs.Args()
		if len(args) == 0 {
			args = nil
		}
		sa.SetArgs(args)
	}
	if sf, ok := v.(interface{ SetFlags(map[string]bool) }); ok {
		set := make(map[string]bool)
		fs.Visit(func(fl *flag.Flag) {
			set[fl.Name] = true
		})
		if len(set) == 0 {
			set = nil
		}
		sf.SetFlags(set)
	}

	return nil
}