~emersion/soju

2d34baa439c1889501196d5dbf1ce125253a55c9 — Simon Ser 4 months ago dcfe206 external-auth
Implement HTTP basic external auth
5 files changed, 84 insertions(+), 28 deletions(-)

M cmd/soju/main.go
M config/config.go
M db.go
M downstream.go
M server.go
M cmd/soju/main.go => cmd/soju/main.go +3 -0
@@ 59,6 59,9 @@ func main() {
	srv.LogPath = cfg.LogPath
	srv.HTTPOrigins = cfg.HTTPOrigins
	srv.Debug = debug
	if cfg.ExternalAuth != nil {
		srv.ExternalAuthURL = cfg.ExternalAuth.URL
	}

	for _, listen := range cfg.Listen {
		listenURI := listen

M config/config.go => config/config.go +19 -7
@@ 13,14 13,20 @@ type TLS struct {
	CertPath, KeyPath string
}

type ExternalAuth struct {
	Mechanism string
	URL       string
}

type Server struct {
	Listen      []string
	Hostname    string
	TLS         *TLS
	SQLDriver   string
	SQLSource   string
	LogPath     string
	HTTPOrigins []string
	Listen       []string
	Hostname     string
	TLS          *TLS
	SQLDriver    string
	SQLSource    string
	LogPath      string
	HTTPOrigins  []string
	ExternalAuth *ExternalAuth
}

func Defaults() *Server {


@@ 93,6 99,12 @@ func Parse(r io.Reader) (*Server, error) {
			}
		case "http-origin":
			srv.HTTPOrigins = append(srv.HTTPOrigins, d.Params...)
		case "external-auth":
			var mech, url string
			if err := d.parseParams(&mech, &url); err != nil {
				return nil, err
			}
			srv.ExternalAuth = &ExternalAuth{mech, url}
		default:
			return nil, fmt.Errorf("unknown directive %q", d.Name)
		}

M db.go => db.go +5 -1
@@ 9,6 9,8 @@ import (
	_ "github.com/mattn/go-sqlite3"
)

var ErrNoSuchUser = fmt.Errorf("soju: no such user")

type User struct {
	Created  bool
	Username string


@@ 222,7 224,9 @@ func (db *DB) GetUser(username string) (*User, error) {

	var password *string
	row := db.db.QueryRow("SELECT password, admin FROM User WHERE username = ?", username)
	if err := row.Scan(&password, &user.Admin); err != nil {
	if err := row.Scan(&password, &user.Admin); err == sql.ErrNoRows {
		return nil, ErrNoSuchUser
	} else if err != nil {
		return nil, err
	}
	user.Password = fromStringPtr(password)

M downstream.go => downstream.go +49 -13
@@ 6,6 6,7 @@ import (
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"strings"
	"time"


@@ 683,21 684,56 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
func (dc *downstreamConn) authenticate(username, password string) error {
	username, clientName, networkName := unmarshalUsername(username)

	u, err := dc.srv.db.GetUser(username)
	if err != nil {
		dc.logger.Printf("failed authentication for %q: %v", username, err)
		return errAuthFailed
	}
	if dc.srv.ExternalAuthURL != "" {
		if !strings.HasPrefix(dc.srv.ExternalAuthURL, "https://") {
			dc.logger.Printf("failed authentication for %q: unsupported external auth URL", username)
			return errAuthFailed
		}
		req, err := http.NewRequest(http.MethodGet, dc.srv.ExternalAuthURL, nil)
		if err != nil {
			dc.logger.Printf("failed authentication for %q: failed to create HTTP request: %v", username, err)
			return errAuthFailed
		}
		req.SetBasicAuth(username, password)
		resp, err := http.DefaultClient.Do(req)
		if err != nil {
			dc.logger.Printf("failed authentication for %q: failed to send HTTP request: %v", username, err)
			return errAuthFailed
		}
		resp.Body.Close()
		if resp.StatusCode != http.StatusOK {
			dc.logger.Printf("failed authentication for %q: HTTP error: %v", username, resp.Status)
			return errAuthFailed
		}

	// Password auth disabled
	if u.Password == "" {
		return errAuthFailed
	}
		// Insert user in DB on first login
		if _, err := dc.srv.db.GetUser(username); err == ErrNoSuchUser {
			u := User{Username: username, Password: "!!"}
			if _, err := dc.srv.createUser(&u); err != nil {
				dc.logger.Printf("failed to auto-create user %q: %v", username, err)
				return errAuthFailed
			}
		} else if err != nil {
			dc.logger.Printf("failed authentication for %q: %v", username, err)
			return errAuthFailed
		}
	} else {
		u, err := dc.srv.db.GetUser(username)
		if err != nil {
			dc.logger.Printf("failed authentication for %q: %v", username, err)
			return errAuthFailed
		}

	err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
	if err != nil {
		dc.logger.Printf("failed authentication for %q: %v", username, err)
		return errAuthFailed
		if u.Password == "" {
			dc.logger.Printf("failed authentication for %q: no password in DB", username)
			return errAuthFailed
		}

		err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
		if err != nil {
			dc.logger.Printf("failed authentication for %q: %v", username, err)
			return errAuthFailed
		}
	}

	dc.user = dc.srv.getUser(username)

M server.go => server.go +8 -7
@@ 41,13 41,14 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) {
}

type Server struct {
	Hostname     string
	Logger       Logger
	RingCap      int
	HistoryLimit int
	LogPath      string
	Debug        bool
	HTTPOrigins  []string
	Hostname        string
	Logger          Logger
	RingCap         int
	HistoryLimit    int
	LogPath         string
	Debug           bool
	HTTPOrigins     []string
	ExternalAuthURL string

	db *DB