~emersion/sinwon

476a35955f666ccb1d01fdacb0a52b80115bf214 — Simon Ser 7 months ago d9c1ec8
Add support for refresh tokens

Closes: https://todo.sr.ht/~emersion/sinwon/21
6 files changed, 202 insertions(+), 104 deletions(-)

M db.go
M entity.go
M middleware.go
M oauth2.go
M schema.sql
M user.go
M db.go => db.go +21 -5
@@ 15,6 15,11 @@ var schema string

var migrations = []string{
	"", // migration #0 is reserved for schema initialization
	`
		ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB;
		ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime;
		CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash);
	`,
}

var errNoDBRows = sql.ErrNoRows


@@ 210,7 215,7 @@ func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]Auth
		SELECT id, client_id, client_name, client_uri, token.expires_at
		FROM Client,
		(
			SELECT client, MAX(expires_at) as expires_at
			SELECT client, MAX(COALESCE(refresh_expires_at, expires_at)) as expires_at
			FROM AccessToken
			WHERE user = ?
			GROUP BY client


@@ 255,10 260,21 @@ func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*Acces
	return &token, err
}

func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error {
	return db.db.QueryRowContext(ctx, `
		INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
		VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
		INSERT INTO AccessToken(id, hash, user, client, scope, issued_at,
			expires_at, refresh_hash, refresh_expires_at)
		VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at,
			:refresh_hash, :refresh_expires_at)
		ON CONFLICT(id) DO UPDATE SET
			hash = :hash,
			user = :user,
			client = :client,
			scope = :scope,
			issued_at = :issued_at,
			expires_at = :expires_at,
			refresh_hash = :refresh_hash,
			refresh_expires_at = :refresh_expires_at
		RETURNING id
	`, entityArgs(token)...).Scan(&token.ID)
}


@@ 301,7 317,7 @@ func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, err
func (db *DB) Maintain(ctx context.Context) error {
	_, err := db.db.ExecContext(ctx, `
		DELETE FROM AccessToken
		WHERE timediff('now', expires_at) > 0
		WHERE timediff('now', COALESCE(refresh_expires_at, expires_at)) > 0
	`)
	if err != nil {
		return err

M entity.go => entity.go +67 -41
@@ 8,6 8,7 @@ import (
	"database/sql/driver"
	"encoding/base64"
	"fmt"
	"reflect"
	"strconv"
	"strings"
	"time"


@@ 16,8 17,9 @@ import (
)

const (
	accessTokenExpiration = 30 * 24 * time.Hour
	authCodeExpiration    = 10 * time.Minute
	accessTokenExpiration  = 30 * 24 * time.Hour
	refreshTokenExpiration = 2 * accessTokenExpiration
	authCodeExpiration     = 10 * time.Minute
)

type entity interface {


@@ 67,32 69,37 @@ func (id ID[T]) Value() (driver.Value, error) {
	}
}

type nullString string
type nullValue struct {
	ptr interface{}
}

var (
	_ sql.Scanner   = (*nullString)(nil)
	_ driver.Valuer = (*nullString)(nil)
	_ sql.Scanner   = nullValue{nil}
	_ driver.Valuer = nullValue{nil}
)

func (ptr *nullString) Scan(v interface{}) error {
func (nv nullValue) Scan(v interface{}) error {
	out := reflect.ValueOf(nv.ptr).Elem()
	if v == nil {
		*ptr = ""
		out.SetZero()
		return nil
	}
	s, ok := v.(string)
	if !ok {
		return fmt.Errorf("cannot scan nullStringPtr from %T", v)

	rv := reflect.ValueOf(v)
	if rv.Type() != out.Type() {
		return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type())
	}
	*ptr = nullString(s)

	out.Set(rv)
	return nil
}

func (ptr *nullString) Value() (driver.Value, error) {
	if *ptr == "" {
func (nv nullValue) Value() (driver.Value, error) {
	in := reflect.ValueOf(nv.ptr).Elem()
	if in.IsZero() {
		return nil, nil
	} else {
		return string(*ptr), nil
	}
	return in.Interface(), nil
}

type User struct {


@@ 106,7 113,7 @@ func (user *User) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":            &user.ID,
		"username":      &user.Username,
		"password_hash": (*nullString)(&user.PasswordHash),
		"password_hash": nullValue{&user.PasswordHash},
		"admin":         &user.Admin,
	}
}


@@ 164,9 171,9 @@ func (client *Client) columns() map[string]interface{} {
		"client_id":          &client.ClientID,
		"client_secret_hash": &client.ClientSecretHash,
		"owner":              &client.Owner,
		"redirect_uris":      (*nullString)(&client.RedirectURIs),
		"client_name":        (*nullString)(&client.ClientName),
		"client_uri":         (*nullString)(&client.ClientURI),
		"redirect_uris":      nullValue{&client.RedirectURIs},
		"client_name":        nullValue{&client.ClientName},
		"client_uri":         nullValue{&client.ClientURI},
	}
}



@@ 186,6 193,9 @@ type AccessToken struct {
	Scope     string
	IssuedAt  time.Time
	ExpiresAt time.Time

	RefreshHash      []byte
	RefreshExpiresAt time.Time
}

func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {


@@ 199,25 209,35 @@ func (token *AccessToken) Generate(expiration time.Duration) (secret string, err
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
	token = &AccessToken{
func (token *AccessToken) GenerateRefresh() (secret string, err error) {
	secret, hash, err := generateSecret()
	if err != nil {
		return "", fmt.Errorf("failed to generate refresh token secret: %v", err)
	}
	token.RefreshHash = hash
	token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration)
	return secret, nil
}

func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
	return &AccessToken{
		User:   authCode.User,
		Client: authCode.Client,
		Scope:  authCode.Scope,
	}
	secret, err = token.Generate(accessTokenExpiration)
	return token, secret, err
}

func (token *AccessToken) columns() map[string]interface{} {
	return map[string]interface{}{
		"id":         &token.ID,
		"hash":       &token.Hash,
		"user":       &token.User,
		"client":     &token.Client,
		"scope":      (*nullString)(&token.Scope),
		"issued_at":  &token.IssuedAt,
		"expires_at": &token.ExpiresAt,
		"id":                 &token.ID,
		"hash":               &token.Hash,
		"user":               &token.User,
		"client":             &token.Client,
		"scope":              nullValue{&token.Scope},
		"issued_at":          &token.IssuedAt,
		"expires_at":         &token.ExpiresAt,
		"refresh_hash":       &token.RefreshHash,
		"refresh_expires_at": nullValue{&token.RefreshExpiresAt},
	}
}



@@ 225,6 245,10 @@ func (token *AccessToken) VerifySecret(secret string) bool {
	return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}

func (token *AccessToken) VerifyRefreshSecret(secret string) bool {
	return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt)
}

type AuthorizedClient struct {
	Client    Client
	ExpiresAt time.Time


@@ 257,8 281,8 @@ func (code *AuthCode) columns() map[string]interface{} {
		"created_at":   &code.CreatedAt,
		"user":         &code.User,
		"client":       &code.Client,
		"scope":        (*nullString)(&code.Scope),
		"redirect_uri": (*nullString)(&code.RedirectURI),
		"scope":        nullValue{&code.Scope},
		"redirect_uri": nullValue{&code.RedirectURI},
	}
}



@@ 269,8 293,9 @@ func (code *AuthCode) VerifySecret(secret string) bool {
type SecretKind byte

const (
	SecretKindAccessToken = SecretKind('a')
	SecretKindAuthCode    = SecretKind('c')
	SecretKindAccessToken  = SecretKind('a')
	SecretKindRefreshToken = SecretKind('r')
	SecretKindAuthCode     = SecretKind('c')
)

func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {


@@ 281,7 306,7 @@ func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	}

	switch SecretKind(kind[0]) {
	case SecretKindAccessToken:
	case SecretKindAccessToken, SecretKindRefreshToken:
		_, ok = interface{}(id).(ID[*AccessToken])
	case SecretKindAuthCode:
		_, ok = interface{}(id).(ID[*AuthCode])


@@ 294,19 319,20 @@ func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
	return id, secret, err
}

func MarshalSecret[T entity](id ID[T], secret string) string {
func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string {
	if id == 0 {
		panic("cannot marshal zero ID")
	}

	var kind SecretKind
	var ok bool
	switch interface{}(id).(type) {
	case ID[*AccessToken]:
		kind = SecretKindAccessToken
		ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken
	case ID[*AuthCode]:
		kind = SecretKindAuthCode
	default:
		panic(fmt.Sprintf("unsupported secret kind for ID type %T", id))
		ok = kind == SecretKindAuthCode
	}
	if !ok {
		panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id))
	}

	return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)

M middleware.go => middleware.go +1 -1
@@ 49,7 49,7 @@ func newBaseContext(db *DB, tpl *Template) context.Context {
func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
	http.SetCookie(w, &http.Cookie{
		Name:     loginCookieName,
		Value:    MarshalSecret(token.ID, secret),
		Value:    MarshalSecret(token.ID, SecretKindAccessToken, secret),
		HttpOnly: true,
		SameSite: http.SameSiteStrictMode,
		Secure:   isForwardedHTTPS(req),

M oauth2.go => oauth2.go +110 -56
@@ 174,7 174,7 @@ func authorize(w http.ResponseWriter, req *http.Request) {
		return
	}

	code := MarshalSecret(authCode.ID, secret)
	code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)

	values := make(url.Values)
	values.Set("code", code)


@@ 200,16 200,9 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
	clientID := values.Get("client_id")
	grantType := oauth2.GrantType(values.Get("grant_type"))
	scope := values.Get("scope")
	redirectURI := values.Get("redirect_uri")

	authClientID, clientSecret, _ := req.BasicAuth()
	if clientID == "" && authClientID == "" {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidRequest,
			Description: "Missing client ID",
		})
		return
	} else if clientID == "" {
	if clientID == "" {
		clientID = authClientID
	} else if clientID != authClientID {
		oauthError(w, &oauth2.Error{


@@ 219,20 212,21 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
		return
	}

	client, err := db.FetchClientByClientID(ctx, clientID)
	if err == errNoDBRows {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeInvalidClient,
			Description: "Invalid client ID",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
		return
	}
	var client *Client
	if clientID != "" {
		client, err = db.FetchClientByClientID(ctx, clientID)
		if err == errNoDBRows {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidClient,
				Description: "Invalid client ID",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

	if !client.IsPublic() {
		if !client.VerifySecret(clientSecret) {
		if !client.IsPublic() && !client.VerifySecret(clientSecret) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",


@@ 241,49 235,108 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
		}
	}

	if grantType != oauth2.GrantTypeAuthorizationCode {
	var token *AccessToken
	switch grantType {
	case oauth2.GrantTypeAuthorizationCode:
		if client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeInvalidRequest,
				Description: "Missing client ID",
			})
			return
		}

		codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
		authCode, err := db.PopAuthCode(ctx, codeID)
		if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid authorization code",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
			return
		}

		if scope != authCode.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
		if values.Get("redirect_uri") != authCode.RedirectURI {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid redirect URI",
			})
			return
		}

		token = NewAccessTokenFromAuthCode(authCode)
	case oauth2.GrantTypeRefreshToken:
		tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
		token, err = db.FetchAccessToken(ctx, tokenID)
		if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		} else if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
			return
		}

		if client != nil && client.ID != token.Client {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid refresh token",
			})
			return
		}

		tokenClient, err := db.FetchClient(ctx, token.Client)
		if err != nil {
			oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
			return
		}

		if !tokenClient.IsPublic() && client == nil {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid client secret",
			})
			return
		}

		if scope != token.Scope {
			oauthError(w, &oauth2.Error{
				Code:        oauth2.ErrorCodeAccessDenied,
				Description: "Invalid scope",
			})
			return
		}
	default:
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeUnsupportedGrantType,
			Description: "Unsupported grant type",
		})
		return
	}

	codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
	authCode, err := db.PopAuthCode(ctx, codeID)
	if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid authorization code",
		})
		return
	} else if err != nil {
		oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
		return
	}

	if scope != authCode.Scope {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid scope",
		})
		return
	}
	if redirectURI != authCode.RedirectURI {
		oauthError(w, &oauth2.Error{
			Code:        oauth2.ErrorCodeAccessDenied,
			Description: "Invalid redirect URI",
		})
	secret, err := token.Generate(accessTokenExpiration)
	if err != nil {
		oauthError(w, err)
		return
	}

	token, secret, err := NewAccessTokenFromAuthCode(authCode)
	refreshSecret, err := token.GenerateRefresh()
	if err != nil {
		oauthError(w, err)
		return
	}

	if err := db.CreateAccessToken(ctx, token); err != nil {
	if err := db.StoreAccessToken(ctx, token); err != nil {
		oauthError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}


@@ 291,10 344,11 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Cache-Control", "no-store")
	json.NewEncoder(w).Encode(&oauth2.TokenResp{
		AccessToken: MarshalSecret(token.ID, secret),
		TokenType:   oauth2.TokenTypeBearer,
		ExpiresIn:   time.Until(token.ExpiresAt),
		Scope:       strings.Split(token.Scope, " "),
		AccessToken:  MarshalSecret(token.ID, SecretKindAccessToken, secret),
		TokenType:    oauth2.TokenTypeBearer,
		ExpiresIn:    time.Until(token.ExpiresAt),
		Scope:        strings.Split(token.Scope, " "),
		RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret),
	})
}


M schema.sql => schema.sql +2 -0
@@ 24,6 24,8 @@ CREATE TABLE AccessToken (
	scope TEXT,
	issued_at datetime NOT NULL,
	expires_at datetime NOT NULL,
	refresh_hash BLOB UNIQUE,
	refresh_expires_at datetime,
	FOREIGN KEY(user) REFERENCES User(id),
	FOREIGN KEY(client) REFERENCES Client(id)
);

M user.go => user.go +1 -1
@@ 127,7 127,7 @@ func login(w http.ResponseWriter, req *http.Request) {
		httpError(w, fmt.Errorf("failed to generate access token: %v", err))
		return
	}
	if err := db.CreateAccessToken(ctx, &token); err != nil {
	if err := db.StoreAccessToken(ctx, &token); err != nil {
		httpError(w, fmt.Errorf("failed to create access token: %v", err))
		return
	}