~elektito/hodhod

f33ad68bdbd10fffbf36832430e1d92026377ff9 — Mostafa Razavi 1 year, 3 months ago 645dc91
Fix CGI

Requests were not being passed to CGI scripts
5 files changed, 74 insertions(+), 12 deletions(-)

M main.go
M pkg/hodhod/cgi.go
M pkg/hodhod/error_resp.go
M pkg/hodhod/response.go
M pkg/hodhod/static.go
M main.go => main.go +17 -2
@@ 4,6 4,7 @@ import (
	"bufio"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"flag"
	"fmt"
	"io"


@@ 34,6 35,11 @@ func (e ErrNotFound) Error() string {
	return fmt.Sprintf("URL %s not found: %s", e.Url, e.Reason)
}

func (e ErrNotFound) Is(err error) bool {
	_, ok := err.(ErrNotFound)
	return ok
}

var _ error = (*ErrNotFound)(nil)

func errNotFound(url string, reason string) ErrNotFound {


@@ 49,7 55,7 @@ func fail(whileDoing string, err error) {
}

func getResponseForRequest(req hodhod.Request, cfg *hodhod.Config) (resp hodhod.Response, err error) {
	backend, unmatched := cfg.GetBackendByUrl(req.Url)
	backend, unmatched := cfg.GetBackendByUrl(*req.Url)
	if backend == nil {
		err = errNotFound(req.Url.String(), "no route")
		return


@@ 109,11 115,20 @@ func handleConn(conn net.Conn, cfg *hodhod.Config) {
		RemoteAddr: conn.RemoteAddr().String(),
	}
	resp, err := getResponseForRequest(req, cfg)
	if err != nil {
	if errors.Is(err, ErrNotFound{}) {
		conn.Write([]byte("51 Not Found\r\n"))
		return
	} else if err != nil {
		log.Println("Could not find response for the request:", err)
		return
	}

	err = resp.Init(&req)
	if err != nil {
		conn.Write([]byte("40 Internal error\r\n"))
		return
	}

	var wg sync.WaitGroup
	wg.Add(2)


M pkg/hodhod/cgi.go => pkg/hodhod/cgi.go +45 -10
@@ 11,7 11,9 @@ import (

type CgiResponse struct {
	cmd          *exec.Cmd
	stdin        io.WriteCloser
	stdout       io.Reader
	stderr       io.Reader
	cancelScript func()
}



@@ 33,6 35,14 @@ func cgiError(exitCode int) CgiError {
	}
}

func (resp CgiResponse) Init(req *Request) (err error) {
	reqLine := []byte(req.Url.String())
	reqLine = append(reqLine, '\r', '\n')
	_, err = resp.stdin.Write(reqLine)
	resp.stdin.Close()
	return
}

func (resp CgiResponse) Read(p []byte) (n int, err error) {
	if resp.cmd.ProcessState != nil {
		if resp.cmd.ProcessState.ExitCode() != 0 {


@@ 44,7 54,8 @@ func (resp CgiResponse) Read(p []byte) (n int, err error) {
		return
	}

	return resp.stdout.Read(p)
	n, err = resp.stdout.Read(p)
	return
}

func (resp CgiResponse) Close() {


@@ 57,10 68,9 @@ func NewCgiResp(req Request, scriptPath string, cfg *Config) (resp Response) {
	ctx, cancelFunc := context.WithTimeout(context.Background(), time.Duration(cfg.CgiTimeout)*time.Second)
	cmd := exec.CommandContext(ctx, scriptPath)

	// create a pipe to connect to the script's stdout; we set the writer as the
	// script's stdout writer (where its output is written to), and keep the
	// reader side so we can read from and send the response to the client.
	r, w := io.Pipe()
	rStdin, wStdin := io.Pipe()
	rStdout, wStdout := io.Pipe()
	rStderr, wStderr := io.Pipe()

	cmd.Env = []string{
		"GATEWAY_INTERFACE=CGI/1.1",


@@ 76,7 86,9 @@ func NewCgiResp(req Request, scriptPath string, cfg *Config) (resp Response) {
		fmt.Sprintf("REMOTE_ADDR=%s", req.RemoteAddr),
		fmt.Sprintf("REMOTE_HOST=%s", req.RemoteAddr),
	}
	cmd.Stdout = w
	cmd.Stdin = rStdin
	cmd.Stdout = wStdout
	cmd.Stderr = wStderr
	cmd.WaitDelay = 5 * time.Second

	err := cmd.Start()


@@ 92,18 104,41 @@ func NewCgiResp(req Request, scriptPath string, cfg *Config) (resp Response) {
	}

	go func() {
		// This function can be useful for debugging CGI scripts. We can read
		// the stderr here and log it.
		//
		// TODO: We could have an option to log these (maybe to a separate file,
		// and/or when there was a CGI error)
		//
		// stderr, err := io.ReadAll(rStderr)
		// if err == nil {
		//    log.Println("CGI stderr:", string(stderr))
		// } else {
		// 	  log.Println("Error reading CGI stderr:", err)
		// }

		io.Copy(io.Discard, rStderr)
	}()

	go func() {
		err := cmd.Wait()
		if err != nil {
			log.Println("CGI script timeout:", scriptPath)
			w.CloseWithError(fmt.Errorf("CGI timeout"))
			log.Printf("CGI script (%s) timeout (error: %s)\n", scriptPath, err)
			rStdin.CloseWithError(fmt.Errorf("CGI timeout"))
			wStdout.CloseWithError(fmt.Errorf("CGI timeout"))
			wStderr.CloseWithError(fmt.Errorf("CGI timeout"))
		} else {
			w.Close()
			rStdin.Close()
			wStdout.Close()
			wStderr.Close()
		}
	}()

	resp = CgiResponse{
		cmd:          cmd,
		stdout:       r,
		stdin:        wStdin,
		stdout:       rStdout,
		stderr:       rStderr,
		cancelScript: cancelFunc,
	}
	return

M pkg/hodhod/error_resp.go => pkg/hodhod/error_resp.go +4 -0
@@ 11,6 11,10 @@ type ErrorResponse struct {
	returnedStatusLine bool
}

func (resp *ErrorResponse) Init(req *Request) (err error) {
	return
}

func (resp *ErrorResponse) Read(p []byte) (n int, err error) {
	if resp.returnedStatusLine {
		return 0, io.EOF

M pkg/hodhod/response.go => pkg/hodhod/response.go +4 -0
@@ 5,6 5,10 @@ import (
)

type Response interface {
	// Called before response body is read, in order to perform any needed
	// initialization.
	Init(req *Request) (err error)

	// Read a part of the response body. Implementes the io.Reader interface.
	Read(p []byte) (n int, err error)


M pkg/hodhod/static.go => pkg/hodhod/static.go +4 -0
@@ 13,6 13,10 @@ type StaticResponse struct {
	returnedStatusLine bool
}

func (resp *StaticResponse) Init(req *Request) (err error) {
	return
}

func (resp *StaticResponse) Read(p []byte) (n int, err error) {
	if !resp.returnedStatusLine {
		status := []byte(fmt.Sprintf("20 %s\r\n", resp.contentType))