~ashkeel/kilovolt-client-go

8c6f8f0882f5f68277e587698028e209557f4281 — Ash Keel 3 years ago 982fcb5 v1.1.1
Use sync map for solving race conditions
3 files changed, 26 insertions(+), 21 deletions(-)

M client.go
M go.mod
M go.sum
M client.go => client.go +23 -21
@@ 11,6 11,7 @@ import (

	"github.com/gorilla/websocket"
	jsoniter "github.com/json-iterator/go"
	cmap "github.com/orcaman/concurrent-map"
	"github.com/sirupsen/logrus"

	kv "github.com/strimertul/kilovolt/v3"


@@ 28,9 29,9 @@ type Client struct {

	headers       http.Header
	ws            *websocket.Conn
	mu            sync.Mutex // Used to avoid concurrent writes to socket
	requests      map[string]chan<- string
	subscriptions map[string][]chan<- string
	mu            sync.Mutex         // Used to avoid concurrent writes to socket
	requests      cmap.ConcurrentMap // map[string]chan<- string
	subscriptions cmap.ConcurrentMap // map[string][]chan<- string
}

type ClientOptions struct {


@@ 49,8 50,8 @@ func NewClient(endpoint string, options ClientOptions) (*Client, error) {
		headers:       options.Headers,
		ws:            nil,
		mu:            sync.Mutex{},
		requests:      make(map[string]chan<- string),
		subscriptions: make(map[string][]chan<- string),
		requests:      cmap.New(), // make(map[string]chan<- string),
		subscriptions: cmap.New(), // make(map[string][]chan<- string),
	}

	err := client.ConnectToWebsocket()


@@ 103,9 104,10 @@ func (s *Client) ConnectToWebsocket() error {
				// Check message
				if response.RequestID != "" {
					// We have a request ID, send byte chunk over to channel
					if chn, ok := s.requests[response.RequestID]; ok {
					if chn, ok := s.requests.Get(response.RequestID); ok {
						s.Logger.WithField("rid", response.RequestID).Trace("recv response")
						chn <- msg
						chn.(chan string) <- msg
						s.requests.Remove(response.RequestID)
					} else {
						s.Logger.WithField("rid", response.RequestID).Error("received response for unknown RID")
					}


@@ 121,15 123,12 @@ func (s *Client) ConnectToWebsocket() error {
							continue
						}
						// Deliver to subscriptions
						for sub, chans := range s.subscriptions {
							if push.Key != sub {
								continue
							}

							for _, chann := range chans {
						if subs, ok := s.subscriptions.Get(push.Key); ok {
							for _, chann := range subs.([]chan string) {
								chann <- push.NewValue
							}
						}

					}
				}
			}


@@ 202,10 201,11 @@ func (s *Client) SetJSON(key string, data interface{}) error {
func (s *Client) Subscribe(key string) (chan string, error) {
	chn := make(chan string)

	subs, ok := s.subscriptions[key]
	data, ok := s.subscriptions.Get(key)
	subs := data.([]chan string)

	needsAPISubscription := !ok || len(subs) < 1
	s.subscriptions[key] = append(subs, chn)
	s.subscriptions.Set(key, append(subs, chn))

	var err error
	// If this is the first time we subscribe to this key, ask server to push updates


@@ 222,14 222,16 @@ func (s *Client) Subscribe(key string) (chan string, error) {
}

func (s *Client) Unsubscribe(key string, chn chan string) error {
	if _, ok := s.subscriptions[key]; !ok {
	data, ok := s.subscriptions.Get(key)
	if !ok {
		return nil
	}
	chans := data.([]chan string)

	found := false
	for idx, sub := range s.subscriptions[key] {
	for idx, sub := range chans {
		if sub == chn {
			s.subscriptions[key] = append(s.subscriptions[key][:idx], s.subscriptions[key][idx+1:]...)
			s.subscriptions.Set(key, append(chans[:idx], chans[idx+1:]...))
			found = true
		}
	}


@@ 239,7 241,7 @@ func (s *Client) Unsubscribe(key string, chn chan string) error {
	}

	// If we removed all subscribers, ask server to not push updates to us anymore
	if len(s.subscriptions[key]) < 1 {
	if len(chans) < 1 {
		_, err := s.makeRequest(kv.Request{
			CmdName: kv.CmdUnsubscribeKey,
			Data: map[string]interface{}{


@@ 256,14 258,14 @@ func (s *Client) makeRequest(request kv.Request) (kv.Response, error) {
	rid := ""
	for {
		rid = fmt.Sprintf("%x", rand.Int63())
		if _, ok := s.requests[rid]; ok {
		if s.requests.Has(rid) {
			continue
		}
		break
	}

	responseChannel := make(chan string)
	s.requests[rid] = responseChannel
	s.requests.Set(rid, responseChannel)

	request.RequestID = rid
	err := s.send(request)

M go.mod => go.mod +1 -0
@@ 5,6 5,7 @@ go 1.16
require (
	github.com/gorilla/websocket v1.4.2
	github.com/json-iterator/go v1.1.11
	github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect
	github.com/sirupsen/logrus v1.8.1
	github.com/strimertul/kilovolt/v3 v3.0.0
)

M go.sum => go.sum +2 -0
@@ 62,6 62,8 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc h1:Ak86L+yDSOzKFa7WM5bf5itSOo1e3Xh8bm5YCMUXIjQ=
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/peterh/liner v0.0.0-20170317030525-88609521dc4b/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=