~shabbyrobe/apiauth

205cf9a2e39557139d3e9422fbd7eb6b1ebc0b16 — Blake Williams 3 months ago
Initial sketch
3 files changed, 217 insertions(+), 0 deletions(-)

A cmd/apiauth/main.go
A go.mod
A go.sum
A  => cmd/apiauth/main.go +210 -0
@@ 1,210 @@
package main

import (
	"bufio"
	"crypto/hmac"
	"crypto/sha1"
	"encoding/base64"
	"fmt"
	"log"
	"net/textproto"
	"net/url"
	"os"
	"os/exec"
	"strings"
	"time"

	_ "github.com/davecgh/go-spew/spew"
)

func main() {
	if err := run(); err != nil {
		log.Fatal(err)
	}
}

func parseHeader(raw string) (name, val string, ok bool) {
	br := bufio.NewReader(strings.NewReader(raw + "\r\n\r\n"))
	tp := textproto.NewReader(br)
	hdr, err := tp.ReadMIMEHeader()
	if err != nil {
		return "", "", false
	}
	for k, v := range hdr {
		return strings.ToLower(k), v[0], true
	}
	return "", "", false
}

type Args struct {
	args    []string
	outArgs []string
	idx     int
}

func NewArgs(args []string) *Args {
	return &Args{
		args: append(args, ""),
		idx:  -1,
	}
}

func (args *Args) next() (valid bool) {
	args.idx++
	return args.idx < len(args.args)-1
}

func (args *Args) keep() {
	args.outArgs = append(args.outArgs, args.args[args.idx])
}

func (args *Args) consumeFlagValue(short, long string, keep bool) (k, v string, ok bool) {
	if args.idx >= len(args.args)-1 {
		return "", "", false
	}

	flag := args.args[args.idx]
	next := args.args[args.idx+1]

	var take int
	if (short != "" && flag == short) || (long != "" && flag == long) {
		take, k, v, ok = 2, flag, next, true

	} else if idx := strings.Index(flag, short+"="); short != "" && idx >= 0 {
		args.idx += 1
		take, k, v, ok = 1, flag[0:len(short)], flag[len(short)+1:], true

	} else if idx := strings.Index(flag, long+"="); long != "" && idx >= 0 {
		args.idx += 1
		take, k, v, ok = 1, flag[0:len(long)], flag[len(long)+1:], true
	}

	if keep {
		args.outArgs = append(args.outArgs, args.args[args.idx:args.idx+take]...)
	}
	if take > 0 {
		args.idx += take - 1
	}

	return k, v, ok
}

func run() error {
	gmt, err := time.LoadLocation("Etc/GMT")
	if err != nil {
		return err
	}

	inArgs := NewArgs(os.Args[1:])

	var contentType string
	var method string
	var rawURL string
	var date string
	var dateFound bool
	var contentHash string
	var accessID string
	var accessKey string

	for inArgs.next() {
		if _, v, ok := inArgs.consumeFlagValue("-H", "--header", true); ok {
			if headerName, headerVal, ok := parseHeader(v); ok {
				if headerName == "content-type" {
					contentType = headerVal
				} else if headerName == "date" {
					date, dateFound = headerVal, true
				} else if headerName == "x-authorization-content-sha256" {
					contentHash = headerVal
				}
			}

		} else if _, v, ok := inArgs.consumeFlagValue("-X", "--request", true); ok {
			method = strings.ToUpper(v)

		} else if _, v, ok := inArgs.consumeFlagValue("", "--access-id", false); ok {
			accessID = v

		} else if _, v, ok := inArgs.consumeFlagValue("", "--access-key", false); ok {
			accessKey = v

		} else if _, v, ok := inArgs.consumeFlagValue("", "--url", true); ok {
			if rawURL != "" {
				return fmt.Errorf("--url can only be passed once")
			}
			rawURL = v

		} else {
			inArgs.keep()
		}
	}

	if rawURL == "" {
		return fmt.Errorf("--url is required")
	}
	if method == "" {
		method = "GET"
	}
	if accessID == "" {
		return fmt.Errorf("--access-id required")
	}
	if accessKey == "" {
		return fmt.Errorf("--access-key required")
	}

	u, err := url.Parse(rawURL)
	if err != nil {
		return err
	}

	uriPath := u.EscapedPath()
	if uriPath == "" {
		uriPath = "/"
	}
	if u.RawQuery != "" {
		uriPath = uriPath + "?" + u.RawQuery
	}

	if date == "" {
		// Ruby's Time.httpdate function rejects 'UTC':
		date = time.Now().In(gmt).Format(time.RFC1123)
	}

	canonical := strings.Join([]string{
		method,
		contentType,
		contentHash,
		uriPath,
		date,
	}, ",")

	mac := hmac.New(sha1.New, []byte(accessKey))
	mac.Write([]byte(canonical))
	signature := base64.StdEncoding.EncodeToString(mac.Sum(nil))

	debugCmd := "curl"
	var curlArgs []string
	for _, arg := range inArgs.outArgs {
		debugCmd += fmt.Sprintf(" %q", arg)
		curlArgs = append(curlArgs, arg)
	}

	if !dateFound {
		debugCmd += fmt.Sprintf(" -H %q", "Date: "+date)
		curlArgs = append(curlArgs, "-H", "Date: "+date)
	}

	debugCmd += fmt.Sprintf(" -H %q",
		fmt.Sprintf("Authorization: APIAuth %s:%s", accessID, signature))

	curlArgs = append(curlArgs, "-H", fmt.Sprintf("Authorization: APIAuth %s:%s", accessID, signature))

	cmd := exec.Command("curl", curlArgs...)
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	cmd.Stdin = os.Stdin
	if err := cmd.Run(); err != nil {
		return err
	}

	return nil
}

A  => go.mod +5 -0
@@ 1,5 @@
module go.shabbyrobe.org/apiauth

go 1.21.5

require github.com/davecgh/go-spew v1.1.1

A  => go.sum +2 -0
@@ 1,2 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=