@@ 11,6 11,7 @@ import (
"github.com/gorilla/websocket"
jsoniter "github.com/json-iterator/go"
+ cmap "github.com/orcaman/concurrent-map"
"github.com/sirupsen/logrus"
kv "github.com/strimertul/kilovolt/v3"
@@ 28,9 29,9 @@ type Client struct {
headers http.Header
ws *websocket.Conn
- mu sync.Mutex // Used to avoid concurrent writes to socket
- requests map[string]chan<- string
- subscriptions map[string][]chan<- string
+ mu sync.Mutex // Used to avoid concurrent writes to socket
+ requests cmap.ConcurrentMap // map[string]chan<- string
+ subscriptions cmap.ConcurrentMap // map[string][]chan<- string
}
type ClientOptions struct {
@@ 49,8 50,8 @@ func NewClient(endpoint string, options ClientOptions) (*Client, error) {
headers: options.Headers,
ws: nil,
mu: sync.Mutex{},
- requests: make(map[string]chan<- string),
- subscriptions: make(map[string][]chan<- string),
+ requests: cmap.New(), // make(map[string]chan<- string),
+ subscriptions: cmap.New(), // make(map[string][]chan<- string),
}
err := client.ConnectToWebsocket()
@@ 103,9 104,10 @@ func (s *Client) ConnectToWebsocket() error {
// Check message
if response.RequestID != "" {
// We have a request ID, send byte chunk over to channel
- if chn, ok := s.requests[response.RequestID]; ok {
+ if chn, ok := s.requests.Get(response.RequestID); ok {
s.Logger.WithField("rid", response.RequestID).Trace("recv response")
- chn <- msg
+ chn.(chan string) <- msg
+ s.requests.Remove(response.RequestID)
} else {
s.Logger.WithField("rid", response.RequestID).Error("received response for unknown RID")
}
@@ 121,15 123,12 @@ func (s *Client) ConnectToWebsocket() error {
continue
}
// Deliver to subscriptions
- for sub, chans := range s.subscriptions {
- if push.Key != sub {
- continue
- }
-
- for _, chann := range chans {
+ if subs, ok := s.subscriptions.Get(push.Key); ok {
+ for _, chann := range subs.([]chan string) {
chann <- push.NewValue
}
}
+
}
}
}
@@ 202,10 201,11 @@ func (s *Client) SetJSON(key string, data interface{}) error {
func (s *Client) Subscribe(key string) (chan string, error) {
chn := make(chan string)
- subs, ok := s.subscriptions[key]
+ data, ok := s.subscriptions.Get(key)
+ subs := data.([]chan string)
needsAPISubscription := !ok || len(subs) < 1
- s.subscriptions[key] = append(subs, chn)
+ s.subscriptions.Set(key, append(subs, chn))
var err error
// If this is the first time we subscribe to this key, ask server to push updates
@@ 222,14 222,16 @@ func (s *Client) Subscribe(key string) (chan string, error) {
}
func (s *Client) Unsubscribe(key string, chn chan string) error {
- if _, ok := s.subscriptions[key]; !ok {
+ data, ok := s.subscriptions.Get(key)
+ if !ok {
return nil
}
+ chans := data.([]chan string)
found := false
- for idx, sub := range s.subscriptions[key] {
+ for idx, sub := range chans {
if sub == chn {
- s.subscriptions[key] = append(s.subscriptions[key][:idx], s.subscriptions[key][idx+1:]...)
+ s.subscriptions.Set(key, append(chans[:idx], chans[idx+1:]...))
found = true
}
}
@@ 239,7 241,7 @@ func (s *Client) Unsubscribe(key string, chn chan string) error {
}
// If we removed all subscribers, ask server to not push updates to us anymore
- if len(s.subscriptions[key]) < 1 {
+ if len(chans) < 1 {
_, err := s.makeRequest(kv.Request{
CmdName: kv.CmdUnsubscribeKey,
Data: map[string]interface{}{
@@ 256,14 258,14 @@ func (s *Client) makeRequest(request kv.Request) (kv.Response, error) {
rid := ""
for {
rid = fmt.Sprintf("%x", rand.Int63())
- if _, ok := s.requests[rid]; ok {
+ if s.requests.Has(rid) {
continue
}
break
}
responseChannel := make(chan string)
- s.requests[rid] = responseChannel
+ s.requests.Set(rid, responseChannel)
request.RequestID = rid
err := s.send(request)