diff --git a/redis/subscribers.go b/redis/subscribers.go index 953bb01..2a30c9d 100644 --- a/redis/subscribers.go +++ b/redis/subscribers.go @@ -1,85 +1,133 @@ package redis import ( - "context" "encoding/json" "fmt" + "sync" "git.loafle.net/commons_go/logging" - uch "git.loafle.net/commons_go/util/channel" - ofs "git.loafle.net/overflow/overflow_subscriber" + cuc "git.loafle.net/commons_go/util/channel" + oos "git.loafle.net/overflow/overflow_subscriber" "github.com/garyburd/redigo/redis" ) type channelAction struct { - uch.Action - h ofs.SubscriberHandler + cuc.Action + h oos.SubscriberHandler } type subscribers struct { - ctx context.Context conn redis.PubSubConn - subHandlers map[string]ofs.SubscriberHandler + subHandlers map[string]oos.SubscriberHandler isListen bool subCh chan channelAction - isRunning bool + + stopChan chan struct{} + stopWg sync.WaitGroup } -func New(ctx context.Context, conn redis.Conn) ofs.Subscriber { +func New(conn redis.Conn) oos.Subscriber { s := &subscribers{ - ctx: ctx, - subHandlers: make(map[string]ofs.SubscriberHandler), - isListen: false, - subCh: make(chan channelAction), + isListen: false, } s.conn = redis.PubSubConn{Conn: conn} - go s.listen() - - s.isRunning = true - return s } -func (s *subscribers) listen() { +func (s *subscribers) Start() error { + if s.stopChan != nil { + panic("Redis Subscriber: subscriber is already running. Stop it before starting it again") + } + + s.stopChan = make(chan struct{}) + s.subHandlers = make(map[string]oos.SubscriberHandler) + s.subCh = make(chan channelAction) + + s.stopWg.Add(1) + go handleSubscriber(s) + + return nil +} + +func (s *subscribers) Stop() { + if s.stopChan == nil { + panic("Redis Subscriber: subscriber must be started before stopping it") + } + close(s.stopChan) + s.stopWg.Wait() + s.stopChan = nil +} + +func (s *subscribers) Subscribe(h oos.SubscriberHandler) error { + if _, ok := s.subHandlers[h.GetChannel()]; ok { + return oos.ChannelExistError{Channel: h.GetChannel()} + } + + ca := channelAction{ + h: h, + } + ca.Type = cuc.ActionTypeCreate + + s.subCh <- ca + + return nil +} + +func (s *subscribers) Unsubscribe(h oos.SubscriberHandler) error { + if _, ok := s.subHandlers[h.GetChannel()]; !ok { + return oos.ChannelIsNotExistError{Channel: h.GetChannel()} + } + + ca := channelAction{ + h: h, + } + ca.Type = cuc.ActionTypeDelete + + s.subCh <- ca + + return nil +} + +func handleSubscriber(s *subscribers) { + defer s.stopWg.Done() + for { select { case ca := <-s.subCh: switch ca.Type { - case uch.ActionTypeCreate: + case cuc.ActionTypeCreate: s.subHandlers[ca.h.GetChannel()] = ca.h s.conn.Subscribe(ca.h.GetChannel()) - s.listenSubscriptions() + listenSubscriptions(s) break - case uch.ActionTypeDelete: + case cuc.ActionTypeDelete: s.conn.Unsubscribe(ca.h.GetChannel()) delete(s.subHandlers, ca.h.GetChannel()) break } - case <-s.ctx.Done(): - s.destroy() + case <-s.stopChan: + s.conn.Close() return } } } -func (s *subscribers) destroy() { - s.isRunning = false - s.conn.Close() -} - -func (s *subscribers) listenSubscriptions() { +func listenSubscriptions(s *subscribers) { if s.isListen { return } + s.stopWg.Add(1) go func() { + defer s.stopWg.Done() + for { switch v := s.conn.Receive().(type) { case redis.Message: if h, ok := s.subHandlers[v.Channel]; ok { - if message, err := s.unmarshalMessage(v.Data); nil != err { - logging.Logger.Error(fmt.Sprintf("Subscriber Unmarshal error:%v", err)) + if message, err := unmarshalMessage(v.Data); nil != err { + logging.Logger().Error(fmt.Sprintf("Subscriber Unmarshal error:%v", err)) break } else { h.OnSubscribe(v.Channel, message) @@ -89,7 +137,7 @@ func (s *subscribers) listenSubscriptions() { case redis.Subscription: break case error: - s.destroy() + s.Stop() return default: } @@ -99,9 +147,9 @@ func (s *subscribers) listenSubscriptions() { s.isListen = true } -func (s *subscribers) unmarshalMessage(data []byte) (ofs.SubscribeMessage, error) { +func unmarshalMessage(data []byte) (oos.SubscribeMessage, error) { var err error - var message ofs.SubscribeMessage + var message oos.SubscribeMessage if err = json.Unmarshal(data, &message); nil != err { return message, err } @@ -112,33 +160,3 @@ func (s *subscribers) unmarshalMessage(data []byte) (ofs.SubscribeMessage, error return message, nil } - -func (s *subscribers) Subscribe(h ofs.SubscriberHandler) error { - if _, ok := s.subHandlers[h.GetChannel()]; ok { - return ofs.ChannelExistError{Channel: h.GetChannel()} - } - - ca := channelAction{ - h: h, - } - ca.Type = uch.ActionTypeCreate - - s.subCh <- ca - - return nil -} - -func (s *subscribers) Unsubscribe(h ofs.SubscriberHandler) error { - if _, ok := s.subHandlers[h.GetChannel()]; !ok { - return ofs.ChannelIsNotExistError{Channel: h.GetChannel()} - } - - ca := channelAction{ - h: h, - } - ca.Type = uch.ActionTypeDelete - - s.subCh <- ca - - return nil -} diff --git a/subscriber.go b/subscriber.go index a62ceab..f539f72 100644 --- a/subscriber.go +++ b/subscriber.go @@ -21,6 +21,9 @@ func (cinee ChannelIsNotExistError) Error() string { } type Subscriber interface { + Start() error + Stop() + Subscribe(h SubscriberHandler) error Unsubscribe(h SubscriberHandler) error } diff --git a/subscriber_handler.go b/subscriber_handler.go index b4cb9ed..bc6b5a8 100644 --- a/subscriber_handler.go +++ b/subscriber_handler.go @@ -3,4 +3,6 @@ package overflow_subscriber type SubscriberHandler interface { GetChannel() string OnSubscribe(channel string, message SubscribeMessage) + + Validate() } diff --git a/subscriber_handlers.go b/subscriber_handlers.go index 540bc24..cc54c65 100644 --- a/subscriber_handlers.go +++ b/subscriber_handlers.go @@ -4,8 +4,11 @@ type SubscriberHandlers struct { Channel string } -func (h *SubscriberHandlers) GetChannel() string { - return h.Channel +func (sh *SubscriberHandlers) GetChannel() string { + return sh.Channel } -func (h *SubscriberHandlers) OnSubscribe(channel string, message SubscribeMessage) { +func (sh *SubscriberHandlers) OnSubscribe(channel string, message SubscribeMessage) { +} + +func (sh *SubscriberHandlers) Validate() { }