M scgi.go => scgi.go +9 -8
@@ 125,6 125,7 @@ func headerMap(r io.Reader) (map[string]string, error) {
}
// Request reads an SCGI request from an io.Reader, returning a http.Request.
+// The returned Request's Body field is not populated.
func Request(r io.Reader) (*http.Request, error) {
headers, err := headerMap(r)
if err != nil {
@@ 136,10 137,6 @@ func Request(r io.Reader) (*http.Request, error) {
return nil, err
}
- if req.ContentLength > 0 {
- req.Body = io.NopCloser(io.LimitReader(r, req.ContentLength))
- }
-
return req, nil
}
@@ 156,17 153,21 @@ func NewConn(c net.Conn) (*Conn, error) {
return nil, err
}
- return &Conn{Conn: c, buf: &buf}, nil
+ _, err = io.WriteString(&buf, "\r\n")
+ if err != nil {
+ return nil, err
+ }
+
+ return &Conn{Conn: c, r: io.MultiReader(&buf, io.LimitReader(c, req.ContentLength))}, nil
}
// Conn converts an incoming SCGI-connection into a connection presenting a http.Request.
-// Note that the full request is read into a buffer which is then served from.
type Conn struct {
- buf io.Reader
+ r io.Reader
net.Conn
}
// Read from the buffer containing the converted SCGI request.
func (c *Conn) Read(b []byte) (n int, err error) {
- return c.buf.Read(b)
+ return c.r.Read(b)
}
M scgi_test.go => scgi_test.go +8 -8
@@ 19,6 19,7 @@ import (
"testing"
"strings"
"fmt"
+ "net/http"
)
var netstringTests = []struct{
@@ 80,17 81,19 @@ func toNetstring(s string) string {
var requestTests = []struct{
headers string
- body string
+ contentLength int
+ method string
}{
{
"SCGI\x001\x00CONTENT_LENGTH\x006\x00REQUEST_METHOD\x00GET\x00SERVER_PROTOCOL\x00HTTP/1.1\x00",
- "foobar",
+ 6,
+ http.MethodGet,
},
}
func TestRequest(t *testing.T) {
for _, v := range requestTests {
- r := strings.NewReader(toNetstring(v.headers) + v.body)
+ r := strings.NewReader(toNetstring(v.headers))
req, err := Request(r)
if err != nil {
@@ 98,16 101,13 @@ func TestRequest(t *testing.T) {
t.FailNow()
}
- buf := make([]byte, len(v.body))
- _, err = req.Body.Read(buf)
-
if err != nil {
t.Log(err)
t.Fail()
}
- if string(buf) != v.body {
- t.Logf("wrong body contents: %v", string(buf))
+ if req.ContentLength != 6 {
+ t.Logf("wrong content length, got: %v, want: %v", req.ContentLength, v.contentLength)
t.Fail()
}
}