From 51cb76d86112c32388feb34b32ddcb964dc22e30 Mon Sep 17 00:00:00 2001 From: crusader Date: Fri, 25 Aug 2017 18:19:53 +0900 Subject: [PATCH] ing --- socket.go | 64 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/socket.go b/socket.go index b8cdd40..5d0443e 100644 --- a/socket.go +++ b/socket.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "log" - "sync" "time" "git.loafle.net/overflow/overflow_gateway_websocket/websocket" @@ -18,21 +17,24 @@ type Socket interface { } type socket struct { - id string - o *SocketOptions - conn *websocket.Conn - path string - messageType int - writeMTX sync.Mutex + id string + o *SocketOptions + conn *websocket.Conn + path string + messageType int + writeCh chan []byte + disconnectCh chan bool } func NewSocket(id string, path string, o *SocketOptions, conn *websocket.Conn) Socket { c := &socket{ - id: id, - o: o, - conn: conn, - path: path, - messageType: websocket.TextMessage, + id: id, + o: o, + conn: conn, + path: path, + writeCh: make(chan []byte), + disconnectCh: make(chan bool), + messageType: websocket.TextMessage, } return c @@ -55,9 +57,11 @@ func (soc *socket) run() { soc.conn.SetReadLimit(soc.o.MaxMessageSize) defer func() { - soc.o.onDisconnected(soc) + soc.onDisconnected() }() + go soc.listenWrite() + for { if hasReadTimeout { soc.conn.SetReadDeadline(time.Now().Add(soc.o.ReadTimeout)) @@ -77,23 +81,37 @@ func (soc *socket) run() { } } +func (soc *socket) onDisconnected() { + soc.disconnectCh <- true + soc.o.onDisconnected(soc) +} + func (soc *socket) onMessage(messageType int, r io.Reader) { result := soc.o.Handler.OnMessage(soc, messageType, r) if nil == result { return } + soc.writeCh <- result +} - soc.writeMTX.Lock() - if writeTimeout := soc.o.WriteTimeout; writeTimeout > 0 { - err := soc.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - log.Println(fmt.Errorf("%v", err)) - return - } +func (soc *socket) listenWrite() { + for { + select { + // send message to the client + case w := <-soc.writeCh: + if writeTimeout := soc.o.WriteTimeout; writeTimeout > 0 { + err := soc.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + log.Println(fmt.Errorf("%v", err)) + } - err := soc.conn.WriteMessage(soc.messageType, result) + err := soc.conn.WriteMessage(soc.messageType, w) + if nil != err { + log.Println(fmt.Errorf("%v", err)) + } - soc.writeMTX.Unlock() - - if nil != err { + // receive done request + case <-soc.disconnectCh: + return + } } }