M db.go => db.go +21 -5
@@ 15,6 15,11 @@ var schema string
var migrations = []string{
"", // migration #0 is reserved for schema initialization
+ `
+ ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB;
+ ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime;
+ CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash);
+ `,
}
var errNoDBRows = sql.ErrNoRows
@@ 210,7 215,7 @@ func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]Auth
SELECT id, client_id, client_name, client_uri, token.expires_at
FROM Client,
(
- SELECT client, MAX(expires_at) as expires_at
+ SELECT client, MAX(COALESCE(refresh_expires_at, expires_at)) as expires_at
FROM AccessToken
WHERE user = ?
GROUP BY client
@@ 255,10 260,21 @@ func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*Acces
return &token, err
}
-func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
+func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error {
return db.db.QueryRowContext(ctx, `
- INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
- VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
+ INSERT INTO AccessToken(id, hash, user, client, scope, issued_at,
+ expires_at, refresh_hash, refresh_expires_at)
+ VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at,
+ :refresh_hash, :refresh_expires_at)
+ ON CONFLICT(id) DO UPDATE SET
+ hash = :hash,
+ user = :user,
+ client = :client,
+ scope = :scope,
+ issued_at = :issued_at,
+ expires_at = :expires_at,
+ refresh_hash = :refresh_hash,
+ refresh_expires_at = :refresh_expires_at
RETURNING id
`, entityArgs(token)...).Scan(&token.ID)
}
@@ 301,7 317,7 @@ func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, err
func (db *DB) Maintain(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM AccessToken
- WHERE timediff('now', expires_at) > 0
+ WHERE timediff('now', COALESCE(refresh_expires_at, expires_at)) > 0
`)
if err != nil {
return err
M entity.go => entity.go +67 -41
@@ 8,6 8,7 @@ import (
"database/sql/driver"
"encoding/base64"
"fmt"
+ "reflect"
"strconv"
"strings"
"time"
@@ 16,8 17,9 @@ import (
)
const (
- accessTokenExpiration = 30 * 24 * time.Hour
- authCodeExpiration = 10 * time.Minute
+ accessTokenExpiration = 30 * 24 * time.Hour
+ refreshTokenExpiration = 2 * accessTokenExpiration
+ authCodeExpiration = 10 * time.Minute
)
type entity interface {
@@ 67,32 69,37 @@ func (id ID[T]) Value() (driver.Value, error) {
}
}
-type nullString string
+type nullValue struct {
+ ptr interface{}
+}
var (
- _ sql.Scanner = (*nullString)(nil)
- _ driver.Valuer = (*nullString)(nil)
+ _ sql.Scanner = nullValue{nil}
+ _ driver.Valuer = nullValue{nil}
)
-func (ptr *nullString) Scan(v interface{}) error {
+func (nv nullValue) Scan(v interface{}) error {
+ out := reflect.ValueOf(nv.ptr).Elem()
if v == nil {
- *ptr = ""
+ out.SetZero()
return nil
}
- s, ok := v.(string)
- if !ok {
- return fmt.Errorf("cannot scan nullStringPtr from %T", v)
+
+ rv := reflect.ValueOf(v)
+ if rv.Type() != out.Type() {
+ return fmt.Errorf("cannot scan %v into %v", rv.Type(), out.Type())
}
- *ptr = nullString(s)
+
+ out.Set(rv)
return nil
}
-func (ptr *nullString) Value() (driver.Value, error) {
- if *ptr == "" {
+func (nv nullValue) Value() (driver.Value, error) {
+ in := reflect.ValueOf(nv.ptr).Elem()
+ if in.IsZero() {
return nil, nil
- } else {
- return string(*ptr), nil
}
+ return in.Interface(), nil
}
type User struct {
@@ 106,7 113,7 @@ func (user *User) columns() map[string]interface{} {
return map[string]interface{}{
"id": &user.ID,
"username": &user.Username,
- "password_hash": (*nullString)(&user.PasswordHash),
+ "password_hash": nullValue{&user.PasswordHash},
"admin": &user.Admin,
}
}
@@ 164,9 171,9 @@ func (client *Client) columns() map[string]interface{} {
"client_id": &client.ClientID,
"client_secret_hash": &client.ClientSecretHash,
"owner": &client.Owner,
- "redirect_uris": (*nullString)(&client.RedirectURIs),
- "client_name": (*nullString)(&client.ClientName),
- "client_uri": (*nullString)(&client.ClientURI),
+ "redirect_uris": nullValue{&client.RedirectURIs},
+ "client_name": nullValue{&client.ClientName},
+ "client_uri": nullValue{&client.ClientURI},
}
}
@@ 186,6 193,9 @@ type AccessToken struct {
Scope string
IssuedAt time.Time
ExpiresAt time.Time
+
+ RefreshHash []byte
+ RefreshExpiresAt time.Time
}
func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
@@ 199,25 209,35 @@ func (token *AccessToken) Generate(expiration time.Duration) (secret string, err
return secret, nil
}
-func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
- token = &AccessToken{
+func (token *AccessToken) GenerateRefresh() (secret string, err error) {
+ secret, hash, err := generateSecret()
+ if err != nil {
+ return "", fmt.Errorf("failed to generate refresh token secret: %v", err)
+ }
+ token.RefreshHash = hash
+ token.RefreshExpiresAt = time.Now().Add(refreshTokenExpiration)
+ return secret, nil
+}
+
+func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
+ return &AccessToken{
User: authCode.User,
Client: authCode.Client,
Scope: authCode.Scope,
}
- secret, err = token.Generate(accessTokenExpiration)
- return token, secret, err
}
func (token *AccessToken) columns() map[string]interface{} {
return map[string]interface{}{
- "id": &token.ID,
- "hash": &token.Hash,
- "user": &token.User,
- "client": &token.Client,
- "scope": (*nullString)(&token.Scope),
- "issued_at": &token.IssuedAt,
- "expires_at": &token.ExpiresAt,
+ "id": &token.ID,
+ "hash": &token.Hash,
+ "user": &token.User,
+ "client": &token.Client,
+ "scope": nullValue{&token.Scope},
+ "issued_at": &token.IssuedAt,
+ "expires_at": &token.ExpiresAt,
+ "refresh_hash": &token.RefreshHash,
+ "refresh_expires_at": nullValue{&token.RefreshExpiresAt},
}
}
@@ 225,6 245,10 @@ func (token *AccessToken) VerifySecret(secret string) bool {
return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}
+func (token *AccessToken) VerifyRefreshSecret(secret string) bool {
+ return verifyHash(token.RefreshHash, secret) && verifyExpiration(token.RefreshExpiresAt)
+}
+
type AuthorizedClient struct {
Client Client
ExpiresAt time.Time
@@ 257,8 281,8 @@ func (code *AuthCode) columns() map[string]interface{} {
"created_at": &code.CreatedAt,
"user": &code.User,
"client": &code.Client,
- "scope": (*nullString)(&code.Scope),
- "redirect_uri": (*nullString)(&code.RedirectURI),
+ "scope": nullValue{&code.Scope},
+ "redirect_uri": nullValue{&code.RedirectURI},
}
}
@@ 269,8 293,9 @@ func (code *AuthCode) VerifySecret(secret string) bool {
type SecretKind byte
const (
- SecretKindAccessToken = SecretKind('a')
- SecretKindAuthCode = SecretKind('c')
+ SecretKindAccessToken = SecretKind('a')
+ SecretKindRefreshToken = SecretKind('r')
+ SecretKindAuthCode = SecretKind('c')
)
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
@@ 281,7 306,7 @@ func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
}
switch SecretKind(kind[0]) {
- case SecretKindAccessToken:
+ case SecretKindAccessToken, SecretKindRefreshToken:
_, ok = interface{}(id).(ID[*AccessToken])
case SecretKindAuthCode:
_, ok = interface{}(id).(ID[*AuthCode])
@@ 294,19 319,20 @@ func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
return id, secret, err
}
-func MarshalSecret[T entity](id ID[T], secret string) string {
+func MarshalSecret[T entity](id ID[T], kind SecretKind, secret string) string {
if id == 0 {
panic("cannot marshal zero ID")
}
- var kind SecretKind
+ var ok bool
switch interface{}(id).(type) {
case ID[*AccessToken]:
- kind = SecretKindAccessToken
+ ok = kind == SecretKindAccessToken || kind == SecretKindRefreshToken
case ID[*AuthCode]:
- kind = SecretKindAuthCode
- default:
- panic(fmt.Sprintf("unsupported secret kind for ID type %T", id))
+ ok = kind == SecretKindAuthCode
+ }
+ if !ok {
+ panic(fmt.Sprintf("unsupported secret kind %q for ID type %T", string(kind), id))
}
return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
M middleware.go => middleware.go +1 -1
@@ 49,7 49,7 @@ func newBaseContext(db *DB, tpl *Template) context.Context {
func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
http.SetCookie(w, &http.Cookie{
Name: loginCookieName,
- Value: MarshalSecret(token.ID, secret),
+ Value: MarshalSecret(token.ID, SecretKindAccessToken, secret),
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: isForwardedHTTPS(req),
M oauth2.go => oauth2.go +110 -56
@@ 174,7 174,7 @@ func authorize(w http.ResponseWriter, req *http.Request) {
return
}
- code := MarshalSecret(authCode.ID, secret)
+ code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)
values := make(url.Values)
values.Set("code", code)
@@ 200,16 200,9 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
clientID := values.Get("client_id")
grantType := oauth2.GrantType(values.Get("grant_type"))
scope := values.Get("scope")
- redirectURI := values.Get("redirect_uri")
authClientID, clientSecret, _ := req.BasicAuth()
- if clientID == "" && authClientID == "" {
- oauthError(w, &oauth2.Error{
- Code: oauth2.ErrorCodeInvalidRequest,
- Description: "Missing client ID",
- })
- return
- } else if clientID == "" {
+ if clientID == "" {
clientID = authClientID
} else if clientID != authClientID {
oauthError(w, &oauth2.Error{
@@ 219,20 212,21 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
return
}
- client, err := db.FetchClientByClientID(ctx, clientID)
- if err == errNoDBRows {
- oauthError(w, &oauth2.Error{
- Code: oauth2.ErrorCodeInvalidClient,
- Description: "Invalid client ID",
- })
- return
- } else if err != nil {
- oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
- return
- }
+ var client *Client
+ if clientID != "" {
+ client, err = db.FetchClientByClientID(ctx, clientID)
+ if err == errNoDBRows {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidClient,
+ Description: "Invalid client ID",
+ })
+ return
+ } else if err != nil {
+ oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
+ return
+ }
- if !client.IsPublic() {
- if !client.VerifySecret(clientSecret) {
+ if !client.IsPublic() && !client.VerifySecret(clientSecret) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
@@ 241,49 235,108 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
}
}
- if grantType != oauth2.GrantTypeAuthorizationCode {
+ var token *AccessToken
+ switch grantType {
+ case oauth2.GrantTypeAuthorizationCode:
+ if client == nil {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "Missing client ID",
+ })
+ return
+ }
+
+ codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
+ authCode, err := db.PopAuthCode(ctx, codeID)
+ if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeAccessDenied,
+ Description: "Invalid authorization code",
+ })
+ return
+ } else if err != nil {
+ oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
+ return
+ }
+
+ if scope != authCode.Scope {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeAccessDenied,
+ Description: "Invalid scope",
+ })
+ return
+ }
+ if values.Get("redirect_uri") != authCode.RedirectURI {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeAccessDenied,
+ Description: "Invalid redirect URI",
+ })
+ return
+ }
+
+ token = NewAccessTokenFromAuthCode(authCode)
+ case oauth2.GrantTypeRefreshToken:
+ tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
+ token, err = db.FetchAccessToken(ctx, tokenID)
+ if err == errNoDBRows || (err == nil && !token.VerifyRefreshSecret(refreshSecret)) {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeAccessDenied,
+ Description: "Invalid refresh token",
+ })
+ return
+ } else if err != nil {
+ oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
+ return
+ }
+
+ if client != nil && client.ID != token.Client {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeAccessDenied,
+ Description: "Invalid refresh token",
+ })
+ return
+ }
+
+ tokenClient, err := db.FetchClient(ctx, token.Client)
+ if err != nil {
+ oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
+ return
+ }
+
+ if !tokenClient.IsPublic() && client == nil {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeAccessDenied,
+ Description: "Invalid client secret",
+ })
+ return
+ }
+
+ if scope != token.Scope {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeAccessDenied,
+ Description: "Invalid scope",
+ })
+ return
+ }
+ default:
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedGrantType,
Description: "Unsupported grant type",
})
- return
}
- codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
- authCode, err := db.PopAuthCode(ctx, codeID)
- if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
- oauthError(w, &oauth2.Error{
- Code: oauth2.ErrorCodeAccessDenied,
- Description: "Invalid authorization code",
- })
- return
- } else if err != nil {
- oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
- return
- }
-
- if scope != authCode.Scope {
- oauthError(w, &oauth2.Error{
- Code: oauth2.ErrorCodeAccessDenied,
- Description: "Invalid scope",
- })
- return
- }
- if redirectURI != authCode.RedirectURI {
- oauthError(w, &oauth2.Error{
- Code: oauth2.ErrorCodeAccessDenied,
- Description: "Invalid redirect URI",
- })
+ secret, err := token.Generate(accessTokenExpiration)
+ if err != nil {
+ oauthError(w, err)
return
}
-
- token, secret, err := NewAccessTokenFromAuthCode(authCode)
+ refreshSecret, err := token.GenerateRefresh()
if err != nil {
oauthError(w, err)
return
}
- if err := db.CreateAccessToken(ctx, token); err != nil {
+ if err := db.StoreAccessToken(ctx, token); err != nil {
oauthError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
@@ 291,10 344,11 @@ func exchangeToken(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(&oauth2.TokenResp{
- AccessToken: MarshalSecret(token.ID, secret),
- TokenType: oauth2.TokenTypeBearer,
- ExpiresIn: time.Until(token.ExpiresAt),
- Scope: strings.Split(token.Scope, " "),
+ AccessToken: MarshalSecret(token.ID, SecretKindAccessToken, secret),
+ TokenType: oauth2.TokenTypeBearer,
+ ExpiresIn: time.Until(token.ExpiresAt),
+ Scope: strings.Split(token.Scope, " "),
+ RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret),
})
}
M schema.sql => schema.sql +2 -0
@@ 24,6 24,8 @@ CREATE TABLE AccessToken (
scope TEXT,
issued_at datetime NOT NULL,
expires_at datetime NOT NULL,
+ refresh_hash BLOB UNIQUE,
+ refresh_expires_at datetime,
FOREIGN KEY(user) REFERENCES User(id),
FOREIGN KEY(client) REFERENCES Client(id)
);
M user.go => user.go +1 -1
@@ 127,7 127,7 @@ func login(w http.ResponseWriter, req *http.Request) {
httpError(w, fmt.Errorf("failed to generate access token: %v", err))
return
}
- if err := db.CreateAccessToken(ctx, &token); err != nil {
+ if err := db.StoreAccessToken(ctx, &token); err != nil {
httpError(w, fmt.Errorf("failed to create access token: %v", err))
return
}