~fnux/yggdrasil-go-coap

195895436775d7216eea1480cdbad9ee4cf506f2 — Arunprabhu Kandasamy 2 years ago a643abf
validate MaxMessageSize in blockwise and tcp session (#42)

* validate MaxMessageSize in blockwise and tcp session
M blockwise.go => blockwise.go +41 -0
@@ 341,7 341,29 @@ func (b *blockWiseSession) WriteMsg(msg Message) error {
	return b.WriteMsgWithContext(context.Background(), msg)
}

func (b *blockWiseSession) validateMessageSize(msg Message) error {
	size, err := msg.ToBytesLength()
	if err != nil {
		return err
	}
	session, ok := b.networkSession.(*sessionTCP)
	if !ok {
		// Not supported for UDP session
		return nil
	}

	if session.peerMaxMessageSize != 0 &&
		uint32(size) > session.peerMaxMessageSize {
		return ErrMaxMessageSizeLimitExceeded
	}

	return nil
}

func (b *blockWiseSession) WriteMsgWithContext(ctx context.Context, msg Message) error {
	if err := b.validateMessageSize(msg); err != nil {
		return err
	}
	switch msg.Code() {
	case CSM, Ping, Pong, Release, Abort, Empty, GET:
		return b.networkSession.WriteMsgWithContext(ctx, msg)


@@ 532,7 554,26 @@ func (r *blockWiseReceiver) exchange(ctx context.Context, b *blockWiseSession, r
	return resp, err
}

func (r *blockWiseReceiver) validateMessageSize(msg Message, b *blockWiseSession) error {
	size, err := msg.ToBytesLength()
	if err != nil {
		return err
	}

	session, ok := b.networkSession.(*sessionTCP)
	if ok {
		if session.srv.MaxMessageSize != 0 &&
			uint32(size) > session.srv.MaxMessageSize {
			return ErrMaxMessageSizeLimitExceeded
		}
	}
	return nil
}

func (r *blockWiseReceiver) processResp(b *blockWiseSession, req Message, resp Message) (Message, error) {
	if err := r.validateMessageSize(req, b); err != nil {
		return nil, err
	}
	if respBlock, ok := resp.Option(r.blockType).(uint32); ok {
		szx, num, more, err := UnmarshalBlockOption(respBlock)
		if err != nil {

M client_test.go => client_test.go +1 -0
@@ 52,6 52,7 @@ func testServingObservation(t *testing.T, net string, addrstr string, BlockWiseT
		Net:                  net,
		BlockWiseTransfer:    &BlockWiseTransfer,
		BlockWiseTransferSzx: &BlockWiseTransferSzx,
		MaxMessageSize:       ^uint32(0),
	}

	conn, err := client.Dial(addrstr)

M error.go => error.go +3 -0
@@ 106,3 106,6 @@ const ErrUnexpectedReponseCode = Error("unexpected response code")

// ErrMessageNotInterested message is not of interest to the client
const ErrMessageNotInterested = Error("message not to be sent due to disinterest")

// ErrMaxMessageSizeLimitExceeded message size bigger thab maximum message size limit
const ErrMaxMessageSizeLimitExceeded = Error("maximum message size limit exceeded")

M message.go => message.go +1 -0
@@ 503,6 503,7 @@ type Message interface {
	UnmarshalBinary(data []byte) error
	SetToken(t []byte)
	SetMessageID(messageID uint16)
	ToBytesLength() (int, error)
}

// MessageParams params to create COAP message

M message_test.go => message_test.go +21 -0
@@ 1132,3 1132,24 @@ func TestDecodeMessageWithNoResponseOption(t *testing.T) {
		t.Fatalf("parsedMsg.Option(NoResponse): %v", parsedMsg.Option(NoResponse).(uint32))
	}
}

func TestToBytesLength(t *testing.T) {
	data := []byte{
		0x40, 0x1, 0x30, 0x39, 0x46, 0x77,
		0x65, 0x65, 0x74, 0x61, 0x67, 0xa1, 0x3,
	}

	msg, err := ParseDgramMessage(data)
	if err != nil {
		t.Fatalf("Error parsing request: %v", err)
	}

	bytesLength, err := msg.ToBytesLength()
	if err != nil {
		t.Fatalf("Error parsing request: %v", err)
	}

	if len(data) != bytesLength {
		t.Errorf("Expected Length = %d, got %d", len(data), bytesLength)
	}
}

M messagedgram.go => messagedgram.go +11 -0
@@ 1,6 1,7 @@
package coap

import (
	"bytes"
	"encoding/binary"
	"io"
	"sort"


@@ 114,3 115,13 @@ func ParseDgramMessage(data []byte) (*DgramMessage, error) {
	rv := &DgramMessage{}
	return rv, rv.UnmarshalBinary(data)
}

// ToBytesLength gets the length of the message
func (m *DgramMessage) ToBytesLength() (int, error) {
	buf := bytes.NewBuffer(make([]byte, 0, 1024))
	if err := m.MarshalBinary(buf); err != nil {
		return 0, err
	}

	return len(buf.Bytes()), nil
}

M messagetcp.go => messagetcp.go +9 -0
@@ 332,6 332,15 @@ func (m *TcpMessage) fill(mti msgTcpInfo, o options, p []byte) {
	m.MessageBase.payload = p
}

func (m *TcpMessage) ToBytesLength() (int, error) {
	buf := bytes.NewBuffer(make([]byte, 0, 1024))
	if err := m.MarshalBinary(buf); err != nil {
		return 0, err
	}

	return len(buf.Bytes()), nil
}

type contextBytesReader struct {
	reader io.Reader
}

M messagetcp_test.go => messagetcp_test.go +33 -0
@@ 31,4 31,37 @@ func TestTCPDecodeMessageSmallWithPayload(t *testing.T) {
	if !bytes.Equal(msg.Payload(), []byte("hi")) {
		t.Errorf("Incorrect payload: %q", msg.Payload())
	}

}

func TestMessageTCPToBytesLength(t *testing.T) {
	msgParams := MessageParams{
		Code:    COAPCode(02),
		Token:   []byte{0xab},
		Payload: []byte("hi"),
	}

	msg := NewTcpMessage(msgParams)
	msg.AddOption(MaxMessageSize, maxMessageSize)

	buf := &bytes.Buffer{}
	err := msg.MarshalBinary(buf)
	if err != nil {
		t.Fatalf("Error encoding request: %v", err)
	}

	bytesLength, err := msg.ToBytesLength()
	if err != nil {
		t.Fatalf("Error parsing request: %v", err)
	}

	lenTkl := 1
	lenCode := 1
	maxMessageSizeOptionLength := 3
	payloadMarker := []byte{0xff}

	expectedLength := lenTkl + lenCode + len(msgParams.Token) + maxMessageSizeOptionLength + len(payloadMarker) + len(msgParams.Payload)
	if expectedLength != bytesLength {
		t.Errorf("Expected Length  = %d, got %d", expectedLength, bytesLength)
	}
}

M networksession.go => networksession.go +16 -0
@@ 423,8 423,24 @@ func (s *sessionUDP) ExchangeWithContext(ctx context.Context, req Message) (Mess
	}
}

func (s *sessionTCP) validateMessageSize(msg Message) error {
	size, err := msg.ToBytesLength()
	if err != nil {
		return err
	}

	if uint32(size) > s.peerMaxMessageSize {
		return ErrMaxMessageSizeLimitExceeded
	}

	return nil
}

// Write implements the networkSession.Write method.
func (s *sessionTCP) WriteMsgWithContext(ctx context.Context, req Message) error {
	if err := s.validateMessageSize(req); err != nil {
		return err
	}
	buffer := bytes.NewBuffer(make([]byte, 0, 1500))
	err := req.MarshalBinary(buffer)
	if err != nil {

M server.go => server.go +5 -0
@@ 438,6 438,11 @@ func (srv *Server) serveTCPconnection(ctx *shutdownContext, netConn net.Conn) er
			return session.closeWithError(fmt.Errorf("cannot serve tcp connection: %v", err))
		}

		if srv.MaxMessageSize != 0 &&
			uint32(mti.totLen) > srv.MaxMessageSize {
			return session.closeWithError(fmt.Errorf("cannot serve tcp connection: %v", ErrMaxMessageSizeLimitExceeded))
		}

		body := make([]byte, mti.BodyLen())
		//ctx, cancel := context.WithTimeout(srv.ctx, srv.readTimeout())
		err = conn.ReadFullWithContext(ctx, body)

M server_test.go => server_test.go +3 -1
@@ 161,7 161,9 @@ func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan 
			fmt.Printf("networkSession start %v\n", s.RemoteAddr())
		}, NotifySessionEndFunc: func(w *ClientConn, err error) {
			fmt.Printf("networkSession end %v: %v\n", w.RemoteAddr(), err)
		}}
		},
		MaxMessageSize: ^uint32(0),
	}

	// fin must be buffered so the goroutine below won't block
	// forever if fin is never read from. This always happens