M bind.go => bind.go +5 -4
@@ 46,14 46,15 @@ func BindResource() StreamFeature {
return true, nil, d.DecodeElement(&parsed, start)
},
Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
- if (session.state & Received) == Received {
+ if (session.State() & Received) == Received {
panic("xmpp: bind not yet implemented")
}
conn := session.Conn()
+ d := session.Decoder()
reqID := internal.RandomID()
- if resource := session.config.Origin.Resourcepart(); resource == "" {
+ if resource := session.Config().Origin.Resourcepart(); resource == "" {
// Send a request for the server to set a resource part.
_, err = fmt.Fprintf(conn, bindIQServerGeneratedRP, reqID)
} else {
@@ 63,7 64,7 @@ func BindResource() StreamFeature {
if err != nil {
return mask, nil, err
}
- tok, err := session.in.d.Token()
+ tok, err := d.Token()
if err != nil {
return mask, nil, err
}
@@ 80,7 81,7 @@ func BindResource() StreamFeature {
}{}
switch start.Name {
case xml.Name{Space: ns.Client, Local: "iq"}:
- if err = session.in.d.DecodeElement(&resp, &start); err != nil {
+ if err = d.DecodeElement(&resp, &start); err != nil {
return mask, nil, err
}
default:
M sasl.go => sasl.go +7 -5
@@ 67,7 67,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
return true, parsed.List, err
},
Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
- if (session.state & Received) == Received {
+ if (session.State() & Received) == Received {
panic("SASL server not yet implemented")
}
@@ 122,17 122,19 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
return mask, nil, err
}
+ d := session.Decoder()
+
// If we're already done after the first step, decode the <success/> or
// <failure/> before we exit.
if !more {
- tok, err := session.in.d.Token()
+ tok, err := d.Token()
if err != nil {
return mask, nil, err
}
if t, ok := tok.(xml.StartElement); ok {
// TODO: Handle the additional data that could be returned if
// success?
- _, _, err := decodeSASLChallenge(session.in.d, t, false)
+ _, _, err := decodeSASLChallenge(d, t, false)
if err != nil {
return mask, nil, err
}
@@ 148,13 150,13 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
return mask, nil, ctx.Err()
default:
}
- tok, err := session.in.d.Token()
+ tok, err := d.Token()
if err != nil {
return mask, nil, err
}
var challenge []byte
if t, ok := tok.(xml.StartElement); ok {
- challenge, success, err = decodeSASLChallenge(session.in.d, t, true)
+ challenge, success, err = decodeSASLChallenge(d, t, true)
if err != nil {
return mask, nil, err
}
M starttls.go => starttls.go +13 -8
@@ 11,6 11,7 @@ import (
"errors"
"fmt"
"io"
+ "net"
"mellium.im/xmpp/ns"
"mellium.im/xmpp/streamerror"
@@ 56,22 57,26 @@ func StartTLS(required bool) StreamFeature {
return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
},
Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
- conn := session.conn
- if conn == nil {
+ conn, ok := session.Conn().(net.Conn)
+ if !ok || conn == nil {
return mask, nil, ErrTLSUpgradeFailed
}
+ config := session.Config()
+ state := session.State()
+ d := session.Decoder()
+
// Fetch or create a TLSConfig to use.
var tlsconf *tls.Config
- if session.config.TLSConfig == nil {
+ if config.TLSConfig == nil {
tlsconf = &tls.Config{
ServerName: session.LocalAddr().Domain().String(),
}
} else {
- tlsconf = session.config.TLSConfig
+ tlsconf = config.TLSConfig
}
- if (session.state & Received) == Received {
+ if (state & Received) == Received {
fmt.Fprint(conn, `<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
rw = tls.Server(conn, tlsconf)
} else {
@@ 79,7 84,7 @@ func StartTLS(required bool) StreamFeature {
fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
// Receive a <proceed/> or <failure/> response from the server.
- t, err := session.in.d.Token()
+ t, err := d.Token()
if err != nil {
return mask, nil, err
}
@@ 90,13 95,13 @@ func StartTLS(required bool) StreamFeature {
return mask, nil, streamerror.UnsupportedStanzaType
case tok.Name.Local == "proceed":
// Skip the </proceed> token.
- if err = session.in.d.Skip(); err != nil {
+ if err = d.Skip(); err != nil {
return mask, nil, streamerror.InvalidXML
}
rw = tls.Client(conn, tlsconf)
case tok.Name.Local == "failure":
// Skip the </failure> token.
- if err = session.in.d.Skip(); err != nil {
+ if err = d.Skip(); err != nil {
err = streamerror.InvalidXML
}
// Failure is not an "error", it's expected behavior. Immediately