package gfsmux_test
import (
"bytes"
crand "crypto/rand"
"encoding/binary"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
_ "net/http/pprof"
"strings"
"sync"
"testing"
"time"
smux "github.com/johnsonjh/gfsmux"
u "github.com/johnsonjh/leaktestfe"
)
func init() {
go func() {
log.Println(
http.ListenAndServe(
"0.0.0.0:6060",
nil,
),
)
}()
}
// setupServer starts new server listening on a random localhost port and
// returns address of the server, function to stop the server, new client
// Connection to this server or an error.
func setupServer(
tb testing.TB,
) (
addr string,
stopfunc func(),
client net.Conn,
err error,
) {
ln, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
return "", nil, nil, err
}
go func() {
Conn, err := ln.Accept()
if err != nil {
return
}
go handleConnection(
Conn,
)
}()
addr = ln.Addr().String()
Conn, err := net.Dial(
"tcp",
addr,
)
if err != nil {
ln.Close()
return "", nil, nil, err
}
return ln.Addr().String(), func() { ln.Close() }, Conn, nil
}
func handleConnection(
Conn net.Conn,
) {
session, _ := smux.Server(
Conn,
nil,
)
for {
if stream, err := session.AcceptStream(); err == nil {
go func(
s io.ReadWriteCloser,
) {
buf := make(
[]byte,
65536,
)
for {
n, err := s.Read(
buf,
)
if err != nil {
return
}
s.Write(
buf[:n],
)
}
}(
stream,
)
} else {
return
}
}
}
// setupServer starts new server listening on a random localhost port and
// returns address of the server, function to stop the server, new client
// Connection to this server or an error.
func setupServerV2(
tb testing.TB,
) (
addr string,
stopfunc func(),
client net.Conn,
err error,
) {
ln, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
return "", nil, nil, err
}
go func() {
Conn, err := ln.Accept()
if err != nil {
return
}
go handleConnectionV2(
Conn,
)
}()
addr = ln.Addr().String()
Conn, err := net.Dial(
"tcp",
addr,
)
if err != nil {
ln.Close()
return "", nil, nil, err
}
return ln.Addr().String(), func() { ln.Close() }, Conn, nil
}
func handleConnectionV2(
Conn net.Conn,
) {
Config := smux.DefaultConfig()
Config.Version = 2
session, _ := smux.Server(
Conn,
Config,
)
for {
if stream, err := session.AcceptStream(); err == nil {
go func(
s io.ReadWriteCloser,
) {
buf := make(
[]byte,
65536,
)
for {
n, err := s.Read(
buf,
)
if err != nil {
return
}
s.Write(
buf[:n],
)
}
}(
stream,
)
} else {
return
}
}
}
func TestEcho(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
const N = 100
buf := make(
[]byte,
10,
)
var sent string
var received string
for i := 0; i < N; i++ {
msg := fmt.Sprintf(
"hello%v",
i,
)
stream.Write(
[]byte(msg),
)
sent += msg
if n, err := stream.Read(
buf,
); err != nil {
t.Fatal(
err,
)
} else {
received += string(buf[:n])
}
}
if sent != received {
t.Fatal(
"data mimatch",
)
}
session.Close()
}
func TestWriteTo(
t *testing.T,
) {
defer u.Leakplug(
t,
)
const N = 1 << 20
// server
ln, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
t.Fatal(
err,
)
}
defer ln.Close()
go func() {
Conn, err := ln.Accept()
if err != nil {
return
}
session, _ := smux.Server(
Conn,
nil,
)
for {
if stream, err := session.AcceptStream(); err == nil {
go func(
s io.ReadWriteCloser,
) {
numBytes := 0
buf := make(
[]byte,
65536,
)
for {
n, err := s.Read(
buf,
)
if err != nil {
return
}
s.Write(
buf[:n],
)
numBytes += n
if numBytes == N {
s.Close()
return
}
}
}(
stream,
)
} else {
return
}
}
}()
addr := ln.Addr().String()
Conn, err := net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
defer Conn.Close()
// client
session, _ := smux.Client(
Conn,
nil,
)
stream, _ := session.OpenStream()
sndbuf := make(
[]byte,
N,
)
for i := range sndbuf {
sndbuf[i] = byte(
rand.Int(),
)
}
go stream.Write(
sndbuf,
)
var rcvbuf bytes.Buffer
nw, ew := stream.WriteTo(
&rcvbuf,
)
if ew != io.EOF {
t.Fatal(
ew,
)
}
if nw != N {
t.Fatal(
"WriteTo nw mismatch",
nw,
)
}
if !bytes.Equal(
sndbuf,
rcvbuf.Bytes(),
) {
t.Fatal(
"mismatched echo bytes",
)
}
}
func TestWriteToV2(
t *testing.T,
) {
defer u.Leakplug(
t,
)
Config := smux.DefaultConfig()
Config.Version = 2
const N = 1 << 20
// server
ln, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
t.Fatal(
err,
)
}
defer ln.Close()
go func() {
Conn, err := ln.Accept()
if err != nil {
return
}
session, _ := smux.Server(
Conn,
Config,
)
for {
if stream, err := session.AcceptStream(); err == nil {
go func(
s io.ReadWriteCloser,
) {
numBytes := 0
buf := make(
[]byte,
65536,
)
for {
n, err := s.Read(
buf,
)
if err != nil {
return
}
s.Write(
buf[:n],
)
numBytes += n
if numBytes == N {
s.Close()
return
}
}
}(
stream,
)
} else {
return
}
}
}()
addr := ln.Addr().String()
Conn, err := net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
defer Conn.Close()
// client
session, _ := smux.Client(
Conn,
Config,
)
stream, _ := session.OpenStream()
sndbuf := make(
[]byte,
N,
)
for i := range sndbuf {
sndbuf[i] = byte(
rand.Int(),
)
}
go stream.Write(
sndbuf,
)
var rcvbuf bytes.Buffer
nw, ew := stream.WriteTo(
&rcvbuf,
)
if ew != io.EOF {
t.Fatal(
ew,
)
}
if nw != N {
t.Fatal(
"WriteTo nw mismatch",
nw,
)
}
if !bytes.Equal(
sndbuf,
rcvbuf.Bytes(),
) {
t.Fatal(
"mismatched echo bytes",
)
}
}
func TestGetDieCh(
t *testing.T,
) {
defer u.Leakplug(
t,
)
cs, ss, err := getSmuxStreamPair()
if err != nil {
t.Fatal(
err,
)
}
defer ss.Close()
dieCh := ss.GetDieCh()
go func() {
select {
case <-dieCh:
case <-time.Tick(time.Second):
t.Fatal(
"wait die chan timeout",
)
}
}()
cs.Close()
}
func TestSpeed(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
t.Log(stream.LocalAddr(), stream.RemoteAddr())
start := time.Now()
var wg sync.WaitGroup
wg.Add(
1,
)
go func() {
buf := make(
[]byte,
1024*1024,
)
nrecv := 0
for {
n, err := stream.Read(
buf,
)
if err != nil {
t.Error(
err,
)
break
} else {
nrecv += n
if nrecv == 4096*4096 {
break
}
}
}
stream.Close()
t.Log(
"time for 16MB rtt",
time.Since(
start,
),
)
wg.Done()
}()
msg := make(
[]byte,
8192,
)
for i := 0; i < 2048; i++ {
stream.Write(
msg,
)
}
wg.Wait()
session.Close()
}
func TestParallel(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
par := 1000
messages := 100
var wg sync.WaitGroup
wg.Add(
par,
)
for i := 0; i < par; i++ {
stream, _ := session.OpenStream()
go func(
s *smux.Stream,
) {
buf := make(
[]byte,
20,
)
for j := 0; j < messages; j++ {
msg := fmt.Sprintf(
"hello%v",
j,
)
s.Write(
[]byte(msg),
)
if _, err := s.Read(
buf,
); err != nil {
break
}
}
s.Close()
wg.Done()
}(
stream,
)
}
t.Log(
"created",
session.NumStreams(),
"streams",
)
wg.Wait()
session.Close()
}
func TestParallelV2(
t *testing.T,
) {
defer u.Leakplug(
t,
)
Config := smux.DefaultConfig()
Config.Version = 2
_, stop, cli, err := setupServerV2(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
Config,
)
par := 1000
messages := 100
var wg sync.WaitGroup
wg.Add(
par,
)
for i := 0; i < par; i++ {
stream, _ := session.OpenStream()
go func(
s *smux.Stream,
) {
buf := make(
[]byte,
20,
)
for j := 0; j < messages; j++ {
msg := fmt.Sprintf(
"hello%v",
j,
)
s.Write(
[]byte(msg),
)
if _, err := s.Read(
buf,
); err != nil {
break
}
}
s.Close()
wg.Done()
}(
stream,
)
}
t.Log(
"created",
session.NumStreams(),
"streams",
)
wg.Wait()
session.Close()
}
func TestCloseThenOpen(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
session.Close()
if _, err := session.OpenStream(); err == nil {
t.Fatal(
"opened after close",
)
}
}
func TestSessionDoubleClose(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
session.Close()
if err := session.Close(); err == nil {
t.Fatal(
"session double close doesn't return error",
)
}
}
func TestStreamDoubleClose(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
stream.Close()
if err := stream.Close(); err == nil {
t.Fatal(
"stream double close doesn't return error",
)
}
session.Close()
}
func TestConcurrentClose(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
numStreams := 100
streams := make(
[]*smux.Stream,
0,
numStreams,
)
var wg sync.WaitGroup
wg.Add(
numStreams,
)
for i := 0; i < 100; i++ {
stream, _ := session.OpenStream()
streams = append(
streams,
stream,
)
}
for _, s := range streams {
stream := s
go func() {
stream.Close()
wg.Done()
}()
}
session.Close()
wg.Wait()
}
func TestTinyReadBuffer(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
const N = 100
tinybuf := make(
[]byte,
6,
)
var sent string
var received string
for i := 0; i < N; i++ {
msg := fmt.Sprintf(
"hello%v",
i,
)
sent += msg
nsent, err := stream.Write(
[]byte(msg),
)
if err != nil {
t.Fatal(
"cannot write",
)
}
nrecv := 0
for nrecv < nsent {
if n, err := stream.Read(
tinybuf,
); err == nil {
nrecv += n
received += string(tinybuf[:n])
} else {
t.Fatal(
"cannot read with tiny buffer",
)
}
}
}
if sent != received {
t.Fatal(
"data mimatch",
)
}
session.Close()
}
func TestIsClose(
t *testing.T,
) {
defer u.Leakplug(
t,
)
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
session.Close()
if !session.IsClosed() {
t.Fatal(
"still open after close",
)
}
}
func TestKeepAliveTimeout(
t *testing.T,
) {
defer u.Leakplug(
t,
)
ln, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
t.Fatal(
err,
)
}
defer ln.Close()
go func() {
ln.Accept()
}()
cli, err := net.Dial(
"tcp",
ln.Addr().String(),
)
if err != nil {
t.Fatal(
err,
)
}
defer cli.Close()
Config := smux.DefaultConfig()
Config.KeepAliveInterval = time.Second
Config.KeepAliveTimeout = 2 * time.Second
session, _ := smux.Client(
cli,
Config,
)
time.Sleep(3 * time.Second)
if !session.IsClosed() {
t.Fatal(
"keepalive-timeout failed",
)
}
}
type blockWriteConn struct {
net.Conn
}
func (
c *blockWriteConn,
) Write(
b []byte,
) (
n int,
err error,
) {
forever := time.Hour * 24
time.Sleep(
forever,
)
return c.Conn.Write(
b,
)
}
func TestKeepAliveBlockWriteTimeout(
t *testing.T,
) {
defer u.Leakplug(
t,
)
ln, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
t.Fatal(
err,
)
}
defer ln.Close()
go func() {
ln.Accept()
}()
cli, err := net.Dial(
"tcp",
ln.Addr().String(),
)
if err != nil {
t.Fatal(
err,
)
}
defer cli.Close()
// when WriteFrame block, keepalive in old version never timeout
blockWriteCli := &blockWriteConn{cli}
Config := smux.DefaultConfig()
Config.KeepAliveInterval = time.Second
Config.KeepAliveTimeout = 2 * time.Second
session, _ := smux.Client(
blockWriteCli,
Config,
)
time.Sleep(3 * time.Second)
if !session.IsClosed() {
t.Fatal(
"keepalive-timeout failed",
)
}
}
func TestServerEcho(
t *testing.T,
) {
ln, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
t.Fatal(
err,
)
}
defer ln.Close()
go func() {
err := func() error {
Conn, err := ln.Accept()
if err != nil {
return err
}
defer Conn.Close()
session, err := smux.Server(
Conn,
nil,
)
if err != nil {
return err
}
defer session.Close()
buf := make(
[]byte,
10,
)
stream, err := session.OpenStream()
if err != nil {
return err
}
defer stream.Close()
for i := 0; i < 100; i++ {
msg := fmt.Sprintf(
"hello%v",
i,
)
stream.Write(
[]byte(msg),
)
n, err := stream.Read(
buf,
)
if err != nil {
return err
}
if got := string(buf[:n]); got != msg {
return fmt.Errorf(
"got: %q, want: %q",
got,
msg,
)
}
}
return nil
}()
if err != nil {
t.Error(
err,
)
}
}()
cli, err := net.Dial(
"tcp",
ln.Addr().String(),
)
if err != nil {
t.Fatal(
err,
)
}
defer cli.Close()
if session, err := smux.Client(
cli,
nil,
); err == nil {
if stream, err := session.AcceptStream(); err == nil {
buf := make(
[]byte,
65536,
)
for {
n, err := stream.Read(
buf,
)
if err != nil {
break
}
stream.Write(
buf[:n],
)
}
} else {
t.Fatal(
err,
)
}
} else {
t.Fatal(
err,
)
}
}
func TestSendWithoutRecv(
t *testing.T,
) {
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
const N = 100
for i := 0; i < N; i++ {
msg := fmt.Sprintf(
"hello%v",
i,
)
stream.Write(
[]byte(msg),
)
}
buf := make(
[]byte,
1,
)
if _, err := stream.Read(
buf,
); err != nil {
t.Fatal(
err,
)
}
stream.Close()
}
func TestWriteAfterClose(
t *testing.T,
) {
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
stream.Close()
if _, err := stream.Write(
[]byte(
"write after close",
),
); err == nil {
t.Fatal(
"write after close failed",
)
}
}
func TestReadStreamAfterSessionClose(
t *testing.T,
) {
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
session.Close()
buf := make(
[]byte,
10,
)
if _, err := stream.Read(
buf,
); err != nil {
t.Log(
err,
)
} else {
t.Fatal(
"read stream after session close succeeded",
)
}
}
func TestWriteStreamAfterConnectionClose(
t *testing.T,
) {
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
session.Conn.Close()
if _, err := stream.Write(
[]byte(
"write after Connection close",
),
); err == nil {
t.Fatal(
"write after Connection close failed",
)
}
}
func TestNumStreamAfterClose(
t *testing.T,
) {
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
if _, err := session.OpenStream(); err == nil {
if session.NumStreams() != 1 {
t.Fatal(
"wrong number of streams after opened",
)
}
session.Close()
if session.NumStreams() != 0 {
t.Fatal(
"wrong number of streams after session closed",
)
}
} else {
t.Fatal(
err,
)
}
cli.Close()
}
func TestRandomFrame(
t *testing.T,
) {
addr, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
for i := 0; i < 100; i++ {
rnd := make(
[]byte,
rand.Uint32()%1024,
)
io.ReadFull(
crand.Reader,
rnd,
)
session.Conn.Write(
rnd,
)
}
cli.Close()
// double syn
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
session, _ = smux.Client(
cli,
nil,
)
for i := 0; i < 100; i++ {
f := smux.NewFrame(
1,
smux.CmdSyn,
1000,
)
session.WriteFrame(
f,
)
}
cli.Close()
// random cmds
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
allcmds := []byte{smux.CmdSyn, smux.CmdFin, smux.CmdPsh, smux.CmdNop}
session, _ = smux.Client(
cli,
nil,
)
for i := 0; i < 100; i++ {
f := smux.NewFrame(
1,
allcmds[rand.Int()%len(allcmds)],
rand.Uint32(),
)
session.WriteFrame(
f,
)
}
cli.Close()
// random cmds & Sids
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
session, _ = smux.Client(
cli,
nil,
)
for i := 0; i < 100; i++ {
f := smux.NewFrame(
1,
byte(rand.Uint32()),
rand.Uint32(),
)
session.WriteFrame(
f,
)
}
cli.Close()
// random version
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
session, _ = smux.Client(
cli,
nil,
)
for i := 0; i < 100; i++ {
f := smux.NewFrame(
1,
byte(rand.Uint32()),
rand.Uint32(),
)
f.Ver = byte(
rand.Uint32(),
)
session.WriteFrame(
f,
)
}
cli.Close()
// incorrect size
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
session, _ = smux.Client(
cli,
nil,
)
f := smux.NewFrame(
1,
byte(rand.Uint32()),
rand.Uint32(),
)
rnd := make(
[]byte,
rand.Uint32()%1024,
)
io.ReadFull(
crand.Reader,
rnd,
)
f.Data = rnd
buf := make(
[]byte,
smux.HeaderSize+len(
f.Data,
),
)
buf[0] = f.Ver
buf[1] = f.Cmd
binary.LittleEndian.PutUint16(
buf[2:],
uint16(
len(
rnd,
)+1,
),
) // incorrect size
binary.LittleEndian.PutUint32(
buf[4:],
f.Sid,
)
copy(
buf[smux.HeaderSize:],
f.Data,
)
session.Conn.Write(
buf,
)
cli.Close()
// WriteFrame after die
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
session, _ = smux.Client(
cli,
nil,
)
// close first
session.Close()
for i := 0; i < 100; i++ {
f := smux.NewFrame(
1,
byte(rand.Uint32()),
rand.Uint32(),
)
session.WriteFrame(
f,
)
}
}
func TestWriteFrameInternal(
t *testing.T,
) {
addr, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
for i := 0; i < 100; i++ {
rnd := make(
[]byte,
rand.Uint32()%1024,
)
io.ReadFull(
crand.Reader,
rnd,
)
session.Conn.Write(
rnd,
)
}
cli.Close()
// WriteFrame after die
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
session, _ = smux.Client(
cli,
nil,
)
// close first
session.Close()
for i := 0; i < 100; i++ {
f := smux.NewFrame(
1,
byte(rand.Uint32()),
rand.Uint32(),
)
session.WriteFrameInternal(
f,
time.After(
session.Config.KeepAliveTimeout,
),
0,
)
}
// random cmds
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
allcmds := []byte{smux.CmdSyn, smux.CmdFin, smux.CmdPsh, smux.CmdNop}
session, _ = smux.Client(
cli,
nil,
)
for i := 0; i < 100; i++ {
f := smux.NewFrame(
1,
allcmds[rand.Int()%len(allcmds)],
rand.Uint32(),
)
session.WriteFrameInternal(
f,
time.After(
session.Config.KeepAliveTimeout,
),
0,
)
}
// deadline occur
{
c := make(
chan time.Time,
)
close(
c,
)
f := smux.NewFrame(
1,
allcmds[rand.Int()%len(allcmds)],
rand.Uint32(),
)
_, err := session.WriteFrameInternal(f, c, 0)
if !strings.Contains(err.Error(), "timeout") {
t.Fatal(
"write frame with deadline failed",
err,
)
}
netErr, ok := err.(net.Error)
if !ok {
t.Fatal(
"expected net.Error for timeout",
)
}
if netErr.Timeout() == false {
t.Fatal(
"expected Timeout() to be true on timeout error ",
err,
)
}
if netErr.Temporary() == false {
t.Fatal(
"expected Temporary() to be true on timeout error ",
err,
)
}
}
cli.Close()
{
cli, err = net.Dial(
"tcp",
addr,
)
if err != nil {
t.Fatal(
err,
)
}
Config := smux.DefaultConfig()
Config.KeepAliveInterval = time.Second
Config.KeepAliveTimeout = 2 * time.Second
session, _ = smux.Client(
&blockWriteConn{cli},
Config,
)
f := smux.NewFrame(
1,
byte(rand.Uint32()),
rand.Uint32(),
)
c := make(
chan time.Time,
)
go func() {
// die first, deadline second, better for coverage
time.Sleep(
time.Second,
)
session.Close()
time.Sleep(
time.Second,
)
close(
c,
)
}()
_, err = session.WriteFrameInternal(
f,
c,
0,
)
if !strings.Contains(
err.Error(),
"closed pipe",
) {
t.Fatal(
"write frame with to closed Conn failed ",
err,
)
}
}
}
func TestReadDeadline(
t *testing.T,
) {
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
const N = 100
buf := make(
[]byte,
10,
)
var readErr error
for i := 0; i < N; i++ {
stream.SetReadDeadline(
time.Now().Add(
-1 * time.Minute,
),
)
if _, readErr = stream.Read(
buf,
); readErr != nil {
break
}
}
if readErr != nil {
if !strings.Contains(
readErr.Error(),
"timeout",
) {
t.Fatalf(
"Wrong error: %v",
readErr,
)
}
} else {
t.Fatal(
"No error when reading with past deadline",
)
}
session.Close()
}
func TestWriteDeadline(
t *testing.T,
) {
_, stop, cli, err := setupServer(
t,
)
if err != nil {
t.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
stream, _ := session.OpenStream()
buf := make(
[]byte,
10,
)
var writeErr error
for {
stream.SetWriteDeadline(
time.Now().Add(
-1 * time.Minute,
),
)
if _, writeErr = stream.Write(
buf,
); writeErr != nil {
if !strings.Contains(
writeErr.Error(),
"timeout",
) {
t.Fatalf(
"Wrong error: %v",
writeErr,
)
}
break
}
}
session.Close()
}
func BenchmarkAcceptClose(
b *testing.B,
) {
_, stop, cli, err := setupServer(
b,
)
if err != nil {
b.Fatal(
err,
)
}
defer stop()
session, _ := smux.Client(
cli,
nil,
)
for i := 0; i < b.N; i++ {
if stream, err := session.OpenStream(); err == nil {
stream.Close()
} else {
b.Fatal(
err,
)
}
}
}
func BenchmarkConnSmux(
b *testing.B,
) {
cs, ss, err := getSmuxStreamPair()
if err != nil {
b.Fatal(
err,
)
}
defer cs.Close()
defer ss.Close()
bench(
b,
cs,
ss,
)
}
func BenchmarkConnTCP(
b *testing.B,
) {
cs, ss, err := getTCPConnectionPair()
if err != nil {
b.Fatal(
err,
)
}
defer cs.Close()
defer ss.Close()
bench(
b,
cs,
ss,
)
}
func getSmuxStreamPair() (
*smux.Stream,
*smux.Stream,
error,
) {
c1, c2, err := getTCPConnectionPair()
if err != nil {
return nil, nil, err
}
s, err := smux.Server(
c2,
nil,
)
if err != nil {
return nil, nil, err
}
c, err := smux.Client(
c1,
nil,
)
if err != nil {
return nil, nil, err
}
var ss *smux.Stream
done := make(
chan error,
)
go func() {
var rerr error
ss, rerr = s.AcceptStream()
done <- rerr
close(
done,
)
}()
cs, err := c.OpenStream()
if err != nil {
return nil, nil, err
}
err = <-done
if err != nil {
return nil, nil, err
}
return cs, ss, nil
}
func getTCPConnectionPair() (
net.Conn,
net.Conn,
error,
) {
lst, err := net.Listen(
"tcp",
"localhost:0",
)
if err != nil {
return nil, nil, err
}
defer lst.Close()
var Conn0 net.Conn
var err0 error
done := make(
chan struct{},
)
go func() {
Conn0, err0 = lst.Accept()
close(
done,
)
}()
Conn1, err := net.Dial(
"tcp",
lst.Addr().String(),
)
if err != nil {
return nil, nil, err
}
<-done
if err0 != nil {
return nil, nil, err0
}
return Conn0, Conn1, nil
}
func bench(
b *testing.B,
rd io.Reader,
wr io.Writer,
) {
buf := make(
[]byte,
128*1024,
)
buf2 := make(
[]byte,
128*1024,
)
b.SetBytes(
128 * 1024,
)
b.ResetTimer()
b.ReportAllocs()
var wg sync.WaitGroup
wg.Add(
1,
)
go func() {
defer wg.Done()
count := 0
for {
n, _ := rd.Read(
buf2,
)
count += n
if count == 128*1024*b.N {
return
}
}
}()
for i := 0; i < b.N; i++ {
wr.Write(
buf,
)
}
wg.Wait()
}