138 lines
2.6 KiB
Go
138 lines
2.6 KiB
Go
package websocket
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"sync"
|
|
|
|
protocol "git.loafle.net/overflow/overflow_probe/websocket/protocol"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
protocolName string = "1.0"
|
|
)
|
|
|
|
type CallCallback func(interface{}, *protocol.Error)
|
|
|
|
type WebSocketRPC interface {
|
|
Call(cb CallCallback, method string, params []string)
|
|
Send(method string, params []string)
|
|
}
|
|
|
|
type webSocketRPC struct {
|
|
conn *websocket.Conn
|
|
messageType int
|
|
writeMTX sync.Mutex
|
|
requestID int64
|
|
requestQueue map[int64]CallCallback
|
|
}
|
|
|
|
// New creates a websocket rpc client and returns it
|
|
func New(serverURL string) WebSocketRPC {
|
|
return newInstance(serverURL)
|
|
}
|
|
|
|
func newInstance(serverURL string) WebSocketRPC {
|
|
var dialer *websocket.Dialer
|
|
|
|
conn, _, err := dialer.Dial(serverURL, nil)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
w := &webSocketRPC{
|
|
conn: conn,
|
|
requestID: 0,
|
|
requestQueue: make(map[int64]CallCallback),
|
|
}
|
|
|
|
return w
|
|
}
|
|
|
|
func (w *webSocketRPC) readHandler() {
|
|
for {
|
|
// messageType, data, err := c.conn.ReadMessage()
|
|
messageType, r, err := w.conn.NextReader()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
|
|
|
}
|
|
break
|
|
} else {
|
|
w.onMessageReceived(messageType, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *webSocketRPC) onMessageReceived(messageType int, r io.Reader) {
|
|
res := new(protocol.Response)
|
|
err := json.NewDecoder(r).Decode(res)
|
|
if err != nil {
|
|
// Check message is Notification
|
|
noti := new(protocol.Notification)
|
|
err = json.NewDecoder(r).Decode(noti)
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
w.onNotificationReceived(noti)
|
|
}
|
|
|
|
w.onResponseReceived(res)
|
|
}
|
|
|
|
func (w *webSocketRPC) onResponseReceived(r *protocol.Response) {
|
|
cb := w.requestQueue[r.ID]
|
|
|
|
cb(r.Result, r.Error)
|
|
|
|
}
|
|
|
|
func (w *webSocketRPC) onNotificationReceived(n *protocol.Notification) {
|
|
}
|
|
|
|
func (w *webSocketRPC) Call(cb CallCallback, method string, params []string) {
|
|
w.writeMTX.Lock()
|
|
w.requestID++
|
|
rID := w.requestID
|
|
|
|
req := new(protocol.Request)
|
|
req.Protocol = protocolName
|
|
req.Method = method
|
|
req.Params = params
|
|
req.ID = rID
|
|
|
|
jReq, err := json.Marshal(req)
|
|
if nil != err {
|
|
log.Println(fmt.Errorf("%v", err))
|
|
}
|
|
err = w.conn.WriteMessage(w.messageType, jReq)
|
|
w.writeMTX.Unlock()
|
|
|
|
if nil != err {
|
|
}
|
|
w.requestQueue[rID] = cb
|
|
}
|
|
|
|
func (w *webSocketRPC) Send(method string, params []string) {
|
|
w.writeMTX.Lock()
|
|
|
|
req := new(protocol.Request)
|
|
req.Protocol = protocolName
|
|
req.Method = method
|
|
req.Params = params
|
|
|
|
jReq, err := json.Marshal(req)
|
|
if nil != err {
|
|
log.Println(fmt.Errorf("%v", err))
|
|
}
|
|
err = w.conn.WriteMessage(w.messageType, jReq)
|
|
w.writeMTX.Unlock()
|
|
|
|
if nil != err {
|
|
}
|
|
|
|
}
|