package client import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "sync" "git.loafle.net/commons_go/logging" "git.loafle.net/overflow/overflow_probes/central/client/protocol" "github.com/gorilla/websocket" ) const ( ProtocolName = "RPC/1.0" ) type ( OnNotifyFunc func(method string, params interface{}) OnCloseFunc func(code int, text string) ) type ServerError string func (e ServerError) Error() string { return string(e) } var ErrShutdown = errors.New("connection is shut down") type Call struct { Method string // The name of the service and method to call. Args interface{} // The argument to the function (*struct). Result interface{} // The reply from the function (*struct). Error error // After completion, the error status. Done chan *Call // Strobes when call is complete. } func (c *Call) done() { select { case c.Done <- c: // ok default: // We don't want to block here. It is the caller's responsibility to make // sure the channel has enough buffer space. See comment in Go(). logging.Logger.Debug("Client: discarding Call reply due to insufficient Done chan capacity") } } type Client interface { Dial(url string, header http.Header, readBufSize int, writeBufSize int) (*http.Response, error) Call(method string, args interface{}, result interface{}) error Notify(method string, args interface{}) error OnNotify(method string, cb OnNotifyFunc) OnClose(cb OnCloseFunc) Shutdown(ctx context.Context) error } type client struct { conn *websocket.Conn sendMutex sync.Mutex request protocol.Request notification protocol.Notification mutex sync.Mutex requestID uint64 pending map[uint64]*Call closing bool // user has called Close shutdown bool // server has told us to stop onNotifyHandlers map[string][]OnNotifyFunc onCloseHandlers []OnCloseFunc } func New() Client { c := &client{ requestID: 0, pending: make(map[uint64]*Call), onNotifyHandlers: make(map[string][]OnNotifyFunc), onCloseHandlers: make([]OnCloseFunc, 1), } return c } func (c *client) Dial(url string, header http.Header, readBufSize int, writeBufSize int) (*http.Response, error) { var err error var res *http.Response dialer := websocket.Dialer{ ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize, } if c.conn, res, err = dialer.Dial(url, header); nil != err { return nil, err } c.conn.SetCloseHandler(c.connCloseHandler) go c.input() return res, nil } func (c *client) Call(method string, args interface{}, result interface{}) error { call := <-c.goCall(method, args, result, make(chan *Call, 1)).Done return call.Error } func (c *client) Notify(method string, args interface{}) error { c.sendMutex.Lock() defer c.sendMutex.Unlock() c.notification.Protocol = ProtocolName c.notification.Method = method c.notification.Params = args if err := c.conn.WriteJSON(c.notification); nil != err { return err } return nil } func (c *client) OnNotify(method string, cb OnNotifyFunc) { var hs []OnNotifyFunc var ok bool if hs, ok = c.onNotifyHandlers[method]; !ok { hs = make([]OnNotifyFunc, 1) c.onNotifyHandlers[method] = hs } hs = append(hs, cb) } func (c *client) OnClose(cb OnCloseFunc) { c.onCloseHandlers = append(c.onCloseHandlers, cb) } func (c *client) Shutdown(ctx context.Context) error { c.mutex.Lock() if c.closing { c.mutex.Unlock() return ErrShutdown } c.closing = true c.mutex.Unlock() return c.conn.Close() } // Go invokes the function asynchronously. It returns the Call structure representing // the invocation. The done channel will signal when the call is complete by returning // the same Call object. If done is nil, Go will allocate a new channel. // If non-nil, done must be buffered or Go will deliberately crash. func (c *client) goCall(method string, args interface{}, result interface{}, done chan *Call) *Call { call := new(Call) call.Method = method call.Args = args call.Result = result if done == nil { done = make(chan *Call, 10) // buffered. } else { // If caller passes done != nil, it must arrange that // done has enough buffer for the number of simultaneous // RPCs that will be using that channel. If the channel // is totally unbuffered, it's best not to run at all. if cap(done) == 0 { logging.Logger.Panic("Client: done channel is unbuffered") } } call.Done = done c.sendCall(call) return call } func (c *client) sendCall(call *Call) { c.sendMutex.Lock() defer c.sendMutex.Unlock() // Register this call. c.mutex.Lock() if c.shutdown || c.closing { call.Error = ErrShutdown c.mutex.Unlock() call.done() return } c.requestID++ id := c.requestID c.pending[id] = call c.mutex.Unlock() // Encode and send the request. c.request.Protocol = ProtocolName c.request.Method = call.Method c.request.Params = call.Args c.request.ID = id if err := c.conn.WriteJSON(c.request); nil != err { c.mutex.Lock() call = c.pending[id] delete(c.pending, id) c.mutex.Unlock() if call != nil { call.Error = err call.done() } } } func (c *client) input() { var err error var res protocol.Response var noti protocol.Notification var messageType int var reader io.Reader for err == nil { res = protocol.Response{} if messageType, reader, err = c.conn.NextReader(); nil != err { break } logging.Logger.Debug(fmt.Sprintf("Client: messageType:%d", messageType)) if err = json.NewDecoder(reader).Decode(res); nil != err { noti = protocol.Notification{} if err = json.NewDecoder(reader).Decode(noti); nil != err { break } else { err = c.onNotification(noti) } } else { err = c.onResponse(res) } } // Terminate pending calls. c.sendMutex.Lock() c.mutex.Lock() c.shutdown = true closing := c.closing if err == io.EOF { if closing { err = ErrShutdown } else { err = io.ErrUnexpectedEOF } } for _, call := range c.pending { call.Error = err call.done() } c.mutex.Unlock() c.sendMutex.Unlock() if err != io.EOF && !closing { logging.Logger.Debug(fmt.Sprintf("Client: client protocol error:%v", err)) } } func (c *client) onResponse(res protocol.Response) error { var err error id := res.ID c.mutex.Lock() call := c.pending[id] delete(c.pending, id) c.mutex.Unlock() switch { case call == nil: case res.Error != nil: // We've got an error response. Give this to the request; // any subsequent requests will get the ReadResponseBody // error if there is one. if protocol.ProtocolErrorCodeInternal == res.Error.Code { if nil != res.Error.Message { call.Error = ServerError(*res.Error.Message) } } call.done() default: if err = json.Unmarshal(*res.Result, call.Result); nil != err { call.Error = errors.New("reading body " + err.Error()) } call.done() } return err } func (c *client) onNotification(noti protocol.Notification) error { var err error var hs []OnNotifyFunc var ok bool if hs, ok = c.onNotifyHandlers[noti.Method]; ok { for _, h := range hs { h(noti.Method, noti.Params) } } return err } func (c *client) connCloseHandler(code int, text string) error { for _, h := range c.onCloseHandlers { h(code, text) } return nil }