package gfsmux import ( "container/heap" "encoding/binary" "errors" "io" "net" "sync" "sync/atomic" "time" ) const ( defaultAcceptBacklog = 2048 ) var ( // ErrInvalidProtocol version or bag negotiation. ErrInvalidProtocol = errors.New( "invalid protocol", ) // ErrConsumed protocol error, indicates desync ErrConsumed = errors.New( "peer consumed more than sent", ) // ErrGoAway overflow condition, restart it all. ErrGoAway = errors.New( "stream id overflows, should start a new Connection", ) // ErrTimeout ... ErrTimeout = &timeoutError{} // ErrWouldBlock error for invalid blocking I/O operating ErrWouldBlock = errors.New( "operation would block on IO", ) ) var _ net.Error = &timeoutError{} type timeoutError struct{} func ( e *timeoutError, ) Error() string { return "timeout" } func ( e *timeoutError, ) Timeout() bool { return true } func ( e *timeoutError, ) Temporary() bool { return true } // WriteRequest ... type WriteRequest struct { Prio uint64 frame Frame result chan writeResult } type writeResult struct { n int err error } type buffersWriter interface { WriteBuffers( v [][]byte, ) ( n int, err error, ) } // Session defines a multiplexed Connection for streams type Session struct { Conn io.ReadWriteCloser Config *Config nextStreamID uint32 // next stream identifier nextStreamIDLock sync.Mutex bucket int32 // token bucket bucketNotify chan struct{} // used for waiting for tokens streams map[uint32]*Stream // all streams in this session streamLock sync.Mutex // locks streams die chan struct{} // flag session has died dieOnce sync.Once // socket error handling socketReadError atomic.Value socketWriteError atomic.Value chSocketReadError chan struct{} chSocketWriteError chan struct{} socketReadErrorOnce sync.Once socketWriteErrorOnce sync.Once // smux protocol errors protoError atomic.Value chProtoError chan struct{} protoErrorOnce sync.Once chAccepts chan *Stream dataReady int32 // flag data has arrived goAway int32 // flag id exhausted deadline atomic.Value shaper chan WriteRequest // a shaper for writing writes chan WriteRequest } func newSession( Config *Config, Conn io.ReadWriteCloser, client bool, ) *Session { s := new( Session, ) s.die = make( chan struct{}, ) s.Conn = Conn s.Config = Config s.streams = make( map[uint32]*Stream, ) s.chAccepts = make( chan *Stream, defaultAcceptBacklog, ) s.bucket = int32( Config.MaxReceiveBuffer, ) s.bucketNotify = make( chan struct{}, 1, ) s.shaper = make( chan WriteRequest, ) s.writes = make( chan WriteRequest, ) s.chSocketReadError = make( chan struct{}, ) s.chSocketWriteError = make( chan struct{}, ) s.chProtoError = make( chan struct{}, ) if client { s.nextStreamID = 1 } else { s.nextStreamID = 0 } go s.shaperLoop() go s.recvLoop() go s.sendLoop() if !Config.KeepAliveDisabled { go s.keepalive() } return s } // OpenStream is used to create a new stream func ( s *Session, ) OpenStream() ( *Stream, error, ) { if s.IsClosed() { return nil, io.ErrClosedPipe } // generate stream id s.nextStreamIDLock.Lock() if s.goAway > 0 { s.nextStreamIDLock.Unlock() return nil, ErrGoAway } s.nextStreamID += 2 Sid := s.nextStreamID if Sid == Sid%2 { // stream-id overflows s.goAway = 1 s.nextStreamIDLock.Unlock() return nil, ErrGoAway } s.nextStreamIDLock.Unlock() stream := newStream( Sid, s.Config.MaxFrameSize, s, ) if _, err := s.WriteFrame( NewFrame( byte(s.Config.Version), CmdSyn, Sid, ), ); err != nil { return nil, err } s.streamLock.Lock() defer s.streamLock.Unlock() select { case <-s.chSocketReadError: return nil, s.socketReadError.Load().(error) case <-s.chSocketWriteError: return nil, s.socketWriteError.Load().(error) case <-s.die: return nil, io.ErrClosedPipe default: s.streams[Sid] = stream return stream, nil } } // Open returns a generic ReadWriteCloser func ( s *Session, ) Open() ( io.ReadWriteCloser, error, ) { return s.OpenStream() } // AcceptStream is used to block until the next available stream // is ready to be accepted. func ( s *Session, ) AcceptStream() ( *Stream, error, ) { var deadline <-chan time.Time if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer( time.Until( d, ), ) defer timer.Stop() deadline = timer.C } select { case stream := <-s.chAccepts: return stream, nil case <-deadline: return nil, ErrTimeout case <-s.chSocketReadError: return nil, s.socketReadError.Load().(error) case <-s.chProtoError: return nil, s.protoError.Load().(error) case <-s.die: return nil, io.ErrClosedPipe } } // Accept Returns a generic ReadWriteCloser instead of smux.Stream func ( s *Session, ) Accept() ( io.ReadWriteCloser, error, ) { return s.AcceptStream() } // Close is used to close the session and all streams. func ( s *Session, ) Close() error { var once bool s.dieOnce.Do(func() { close( s.die, ) once = true }) if once { s.streamLock.Lock() for k := range s.streams { s.streams[k].sessionClose() } s.streamLock.Unlock() return s.Conn.Close() } return io.ErrClosedPipe } // notifyBucket notifies recvLoop that bucket is available func ( s *Session, ) notifyBucket() { select { case s.bucketNotify <- struct{}{}: default: } } func ( s *Session, ) notifyReadError( err error, ) { s.socketReadErrorOnce.Do(func() { s.socketReadError.Store( err, ) close( s.chSocketReadError, ) }) } func ( s *Session, ) notifyWriteError( err error, ) { s.socketWriteErrorOnce.Do(func() { s.socketWriteError.Store( err, ) close( s.chSocketWriteError, ) }) } func ( s *Session, ) notifyProtoError( err error, ) { s.protoErrorOnce.Do(func() { s.protoError.Store( err, ) close( s.chProtoError, ) }) } // IsClosed does a safe check to see if we have shutdown func ( s *Session, ) IsClosed() bool { select { case <-s.die: return true default: return false } } // NumStreams returns the number of currently open streams func ( s *Session, ) NumStreams() int { if s.IsClosed() { return 0 } s.streamLock.Lock() defer s.streamLock.Unlock() return len( s.streams, ) } // SetDeadline sets a deadline used by Accept* calls. // A zero time value disables the deadline. func ( s *Session, ) SetDeadline( t time.Time, ) error { s.deadline.Store( t, ) return nil } // LocalAddr satisfies net.Conn interface func ( s *Session, ) LocalAddr() net.Addr { if ts, ok := s.Conn.(interface { LocalAddr() net.Addr }); ok { return ts.LocalAddr() } return nil } // RemoteAddr satisfies net.Conn interface func ( s *Session, ) RemoteAddr() net.Addr { if ts, ok := s.Conn.(interface { RemoteAddr() net.Addr }); ok { return ts.RemoteAddr() } return nil } // notify the session that a stream has closed func ( s *Session, ) streamClosed( Sid uint32, ) { s.streamLock.Lock() // return remaining tokens to the bucket if n := s.streams[Sid].recycleTokens(); n > 0 { if atomic.AddInt32( &s.bucket, int32(n), ) > 0 { s.notifyBucket() } } delete( s.streams, Sid, ) s.streamLock.Unlock() } // returnTokens is called by stream to return token after read func ( s *Session, ) returnTokens( n int, ) { if atomic.AddInt32( &s.bucket, int32(n), ) > 0 { s.notifyBucket() } } // recvLoop keeps on reading from underlying Connection if tokens are available func ( s *Session, ) recvLoop() { var hdr rawHeader var updHdr updHeader for { for atomic.LoadInt32( &s.bucket, ) <= 0 && !s.IsClosed() { select { case <-s.bucketNotify: case <-s.die: return } } // read header first if _, err := io.ReadFull( s.Conn, hdr[:], ); err == nil { atomic.StoreInt32( &s.dataReady, 1, ) if hdr.Version() != byte( s.Config.Version, ) { s.notifyProtoError( ErrInvalidProtocol, ) return } Sid := hdr.StreamID() switch hdr.Cmd() { case CmdNop: case CmdSyn: s.streamLock.Lock() if _, ok := s.streams[Sid]; !ok { stream := newStream( Sid, s.Config.MaxFrameSize, s, ) s.streams[Sid] = stream select { case s.chAccepts <- stream: case <-s.die: } } s.streamLock.Unlock() case CmdFin: s.streamLock.Lock() if stream, ok := s.streams[Sid]; ok { stream.fin() stream.notifyReadEvent() } s.streamLock.Unlock() case CmdPsh: if hdr.Length() > 0 { newbuf := defaultAllocator.Get( int(hdr.Length()), ) if written, err := io.ReadFull( s.Conn, newbuf, ); err == nil { s.streamLock.Lock() if stream, ok := s.streams[Sid]; ok { stream.pushBytes( newbuf, ) atomic.AddInt32( &s.bucket, -int32(written), ) stream.notifyReadEvent() } s.streamLock.Unlock() } else { s.notifyReadError( err, ) return } } case CmdUpd: if _, err := io.ReadFull( s.Conn, updHdr[:], ); err == nil { s.streamLock.Lock() if stream, ok := s.streams[Sid]; ok { stream.update( updHdr.Consumed(), updHdr.Window(), ) } s.streamLock.Unlock() } else { s.notifyReadError( err, ) return } default: s.notifyProtoError( ErrInvalidProtocol, ) return } } else { s.notifyReadError( err, ) return } } } func ( s *Session, ) keepalive() { tickerPing := time.NewTicker( s.Config.KeepAliveInterval, ) tickerTimeout := time.NewTicker( s.Config.KeepAliveTimeout, ) defer tickerPing.Stop() defer tickerTimeout.Stop() for { select { case <-tickerPing.C: s.WriteFrameInternal( NewFrame( byte(s.Config.Version), CmdNop, 0, ), tickerPing.C, 0, ) s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32( &s.dataReady, 1, 0, ) { // recvLoop may block while bucket is 0, in this case, // session should not be closed. if atomic.LoadInt32( &s.bucket, ) > 0 { s.Close() return } } case <-s.die: return } } } // shaper shapes the sending sequence among streams func ( s *Session, ) shaperLoop() { var reqs ShaperHeap var next WriteRequest var chWrite chan WriteRequest for { if len( reqs, ) > 0 { chWrite = s.writes next = heap.Pop(&reqs).(WriteRequest) } else { chWrite = nil } select { case <-s.die: return case r := <-s.shaper: if chWrite != nil { // next is valid, reshape heap.Push( &reqs, next, ) } heap.Push( &reqs, r, ) case chWrite <- next: } } } func ( s *Session, ) sendLoop() { var buf []byte var n int var err error var vec [][]byte // vector for writeBuffers bw, ok := s.Conn.(buffersWriter) if ok { buf = make([]byte, HeaderSize) vec = make([][]byte, 2) } else { buf = make([]byte, (1<<16)+HeaderSize) } for { select { case <-s.die: return case request := <-s.writes: buf[0] = request.frame.Ver buf[1] = request.frame.Cmd binary.LittleEndian.PutUint16( buf[2:], uint16( len( request.frame.Data, ), ), ) binary.LittleEndian.PutUint32( buf[4:], request.frame.Sid, ) if len( vec, ) > 0 { vec[0] = buf[:HeaderSize] vec[1] = request.frame.Data n, err = bw.WriteBuffers( vec, ) } else { copy( buf[HeaderSize:], request.frame.Data, ) n, err = s.Conn.Write( buf[:HeaderSize+len(request.frame.Data)], ) } n -= HeaderSize if n < 0 { n = 0 } result := writeResult{ n: n, err: err, } request.result <- result close( request.result, ) // store Conn error if err != nil { s.notifyWriteError( err, ) return } } } } // WriteFrame writes the frame to the underlying Connection // and returns the number of bytes written if successful func ( s *Session, ) WriteFrame( f Frame, ) ( n int, err error, ) { return s.WriteFrameInternal( f, nil, 0, ) } // WriteFrameInternal is to support deadline used in keepalive func ( s *Session, ) WriteFrameInternal( f Frame, deadline <-chan time.Time, Prio uint64, ) ( int, error, ) { req := WriteRequest{ Prio: Prio, frame: f, result: make( chan writeResult, 1, ), } select { case s.shaper <- req: case <-s.die: return 0, io.ErrClosedPipe case <-s.chSocketWriteError: return 0, s.socketWriteError.Load().(error) case <-deadline: return 0, ErrTimeout } select { case result := <-req.result: return result.n, result.err case <-s.die: return 0, io.ErrClosedPipe case <-s.chSocketWriteError: return 0, s.socketWriteError.Load().(error) case <-deadline: return 0, ErrTimeout } }