~shabbyrobe/apiauth

334d1103f830453ac9692b659e965d466322315d — Blake Williams 3 months ago 6962f33 master
Add more useful stuff
A LICENSE => LICENSE +21 -0
@@ 0,0 1,21 @@
MIT License

Copyright (c) 2024 Blake Williams <code@shabbyrobe.org>

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

A README.md => README.md +24 -0
@@ 0,0 1,24 @@
# apiauth: Tools for working with the Ruby APIAuth authentication scheme in Go

The Ruby library is found here: https://github.com/mgomes/api_auth

TODO:

- Leeway in dates?
- Add new digest algos
- Add verification subcommand
- Umm, tests?


## Expectation management

This is a tool I hack on for my own amusement in an ad-hoc fashion. No
stability guarantees are made, the code is not guaranteed to work, and anything
may be changed, renamed or removed at any time as I see fit.

If you intend to depend on any of this, I strongly recommend you copy-paste
pieces as-needed (including tests and license/attribution) into your own
project, or fork it for your own purposes.

Bug reports and feature suggestions are welcome, but code contributions will
not be accepted.

A cmd/apiauth/cmd_curl.go => cmd/apiauth/cmd_curl.go +189 -0
@@ 0,0 1,189 @@
package main

import (
	"bufio"
	"errors"
	"fmt"
	"net/textproto"
	"os"
	"os/exec"
	"strings"

	"go.shabbyrobe.org/apiauth"
)

func cmdCurl(args []string) error {
	inArgs := NewArgs(args)

	var contentType string
	var method string
	var rawURL string
	var date string
	var dateFound bool
	var contentHash string
	var digestType string

	var creds apiauth.Credentials

	for inArgs.next() {
		if _, v, ok := inArgs.consumeFlagValue("-H", "--header", true); ok {
			if headerName, headerVal, ok := parseHeader(v); ok {
				if headerName == "content-type" {
					contentType = headerVal
				} else if headerName == "date" {
					date, dateFound = headerVal, true
				} else if headerName == "x-authorization-content-sha256" {
					contentHash = headerVal
				}
			}

		} else if _, v, ok := inArgs.consumeFlagValue("-X", "--request", true); ok {
			method = strings.ToUpper(v)

		} else if _, v, ok := inArgs.consumeFlagValue("", "--access-id", false); ok {
			creds.AccessID = v

		} else if _, v, ok := inArgs.consumeFlagValue("", "--secret", false); ok {
			creds.Secret = v

		} else if _, v, ok := inArgs.consumeFlagValue("", "--digest-type", false); ok {
			// NOTE: this needs to be different to the other commands because '--digest' is a cURL option:
			digestType = v

		} else if _, v, ok := inArgs.consumeFlagValue("", "--url", true); ok {
			if rawURL != "" {
				return fmt.Errorf("--url can only be passed once")
			}
			rawURL = v

		} else {
			inArgs.keep()
		}
	}

	if rawURL == "" {
		return fmt.Errorf("--url is required")
	}
	if method == "" {
		method = "GET"
	}
	if creds.AccessID == "" {
		return fmt.Errorf("--access-id required")
	}
	if creds.Secret == "" {
		return fmt.Errorf("--secret required")
	}

	date, _ = apiauth.CoalesceRawDate(date)

	signer := apiauth.Signer{
		Creds:      creds,
		DigestType: apiauth.DigestType(digestType),
	}

	signInput, err := apiauth.SignInputFromRawValues(
		method,
		contentType,
		contentHash,
		rawURL,
		date,
	)
	if errors.Is(err, apiauth.ErrDateRequired) {
		return fmt.Errorf("--date required")
	} else if err != nil {
		return err
	}

	headerValue, err := signer.AuthHeaderValue(signInput)
	if err != nil {
		return err
	}

	debugCmd := "curl"
	var curlArgs []string
	for _, arg := range inArgs.outArgs {
		debugCmd += fmt.Sprintf(" %q", arg)
		curlArgs = append(curlArgs, arg)
	}

	if !dateFound {
		debugCmd += fmt.Sprintf(" -H %q", "Date: "+date)
		curlArgs = append(curlArgs, "-H", "Date: "+date)
	}

	debugCmd += fmt.Sprintf(" -H %q", fmt.Sprintf("Authorization: %s", headerValue))
	curlArgs = append(curlArgs, "-H", fmt.Sprintf("Authorization: %s", headerValue))

	cmd := exec.Command("curl", curlArgs...)
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	cmd.Stdin = os.Stdin
	if err := cmd.Run(); err != nil {
		return err
	}

	return nil
}
func parseHeader(raw string) (name, val string, ok bool) {
	br := bufio.NewReader(strings.NewReader(raw + "\r\n\r\n"))
	tp := textproto.NewReader(br)
	hdr, err := tp.ReadMIMEHeader()
	if err != nil {
		return "", "", false
	}
	for k, v := range hdr {
		return strings.ToLower(k), v[0], true
	}
	return "", "", false
}

type Args struct {
	args    []string
	outArgs []string
	idx     int
}

func NewArgs(args []string) *Args {
	return &Args{
		args: append(args, ""),
		idx:  -1,
	}
}

func (args *Args) next() (valid bool) {
	args.idx++
	return args.idx < len(args.args)-1
}

func (args *Args) keep() {
	args.outArgs = append(args.outArgs, args.args[args.idx])
}

func (args *Args) consumeFlagValue(short, long string, keep bool) (k, v string, ok bool) {
	if args.idx >= len(args.args)-1 {
		return "", "", false
	}

	flag := args.args[args.idx]
	next := args.args[args.idx+1]

	var take int
	if (short != "" && flag == short) || (long != "" && flag == long) {
		take, k, v, ok = 2, flag, next, true

	} else if idx := strings.Index(flag, short+"="); short != "" && idx >= 0 {
		take, k, v, ok = 1, flag[0:len(short)], flag[len(short)+1:], true

	} else if idx := strings.Index(flag, long+"="); long != "" && idx >= 0 {
		take, k, v, ok = 1, flag[0:len(long)], flag[len(long)+1:], true
	}

	if keep {
		args.outArgs = append(args.outArgs, args.args[args.idx:args.idx+take]...)
	}
	if take > 0 {
		args.idx += take - 1
	}

	return k, v, ok
}

A cmd/apiauth/cmd_header.go => cmd/apiauth/cmd_header.go +97 -0
@@ 0,0 1,97 @@
package main

import (
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"time"

	"go.shabbyrobe.org/apiauth"
)

func cmdHeader(args []string) error {
	var inputFlags signInputFlags
	var creds apiauth.Credentials
	var outfmt string

	var fs = flag.NewFlagSet("apiauth header", 0)
	fs.StringVar(&outfmt, "fmt", "value", ""+
		"Output format. 'value' (header value only), 'gen' (full header, including 'Date' if --date "+
		"not passed, 'full' (like 'gen', but always, shows 'Date'), 'json'")

	inputFlags.Attach(fs, "")
	attachCredsFlags(&creds, fs, "")

	if err := fs.Parse(args); err != nil {
		return err
	}
	if len(fs.Args()) > 0 {
		return fmt.Errorf("unexpected args")
	}

	coalesceEnvCreds(&creds, "")

	if err := creds.Validate(); errors.Is(err, apiauth.ErrCredentialsMissingAccessID) {
		return fmt.Errorf("--access-id required")
	} else if err := creds.Validate(); errors.Is(err, apiauth.ErrCredentialsMissingSecret) {
		return fmt.Errorf("--secret required")
	} else if err != nil {
		return err
	}

	var dateGenerated bool
	if outfmt != "value" && inputFlags.date.IsZero() {
		inputFlags.date.Time = time.Now()
		dateGenerated = true
	}

	signInput, err := inputFlags.SignInput()
	if errors.Is(err, apiauth.ErrDateRequired) {
		return fmt.Errorf("--date required if --fmt is 'value'; pass a --date or use another --fmt (i.e. 'full')")
	} else if err != nil {
		return err
	}
	if err := signInput.Validate(); err != nil {
		return err
	}

	signer := apiauth.Signer{
		Creds:      creds,
		DigestType: inputFlags.DigestType(),
	}
	headerValue, err := signer.AuthHeaderValue(signInput)
	if err != nil {
		return err
	}

	switch outfmt {
	case "value":
		fmt.Println(headerValue)

	case "gen":
		fmt.Println("Authorization: " + headerValue)
		if dateGenerated {
			fmt.Println("Date: " + apiauth.FormatDate(inputFlags.date.Time))
		}

	case "full":
		fmt.Println("Authorization: " + headerValue)
		fmt.Println("Date: " + apiauth.FormatDate(inputFlags.date.Time))

	case "json":
		bts, _ := json.MarshalIndent(struct {
			Authorization string
			Date          string
		}{
			Authorization: headerValue,
			Date:          apiauth.FormatDate(inputFlags.date.Time),
		}, "", "  ")
		fmt.Println(string(bts))

	default:
		return fmt.Errorf("unknown --fmt")
	}

	return nil
}

A cmd/apiauth/cmd_proxy.go => cmd/apiauth/cmd_proxy.go +74 -0
@@ 0,0 1,74 @@
package main

import (
	"flag"
	"fmt"
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"time"

	"go.shabbyrobe.org/apiauth"
)

func cmdProxy(args []string) error {
	var inputFlags signInputFlags
	var creds apiauth.Credentials
	var host string
	var rawUpstream string

	var fs = flag.NewFlagSet("apiauth proxy", 0)
	fs.StringVar(&host, "host", "127.0.0.1:0", "Listen (use :0 to assign arbitrary port)")
	fs.StringVar(&rawUpstream, "upstream", "", "Proxy this URL")

	inputFlags.Attach(fs, "")
	attachCredsFlags(&creds, fs, "")

	if err := fs.Parse(args); err != nil {
		return err
	}
	if len(fs.Args()) > 0 {
		return fmt.Errorf("unexpected args")
	}

	coalesceEnvCreds(&creds, "")

	upstream, err := url.Parse(rawUpstream)
	if err != nil {
		return fmt.Errorf("apiauth: --upstream url %q is invalid: %w", rawUpstream, err)
	} else if upstream.Scheme == "" {
		return fmt.Errorf("apiauth: --upstream url %q does not contain a scheme", rawUpstream)
	}

	signer := apiauth.Signer{
		Creds:      creds,
		DigestType: inputFlags.DigestType(),
	}

	proxy := &httputil.ReverseProxy{
		FlushInterval: 1 * time.Second,
		Rewrite: func(rq *httputil.ProxyRequest) {
			if err := signer.SignRequest(rq.Out); err != nil {
				log.Println(err)
			}

			rq.SetURL(upstream)
		},
	}

	srv := http.Server{
		Addr:    host,
		Handler: proxy,
	}

	ln, err := net.Listen("tcp", host)
	if err != nil {
		return err
	}

	log.Println("Listening:", ln.Addr())

	return srv.Serve(ln)
}

A cmd/apiauth/cmd_sign.go => cmd/apiauth/cmd_sign.go +58 -0
@@ 0,0 1,58 @@
package main

import (
	"errors"
	"flag"
	"fmt"

	"go.shabbyrobe.org/apiauth"
)

func cmdSign(args []string) error {
	var inputFlags signInputFlags
	var creds apiauth.Credentials

	var fs = flag.NewFlagSet("apiauth sign", 0)
	inputFlags.Attach(fs, "")
	attachCredsFlags(&creds, fs, "")

	if err := fs.Parse(args); err != nil {
		return err
	}
	if len(fs.Args()) > 0 {
		return fmt.Errorf("unexpected args")
	}

	coalesceEnvCreds(&creds, "")

	if err := creds.Validate(); errors.Is(err, apiauth.ErrCredentialsMissingAccessID) {
		// This is fine
	} else if err := creds.Validate(); errors.Is(err, apiauth.ErrCredentialsMissingSecret) {
		return fmt.Errorf("--secret required")
	} else if err != nil {
		return err
	}

	signInput, err := inputFlags.SignInput()
	if errors.Is(err, apiauth.ErrDateRequired) {
		return fmt.Errorf("--date required")
	} else if err != nil {
		return err
	}
	if err := signInput.Validate(); err != nil {
		return err
	}

	signer := apiauth.Signer{
		Creds:      creds,
		DigestType: inputFlags.DigestType(),
	}
	signature, err := signer.SignInput(signInput)
	if err != nil {
		return err
	}

	fmt.Println(signature)

	return nil
}

A cmd/apiauth/date.go => cmd/apiauth/date.go +55 -0
@@ 0,0 1,55 @@
package main

import (
	"fmt"
	"time"
)

type LooseDate struct {
	time.Time
}

func (ld LooseDate) Type() string { return "string" }

func (ld *LooseDate) Set(s string) error {
	dt, err := parseLooseDate(s)
	if err != nil {
		return err
	}
	*ld = LooseDate{dt}
	return nil
}

var looseDatePatterns = []string{
	time.RFC3339Nano,
	time.RFC3339,
	time.UnixDate,
	time.RubyDate,
	time.RFC1123Z,
	time.RFC1123,
	"Mon _2 Jan 2006 15:04:05 MST",
	"2006-01-02T15:04:05.999999",
	"2006-01-02T15:04:05.999",
	"2006-01-02T15:04:05Z",
	"2006-01-02T15:04:05",
	"2006-01-02T15:04Z",
	"2006-01-02T15:04",
	"20060102T15:04:05Z07:00",
	"20060102T15:04:05.999999999Z07:00",
	"20060102T15:04:05.999999",
	"20060102T15:04:05.999",
	"20060102T15:04:05Z",
	"20060102T15:04:05",
	"20060102T15:04Z",
	"20060102T15:04",
}

func parseLooseDate(s string) (t time.Time, err error) {
	for _, ptn := range looseDatePatterns {
		t, err := time.Parse(ptn, s)
		if err == nil {
			return t, nil
		}
	}
	return t, fmt.Errorf("date %q cannot be parsed using any known format", s)
}

A cmd/apiauth/flags.go => cmd/apiauth/flags.go +65 -0
@@ 0,0 1,65 @@
package main

import (
	"flag"
	"os"

	"go.shabbyrobe.org/apiauth"
)

type signInputFlags struct {
	contentType   string
	method        string
	rawURL        string
	date          LooseDate
	contentHash   string
	rawDigestType string
}

func (flags signInputFlags) DigestType() apiauth.DigestType {
	return apiauth.DigestType(flags.rawDigestType)
}

func (flags *signInputFlags) Attach(fs *flag.FlagSet, prefix string) {
	fs.StringVar(&flags.contentType, prefix+"content-type", "", "")
	fs.StringVar(&flags.method, prefix+"method", "GET", "")
	fs.StringVar(&flags.rawURL, prefix+"url", "", "")
	fs.Var(&flags.date, prefix+"date", "")
	fs.StringVar(&flags.contentHash, prefix+"content-hash", "", "")
	fs.StringVar(&flags.rawDigestType, prefix+"digest", "", "Digest algo (sha1, sha256; default is 'sha1')")
}

func (flags *signInputFlags) SignInput() (apiauth.SignInput, error) {
	date := ""
	if !flags.date.IsZero() {
		date = apiauth.FormatDate(flags.date.Time)
	}
	return apiauth.SignInputFromRawValues(
		flags.method,
		flags.contentType,
		flags.contentHash,
		flags.rawURL,
		date,
	)
}

func attachCredsFlags(
	creds *apiauth.Credentials,
	fs *flag.FlagSet,
	prefix string,
) {
	fs.StringVar(&creds.AccessID, "access-id", "", "Visible portion of the access ID used in the Authorization header.")
	fs.StringVar(&creds.Secret, "secret", "", "Secret used to create the HMAC.")
}

func coalesceEnvCreds(
	creds *apiauth.Credentials,
	prefix string,
) {
	if creds.AccessID == "" {
		creds.AccessID = os.Getenv(prefix + "APIAUTH_ACCESS_ID")
	}
	if creds.Secret == "" {
		creds.Secret = os.Getenv(prefix + "APIAUTH_SECRET")
	}
}

M cmd/apiauth/main.go => cmd/apiauth/main.go +45 -161
@@ 1,189 1,73 @@
package main

import (
	"bufio"
	_ "embed"
	"errors"
	"fmt"
	"log"
	"net/textproto"
	"os"
	"os/exec"
	"strings"

	_ "github.com/davecgh/go-spew/spew"
	"go.shabbyrobe.org/apiauth"
)

func main() {
	if err := run(); err != nil {
		log.Fatal(err)
	}
}
//go:embed splash.txt
var splash string

func parseHeader(raw string) (name, val string, ok bool) {
	br := bufio.NewReader(strings.NewReader(raw + "\r\n\r\n"))
	tp := textproto.NewReader(br)
	hdr, err := tp.ReadMIMEHeader()
	if err != nil {
		return "", "", false
	}
	for k, v := range hdr {
		return strings.ToLower(k), v[0], true
	}
	return "", "", false
}
var usage = strings.TrimSpace(`
{splash}

type Args struct {
	args    []string
	outArgs []string
	idx     int
}
apiauth: Tools for working with the Ruby APIAuth authentication scheme.

func NewArgs(args []string) *Args {
	return &Args{
		args: append(args, ""),
		idx:  -1,
	}
}
The Ruby library is found here: https://github.com/mgomes/api_auth

func (args *Args) next() (valid bool) {
	args.idx++
	return args.idx < len(args.args)-1
}
Commands:
  curl     Best-effort wrapper for a curl invocation, adding APIAuth headers
  header   Create an APIAuth Authorization header value
  proxy    Proxy unauthenticated requests to an APIAuth-authenticated server
  sign     Create an APIAuth signature

func (args *Args) keep() {
	args.outArgs = append(args.outArgs, args.args[args.idx])
}
Caveats:
  - The 'curl' subcommand does not support positional arguments; it only supports
    using the --url flag to specify the request URL.
  - The 'curl' subcommand does not try to be perfect. It's less likely to break
    down for simple use cases, but most things should be supported.
  - The 'proxy' subcommand is absolutely not a production server.
`)

func (args *Args) consumeFlagValue(short, long string, keep bool) (k, v string, ok bool) {
	if args.idx >= len(args.args)-1 {
		return "", "", false
	}

	flag := args.args[args.idx]
	next := args.args[args.idx+1]

	var take int
	if (short != "" && flag == short) || (long != "" && flag == long) {
		take, k, v, ok = 2, flag, next, true

	} else if idx := strings.Index(flag, short+"="); short != "" && idx >= 0 {
		args.idx += 1
		take, k, v, ok = 1, flag[0:len(short)], flag[len(short)+1:], true
func init() {
	usage = strings.Replace(usage, "{splash}\n", splash, 1)
}

	} else if idx := strings.Index(flag, long+"="); long != "" && idx >= 0 {
		args.idx += 1
		take, k, v, ok = 1, flag[0:len(long)], flag[len(long)+1:], true
	}
var ErrUsage = errors.New("usage")

	if keep {
		args.outArgs = append(args.outArgs, args.args[args.idx:args.idx+take]...)
	}
	if take > 0 {
		args.idx += take - 1
func main() {
	if err := run(); errors.Is(err, ErrUsage) {
		fmt.Println(usage)
		os.Exit(2)
	} else if err != nil {
		log.Fatal(err)
	}

	return k, v, ok
}

func run() error {
	inArgs := NewArgs(os.Args[1:])

	var contentType string
	var method string
	var rawURL string
	var date string
	var dateFound bool
	var contentHash string
	args := os.Args[1:]

	var creds apiauth.Credentials

	for inArgs.next() {
		if _, v, ok := inArgs.consumeFlagValue("-H", "--header", true); ok {
			if headerName, headerVal, ok := parseHeader(v); ok {
				if headerName == "content-type" {
					contentType = headerVal
				} else if headerName == "date" {
					date, dateFound = headerVal, true
				} else if headerName == "x-authorization-content-sha256" {
					contentHash = headerVal
				}
			}

		} else if _, v, ok := inArgs.consumeFlagValue("-X", "--request", true); ok {
			method = strings.ToUpper(v)

		} else if _, v, ok := inArgs.consumeFlagValue("", "--access-id", false); ok {
			creds.AccessID = v

		} else if _, v, ok := inArgs.consumeFlagValue("", "--access-key", false); ok {
			creds.Secret = v

		} else if _, v, ok := inArgs.consumeFlagValue("", "--url", true); ok {
			if rawURL != "" {
				return fmt.Errorf("--url can only be passed once")
			}
			rawURL = v

		} else {
			inArgs.keep()
		}
	cmd := ""
	if len(args) > 0 {
		cmd = args[0]
	}

	if rawURL == "" {
		return fmt.Errorf("--url is required")
	switch cmd {
	case "curl":
		return cmdCurl(args[1:])
	case "header":
		return cmdHeader(args[1:])
	case "proxy":
		return cmdProxy(args[1:])
	case "sign":
		return cmdSign(args[1:])
	default:
		return ErrUsage
	}
	if method == "" {
		method = "GET"
	}
	if creds.AccessID == "" {
		return fmt.Errorf("--access-id required")
	}
	if creds.Secret == "" {
		return fmt.Errorf("--access-key required")
	}

	date, _ = apiauth.CoalesceRawDate(date)

	signer := apiauth.Signer{Creds: creds}

	signInput, err := apiauth.SignInputFromRawValues(
		method,
		contentType,
		contentHash,
		rawURL,
		date,
	)
	if err != nil {
		return err
	}

	headerValue, err := signer.AuthHeaderValue(signInput)
	if err != nil {
		return err
	}

	debugCmd := "curl"
	var curlArgs []string
	for _, arg := range inArgs.outArgs {
		debugCmd += fmt.Sprintf(" %q", arg)
		curlArgs = append(curlArgs, arg)
	}

	if !dateFound {
		debugCmd += fmt.Sprintf(" -H %q", "Date: "+date)
		curlArgs = append(curlArgs, "-H", "Date: "+date)
	}

	debugCmd += fmt.Sprintf(" -H %q", fmt.Sprintf("Authorization: %s", headerValue))
	curlArgs = append(curlArgs, "-H", fmt.Sprintf("Authorization: %s", headerValue))

	cmd := exec.Command("curl", curlArgs...)
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	cmd.Stdin = os.Stdin
	if err := cmd.Run(); err != nil {
		return err
	}

	return nil
}

A cmd/apiauth/splash.txt => cmd/apiauth/splash.txt +7 -0
@@ 0,0 1,7 @@

░█████╗░██████╗░██╗░█████╗░██╗░░░██╗████████╗██╗░░██╗ 
██╔══██╗██╔══██╗██║██╔══██╗██║░░░██║╚══██╔══╝██║░░██║    .--.             
███████║██████╔╝██║███████║██║░░░██║░░░██║░░░███████║   /.-. '----------.
██╔══██║██╔═══╝░██║██╔══██║██║░░░██║░░░██║░░░██╔══██║   \'-' .--"--""-"-'
██║░░██║██║░░░░░██║██║░░██║╚██████╔╝░░░██║░░░██║░░██║    '--'        jgs
╚═╝░░╚═╝╚═╝░░░░░╚═╝╚═╝░░╚═╝░╚═════╝░░░░╚═╝░░░╚═╝░░╚═╝ 

A creds.go => creds.go +25 -0
@@ 0,0 1,25 @@
package apiauth

import (
	"errors"
)

type Credentials struct {
	AccessID string
	Secret   string
}

func (creds Credentials) Validate() (rerr error) {
	if creds.AccessID == "" {
		rerr = errors.Join(rerr, ErrCredentialsMissingAccessID)
	}
	if creds.Secret == "" {
		rerr = errors.Join(rerr, ErrCredentialsMissingSecret)
	}
	return rerr
}

var (
	ErrCredentialsMissingAccessID = errors.New("apiauth: credentials must have an access ID")
	ErrCredentialsMissingSecret   = errors.New("apiauth: credentials must have a secret")
)

A digest.go => digest.go +43 -0
@@ 0,0 1,43 @@
package apiauth

import (
	"crypto/sha1"
	"crypto/sha256"
	"hash"
)

type DigestType string

const (
	DigestSHA1   DigestType = "sha1"
	DigestSHA256 DigestType = "sha256"
)

func (dtyp DigestType) Validate() error {
	if _, err := dtyp.Ident(); err != nil {
		return err
	}
	return nil
}

func (dtyp DigestType) Hasher() (func() hash.Hash, error) {
	switch dtyp {
	case DigestSHA1, "":
		return sha1.New, nil
	case DigestSHA256:
		return sha256.New, nil
	default:
		return nil, &ErrDigestTypeUnknown{dtyp}
	}
}

func (dtyp DigestType) Ident() (string, error) {
	switch dtyp {
	case DigestSHA1, "":
		return "APIAuth", nil
	case DigestSHA256:
		return "APIAuth-HMAC-SHA256", nil
	default:
		return "", &ErrDigestTypeUnknown{dtyp}
	}
}

A errors.go => errors.go +41 -0
@@ 0,0 1,41 @@
package apiauth

import (
	"errors"
	"fmt"
)

var (
	ErrDateRequired       = errors.New("apiauth: date must not be empty")
	ErrDateHeaderRequired = errors.New("apiauth: 'Date' header must be present in request")
	ErrURLRequired        = errors.New("apiauth: url must not be empty")
	ErrMethodRequired     = errors.New("apiauth: method must not be empty")
)

type ErrDateInvalid struct {
	Input string
	Inner error
}

func (err *ErrDateInvalid) Unwrap() error { return err.Inner }
func (err *ErrDateInvalid) Error() string {
	return fmt.Sprintf("apiauth: date %q is not in RFC1123 format: %s", err.Input, err.Inner)
}

type ErrURLInvalid struct {
	Input string
	Inner error
}

func (err *ErrURLInvalid) Unwrap() error { return err.Inner }
func (err *ErrURLInvalid) Error() string {
	return fmt.Sprintf("apiauth: url %q is not valid: %s", err.Input, err.Inner)
}

type ErrDigestTypeUnknown struct {
	DigestType DigestType
}

func (err *ErrDigestTypeUnknown) Error() string {
	return fmt.Sprintf("apiauth: digest type %q is unknown", err.DigestType)
}

M sign.go => sign.go +40 -30
@@ 2,37 2,52 @@ package apiauth

import (
	"crypto/hmac"
	"crypto/sha1"
	"encoding/base64"
	"errors"
	"fmt"
	"net/http"
	"net/url"
	"time"
)

type Credentials struct {
	AccessID string
	Secret   string
}

const AuthContentHeader = "x-authorization-content-sha256"

type Signer struct {
	Creds Credentials
	Creds      Credentials
	DigestType DigestType
}

func (s Signer) AuthHeaderValue(input SignInput) (string, error) {
	sig, err := input.Signature(s.Creds)
	sig, err := s.SignInput(input)
	if err != nil {
		return "", err
	}

	header := sig.AuthHeaderValue(s.Creds)
	return header, nil
	return sig.AuthHeaderValue(s.Creds)
}

func (s Signer) SignInput(input SignInput) (Signature, error) {
	return s.Sign(input.CanonicalString())
}

func (s Signer) Sign(canonical CanonicalString) (Signature, error) {
	hasher, err := s.DigestType.Hasher()
	if err != nil {
		return Signature{}, err
	}

	mac := hmac.New(hasher, []byte(s.Creds.Secret))
	mac.Write([]byte(canonical))
	digest := mac.Sum(nil)

	signature := base64.StdEncoding.EncodeToString(digest)
	return Signature{
		Value:      signature,
		AccessID:   s.Creds.AccessID,
		DigestType: s.DigestType,
	}, nil
}

func (s Signer) Sign(rq *http.Request) error {
func (s Signer) SignRequest(rq *http.Request) error {
	if dt, coalesced := CoalesceRawDate(rq.Header.Get("Date")); coalesced {
		rq.Header.Set("Date", dt)
	}


@@ 50,7 65,6 @@ func (s Signer) Sign(rq *http.Request) error {
	rq.Header.Set("Authorization", header)

	return nil

}

type SignInput struct {


@@ 65,17 79,21 @@ func SignInputFromRawValues(
	method string,
	contentType string,
	contentHash string,
	rawUrl string,
	rawURL string,
	date string,
) (SignInput, error) {
	if date == "" {
		return SignInput{}, ErrDateRequired
	}

	dt, err := time.Parse(time.RFC1123, date)
	if err != nil {
		return SignInput{}, fmt.Errorf("apiauth: date %q is not in RFC1123 format: %w", date, err)
		return SignInput{}, &ErrDateInvalid{Input: date, Inner: err}
	}

	u, err := url.Parse(rawUrl)
	u, err := url.Parse(rawURL)
	if err != nil {
		return SignInput{}, err
		return SignInput{}, &ErrURLInvalid{Input: rawURL, Inner: err}
	}

	return SignInput{


@@ 97,12 115,12 @@ func SignInputFromRequest(rq *http.Request) (SignInput, error) {

	rawDate := rq.Header.Get("Date")
	if rawDate == "" {
		return SignInput{}, fmt.Errorf("apiauth: Date header must be present in request")
		return SignInput{}, ErrDateHeaderRequired
	}

	dt, err := time.Parse(time.RFC1123, rawDate)
	if err != nil {
		return SignInput{}, fmt.Errorf("apiauth: date %q is not in RFC1123 format: %w", rawDate, err)
		return SignInput{}, &ErrDateInvalid{Input: rawDate, Inner: err}
	}
	si.Date = dt



@@ 111,13 129,13 @@ func SignInputFromRequest(rq *http.Request) (SignInput, error) {

func (si SignInput) Validate() (rerr error) {
	if si.URL == nil {
		rerr = errors.Join(rerr, fmt.Errorf("apiauth: URL is required"))
		rerr = errors.Join(rerr, ErrURLRequired)
	}
	if si.Date.IsZero() {
		rerr = errors.Join(rerr, fmt.Errorf("apiauth: Date is required"))
		rerr = errors.Join(rerr, ErrDateRequired)
	}
	if si.Method == "" {
		rerr = errors.Join(rerr, fmt.Errorf("apiauth: Method is required"))
		rerr = errors.Join(rerr, ErrMethodRequired)
	}
	return rerr
}


@@ 131,11 149,3 @@ func (si SignInput) CanonicalString() CanonicalString {
		FormatDate(si.Date),
	)
}

func (si SignInput) Signature(creds Credentials) (Signature, error) {
	canonical := si.CanonicalString()
	mac := hmac.New(sha1.New, []byte(creds.Secret))
	mac.Write([]byte(canonical))
	signature := base64.StdEncoding.EncodeToString(mac.Sum(nil))
	return Signature(signature), nil
}

M signature.go => signature.go +11 -3
@@ 2,8 2,16 @@ package apiauth

import "fmt"

type Signature string
type Signature struct {
	DigestType DigestType
	AccessID   string
	Value      string
}

func (sig Signature) AuthHeaderValue(creds Credentials) string {
	return fmt.Sprintf("APIAuth %s:%s", creds.AccessID, sig)
func (sig Signature) AuthHeaderValue(creds Credentials) (string, error) {
	ident, err := sig.DigestType.Ident()
	if err != nil {
		return "", err
	}
	return fmt.Sprintf("%s %s:%s", ident, creds.AccessID, sig.Value), nil
}

A verify.go => verify.go +98 -0
@@ 0,0 1,98 @@
package apiauth

import (
	"errors"
	"fmt"
	"net/http"
	"regexp"
	"strings"
)

func (s Signer) VerifySignature(canonical CanonicalString, sig Signature) error {
	check, err := s.Sign(canonical)
	if err != nil {
		return fmt.Errorf("apiauth: could not verify signature: %w", err)
	}
	if check != sig {
		return ErrSignatureMismatch
	}
	return nil
}

func (s Signer) VerifyRequest(rq *http.Request) error {
	if _, ok := rq.Header["authorization"]; !ok {
		return ErrAuthorizationHeaderMissing
	}

	auth := rq.Header.Get("authorization")
	sig, err := ParseAuthorizationHeader(auth)
	if err != nil {
		return err
	}

	input, err := SignInputFromRequest(rq)
	if err != nil {
		return err
	}

	canonical := input.CanonicalString()

	return s.VerifySignature(canonical, sig)
}

func ParseAuthorizationHeader(auth string) (Signature, error) {
	match := authHeaderMatcher.FindStringSubmatch(auth)
	if match == nil {
		return Signature{}, &ErrAuthorizationHeaderFormatInvalid{Header: auth}
	}

	value := match[authHeaderValueIdx]
	parts := strings.SplitN(value, ":", 2)
	if len(parts) != 2 {
		return Signature{}, &ErrAuthorizationHeaderFormatInvalid{Header: auth}
	}

	digestType := DigestType(match[authHeaderAlgoIdx])
	if err := digestType.Validate(); err != nil {
		return Signature{}, fmt.Errorf("apiauth: could not create digest from Authorization header value %q: %w", auth, err)
	}

	return Signature{
		DigestType: digestType,
		AccessID:   parts[1],
		Value:      parts[2],
	}, nil
}

var (
	ErrAuthorizationHeaderMissing = errors.New("apiauth: authorization header missing")
	ErrSignatureMismatch          = errors.New("apiauth: verify signature mismatch")
)

type ErrAuthorizationHeaderFormatInvalid struct {
	Header string
}

func (err *ErrAuthorizationHeaderFormatInvalid) Error() string {
	return fmt.Sprintf("apiauth: 'Authorization' header %q is not in the expected format ('APIAuth access-id:signature' or 'APIAuth-HMAC-ALGO access-id:signature')", err.Header)
}

var (
	authHeaderMatcher = regexp.MustCompile(`` +
		`^` +
		`APIAuth` +
		`(-HMAC-(?P<algo>[^ ]+))?` +
		`[ ]+` +
		`(?P<value>.*)` +
		`[ \t]*` +
		`$` +
		``)
	authHeaderAlgoIdx  = authHeaderMatcher.SubexpIndex("algo")
	authHeaderValueIdx = authHeaderMatcher.SubexpIndex("value")
)

func init() {
	if authHeaderAlgoIdx < 0 || authHeaderValueIdx < 0 {
		panic("broken matcher")
	}
}