M cmd/soju/main.go => cmd/soju/main.go +3 -0
@@ 59,6 59,9 @@ func main() {
srv.LogPath = cfg.LogPath
srv.HTTPOrigins = cfg.HTTPOrigins
srv.Debug = debug
+ if cfg.ExternalAuth != nil {
+ srv.ExternalAuthURL = cfg.ExternalAuth.URL
+ }
for _, listen := range cfg.Listen {
listenURI := listen
M config/config.go => config/config.go +19 -7
@@ 13,14 13,20 @@ type TLS struct {
CertPath, KeyPath string
}
+type ExternalAuth struct {
+ Mechanism string
+ URL string
+}
+
type Server struct {
- Listen []string
- Hostname string
- TLS *TLS
- SQLDriver string
- SQLSource string
- LogPath string
- HTTPOrigins []string
+ Listen []string
+ Hostname string
+ TLS *TLS
+ SQLDriver string
+ SQLSource string
+ LogPath string
+ HTTPOrigins []string
+ ExternalAuth *ExternalAuth
}
func Defaults() *Server {
@@ 93,6 99,12 @@ func Parse(r io.Reader) (*Server, error) {
}
case "http-origin":
srv.HTTPOrigins = append(srv.HTTPOrigins, d.Params...)
+ case "external-auth":
+ var mech, url string
+ if err := d.parseParams(&mech, &url); err != nil {
+ return nil, err
+ }
+ srv.ExternalAuth = &ExternalAuth{mech, url}
default:
return nil, fmt.Errorf("unknown directive %q", d.Name)
}
M db.go => db.go +5 -1
@@ 9,6 9,8 @@ import (
_ "github.com/mattn/go-sqlite3"
)
+var ErrNoSuchUser = fmt.Errorf("soju: no such user")
+
type User struct {
Created bool
Username string
@@ 222,7 224,9 @@ func (db *DB) GetUser(username string) (*User, error) {
var password *string
row := db.db.QueryRow("SELECT password, admin FROM User WHERE username = ?", username)
- if err := row.Scan(&password, &user.Admin); err != nil {
+ if err := row.Scan(&password, &user.Admin); err == sql.ErrNoRows {
+ return nil, ErrNoSuchUser
+ } else if err != nil {
return nil, err
}
user.Password = fromStringPtr(password)
M downstream.go => downstream.go +49 -13
@@ 6,6 6,7 @@ import (
"fmt"
"io"
"net"
+ "net/http"
"strconv"
"strings"
"time"
@@ 683,21 684,56 @@ func unmarshalUsername(rawUsername string) (username, client, network string) {
func (dc *downstreamConn) authenticate(username, password string) error {
username, clientName, networkName := unmarshalUsername(username)
- u, err := dc.srv.db.GetUser(username)
- if err != nil {
- dc.logger.Printf("failed authentication for %q: %v", username, err)
- return errAuthFailed
- }
+ if dc.srv.ExternalAuthURL != "" {
+ if !strings.HasPrefix(dc.srv.ExternalAuthURL, "https://") {
+ dc.logger.Printf("failed authentication for %q: unsupported external auth URL", username)
+ return errAuthFailed
+ }
+ req, err := http.NewRequest(http.MethodGet, dc.srv.ExternalAuthURL, nil)
+ if err != nil {
+ dc.logger.Printf("failed authentication for %q: failed to create HTTP request: %v", username, err)
+ return errAuthFailed
+ }
+ req.SetBasicAuth(username, password)
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ dc.logger.Printf("failed authentication for %q: failed to send HTTP request: %v", username, err)
+ return errAuthFailed
+ }
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ dc.logger.Printf("failed authentication for %q: HTTP error: %v", username, resp.Status)
+ return errAuthFailed
+ }
- // Password auth disabled
- if u.Password == "" {
- return errAuthFailed
- }
+ // Insert user in DB on first login
+ if _, err := dc.srv.db.GetUser(username); err == ErrNoSuchUser {
+ u := User{Username: username, Password: "!!"}
+ if _, err := dc.srv.createUser(&u); err != nil {
+ dc.logger.Printf("failed to auto-create user %q: %v", username, err)
+ return errAuthFailed
+ }
+ } else if err != nil {
+ dc.logger.Printf("failed authentication for %q: %v", username, err)
+ return errAuthFailed
+ }
+ } else {
+ u, err := dc.srv.db.GetUser(username)
+ if err != nil {
+ dc.logger.Printf("failed authentication for %q: %v", username, err)
+ return errAuthFailed
+ }
- err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
- if err != nil {
- dc.logger.Printf("failed authentication for %q: %v", username, err)
- return errAuthFailed
+ if u.Password == "" {
+ dc.logger.Printf("failed authentication for %q: no password in DB", username)
+ return errAuthFailed
+ }
+
+ err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
+ if err != nil {
+ dc.logger.Printf("failed authentication for %q: %v", username, err)
+ return errAuthFailed
+ }
}
dc.user = dc.srv.getUser(username)
M server.go => server.go +8 -7
@@ 41,13 41,14 @@ func (l *prefixLogger) Printf(format string, v ...interface{}) {
}
type Server struct {
- Hostname string
- Logger Logger
- RingCap int
- HistoryLimit int
- LogPath string
- Debug bool
- HTTPOrigins []string
+ Hostname string
+ Logger Logger
+ RingCap int
+ HistoryLimit int
+ LogPath string
+ Debug bool
+ HTTPOrigins []string
+ ExternalAuthURL string
db *DB