~apreiml/go-wormhole

46e8dd0976986da1f02b748b11144a1729f579e8 — Armin Preiml 2 years ago ebda5d5
implement send file
4 files changed, 162 insertions(+), 35 deletions(-)

M apps/transfer.go
M main.go
M wormhole/client.go
M wormhole/transit/transit.go
M apps/transfer.go => apps/transfer.go +73 -2
@@ 8,7 8,10 @@ import (
	"fmt"
	"hash"
	"io"
	"log"
	"net/url"
	"os"
	"path/filepath"
	"strings"

	"git.sr.ht/~apreiml/go-wormhole/passphrase"


@@ 55,6 58,7 @@ type directory struct {

type answer struct {
	MessageAck string `json:"message_ack"`
	FileAck    string `json:"file_ack"`
}

const (


@@ 110,7 114,69 @@ func (t *Transfer) SendMessage(msg string) error {
	return nil
}

func (t *Transfer) SendFile() {
func (t *Transfer) SendFile(f *os.File) error {
	t.trans = transit.New(t.w, transit.RoleSender)
	t.transferType = TransferTypeFile
	var msg message

	stats, err := f.Stat()
	if err != nil {
		return err
	}
	t.w.SendMessageJson(map[string]transit.Config{"transit": t.trans.OurAbilities()})
	t.w.SendMessageJson(map[string]map[string]file{
		"offer": map[string]file{
			"file": file{
				Name: filepath.Base(f.Name()),
				Size: stats.Size(),
			},
		},
	})

MessageLoop:
	for {
		println("wait for message..")
		msg, err = t.nextMessage()
		if err != nil {
			println("fail")
			return err
		}
		//	fmt.Printf("%#v\n", msg)
		switch {
		case len(msg.Transit.Abilities) > 0:
			println("negotiate abilities")
			t.trans.TheirAbilities(msg.Transit)
		case msg.Answer.FileAck == "ok":
			t.w.Close()
			break MessageLoop
		default:
			println("unknown message")
			err = errors.New("unknown message")
			return err
		}
	}

	t.trans.Connect()

	log.Println("start write file")
	_, err = io.Copy(t.trans, f)
	if err != nil {
		return err
	}
	t.trans.Flush()

	log.Println("end write file")

	b := make([]byte, 4096)
	n, err := t.trans.Read(b)
	if err != nil {
		return err
	}

	log.Println(string(b[:n]))

	// TODO CHECK ack
	return nil
}

func (t *Transfer) SendDirectory() {


@@ 210,10 276,15 @@ func (t *Transfer) Read(p []byte) (n int, err error) {
	return
}

func (t *Transfer) Write(p []byte) (n int, err error) {
	n, err = t.trans.Write(p)
	return
}

// TODO only ack if file read fully
func (t *Transfer) Close() {
	fmt.Printf("CLOSE %v %d\n", t.complete, t.transferType)
	if t.complete && t.transferType != TransferTypeMessage {
	if t.transferType != TransferTypeMessage && t.trans.Role() == transit.RoleReceiver && t.complete {
		sum := t.hashSum.Sum(nil)
		ack, _ := json.Marshal(map[string]string{
			"ack":    "ok",

M main.go => main.go +32 -0
@@ 34,6 34,13 @@ func main() {
				os.Exit(1)
			}
		}

		err := sendFile(args[1])
		if err != nil {
			fmt.Println(err)
			os.Exit(1)
		}

	case "receive":
		if len(args) != 2 {
			printHelp()


@@ 130,3 137,28 @@ func sendMessage() error {

	return t.SendMessage(string(data))
}

func sendFile(path string) error {
	f, err := os.Open(path)
	if err != nil {
		return err
	}
	u := url.URL{Scheme: "ws", Host: *addr, Path: "/v1"}

	t, err := apps.NewTransfer(u)
	if err != nil {
		return err
	}
	defer t.Close()

	println("make password")
	pw, err := t.MakePassword()
	if err != nil {
		return err
	}
	println("password")

	fmt.Println("wormhole receive", pw)

	return t.SendFile(f)
}

M wormhole/client.go => wormhole/client.go +0 -4
@@ 236,8 236,6 @@ func (w *Client) consume(b []byte) error {
		return ServerErr{msg.Error}
	}

	log.Print(string(b))

	switch msg.Mtype {
	case "ack":
		// ignore


@@ 318,7 316,6 @@ func (w Client) sendCommand(cmd interface{}) error {
	if err != nil {
		return err
	}
	log.Print("SEND: ", string(json))
	w.send <- json
	return nil
}


@@ 399,7 396,6 @@ func (w *Client) addUnencryptedMessage(id, phase string, body []byte) {
}

func (w *Client) addMessage(id, phase string, body []byte) error {
	log.Print("Send ENC: ", string(body))
	encrypted, err := w.encrypt(body, phase)
	if err != nil {
		return w.stopWithError(err)

M wormhole/transit/transit.go => wormhole/transit/transit.go +57 -29
@@ 12,6 12,7 @@ import (
	"net"
	"strconv"
	"strings"
	"sync"

	"git.sr.ht/~apreiml/go-wormhole/wormhole"
	"golang.org/x/crypto/nacl/secretbox"


@@ 77,6 78,7 @@ type Transit struct {
	rw            *bufio.ReadWriter
	conn          io.ReadWriteCloser
	connSet       chan struct{}
	connMutex     sync.Mutex
	buf           []byte

	abilities map[string]bool


@@ 90,16 92,7 @@ func New(w *wormhole.Client, role string) *Transit {
	t := &Transit{role: role}
	w.DeriveKey(t.key[:], w.AppId()+"/transit-key")

	var otherRole string
	switch role {
	case RoleReceiver:
		otherRole = RoleSender
	case RoleSender:
		otherRole = RoleReceiver
	default:
		panic("invalid role: use transit.RoleSender or transit.RoleReceiver")
	}

	otherRole := getOtherRole(role)
	wormhole.DeriveSubKey(t.key[:], t.sendHandshake[:], "transit_"+role)
	wormhole.DeriveSubKey(t.key[:], t.recvHandshake[:], "transit_"+otherRole)
	wormhole.DeriveSubKey(t.key[:], t.sendKey[:], "transit_record_"+role+"_key")


@@ 113,6 106,21 @@ func New(w *wormhole.Client, role string) *Transit {
	return t
}

func getOtherRole(r string) string {
	switch r {
	case RoleReceiver:
		return RoleSender
	case RoleSender:
		return RoleReceiver
	default:
		panic("invalid role: use transit.RoleSender or transit.RoleReceiver")
	}
}

func (t Transit) Role() string {
	return t.role
}

// TODO and hints
func (t *Transit) TheirAbilities(c Config) {
	filtered := make(map[string]bool)


@@ 181,9 189,14 @@ func (t *Transit) listenDirectTcp() []Hint {

		if t.handshake(rw) {
			log.Print("Listen TCP: go")
			t.setConnection(conn, rw)
			if !t.setConnection(conn, rw) {
				log.Print("Listen TCP: nevermind")
				conn.Close()
				ln.Close()
			}
		} else {
			log.Print("Listen TCP: nevermind")
			log.Print("Listen TCP: handshake failed")
			conn.Close()
			ln.Close()
		}
	}()


@@ 199,12 212,25 @@ func (t *Transit) listenDirectTcp() []Hint {
}

func (t *Transit) setConnection(c io.ReadWriteCloser, rw *bufio.ReadWriter) bool {
	// TODO mutex
	t.connMutex.Lock()
	defer t.connMutex.Unlock()

	if t.conn != nil {
		if t.role == RoleSender {
			rw.WriteString("nevermind\n")
			rw.Flush()
		}
		return false
	}

	t.conn = c
	t.rw = rw

	if t.role == RoleSender {
		rw.WriteString("go\n")
		rw.Flush()
	}

	close(t.connSet)
	return true
}


@@ 273,7 299,11 @@ func (t *Transit) tryFindChannel() {
				}

				log.Printf("Direct conn established: %#v\n", h)
				t.setConnection(conn, rw)
				if !t.setConnection(conn, rw) {
					conn.Close()
					log.Print("Direct too late")
					return
				}
			}(h)
		}
	}


@@ 289,7 319,7 @@ func (t *Transit) handshake(rw *bufio.ReadWriter) bool {
	}

	tokens := strings.Split(msg, " ")
	if len(tokens) != 4 || tokens[0] != "transit" || tokens[1] != "sender" ||
	if len(tokens) != 4 || tokens[0] != "transit" || tokens[1] != getOtherRole(t.role) ||
		tokens[3] != "ready\n" {
		return false
	}


@@ 303,9 333,11 @@ func (t *Transit) handshake(rw *bufio.ReadWriter) bool {
		return false
	}

	msg, err = rw.ReadString('\n')
	if err != nil || msg != "go\n" {
		return false
	if t.role == RoleReceiver {
		msg, err = rw.ReadString('\n')
		if err != nil || msg != "go\n" {
			return false
		}
	}

	return true


@@ 381,23 413,19 @@ func (t *Transit) Write(p []byte) (n int, err error) {
	binary.BigEndian.PutUint32(size[:], uint32(len(encrypt)))

	t.rw.Write(size[:])
	t.rw.Write(encrypt)
	_, err = t.rw.Write(encrypt)
	// TODO err
	return len(p), nil
	return len(p), err
}

func (t *Transit) Close() error {
	fmt.Printf("LAST: %#v\n", t)
	connPending := true
	select {
	case _, connPending = <-t.connSet:
	default:
	}
func (t *Transit) Flush() error {
	return t.rw.Flush()
}

	if !connPending {
func (t *Transit) Close() error {
	if t.conn != nil {
		t.rw.Flush()
		t.conn.Close()
	}
	fmt.Println("transit close")
	return nil
}