194 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			194 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package servlet
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/rsa"
 | 
						|
	"fmt"
 | 
						|
	"sync"
 | 
						|
 | 
						|
	"git.loafle.net/commons/logging-go"
 | 
						|
	"git.loafle.net/commons/server-go"
 | 
						|
	"git.loafle.net/commons/server-go/socket"
 | 
						|
	og "git.loafle.net/overflow/gateway"
 | 
						|
	ogs "git.loafle.net/overflow/gateway/subscribe"
 | 
						|
	ogrs "git.loafle.net/overflow/gateway_rpc/servlet"
 | 
						|
	"git.loafle.net/overflow/member_gateway_rpc/subscribe"
 | 
						|
 | 
						|
	"github.com/dgrijalva/jwt-go"
 | 
						|
	"github.com/satori/go.uuid"
 | 
						|
	"github.com/valyala/fasthttp"
 | 
						|
)
 | 
						|
 | 
						|
type WebappServlet interface {
 | 
						|
	ogrs.RPCServlet
 | 
						|
}
 | 
						|
 | 
						|
type WebappServlets struct {
 | 
						|
	ogrs.RPCServlets
 | 
						|
 | 
						|
	VerifyKey *rsa.PublicKey
 | 
						|
	SignKey   *rsa.PrivateKey
 | 
						|
 | 
						|
	sessions sync.Map
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) Init(serverCtx server.ServerCtx) error {
 | 
						|
	if err := s.RPCServlets.Init(serverCtx); nil != err {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) OnStart(serverCtx server.ServerCtx) error {
 | 
						|
	if err := s.RPCServlets.OnStart(serverCtx); nil != err {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	subscribeChan, err := subscribe.Subscriber.Subscribe("/webapp")
 | 
						|
	if nil != err {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	go s.handleSubscribe(serverCtx, subscribeChan)
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) OnStop(serverCtx server.ServerCtx) {
 | 
						|
	if err := subscribe.Subscriber.Unsubscribe("/webapp"); nil != err {
 | 
						|
		logging.Logger().Warn(err)
 | 
						|
	}
 | 
						|
 | 
						|
	s.RPCServlets.OnStop(serverCtx)
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) Destroy(serverCtx server.ServerCtx) {
 | 
						|
 | 
						|
	s.RPCServlets.Destroy(serverCtx)
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) {
 | 
						|
	var ok bool
 | 
						|
 | 
						|
	tokenString := string(ctx.QueryArgs().Peek("authToken"))
 | 
						|
	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
 | 
						|
		// Don't forget to validate the alg is what you expect:
 | 
						|
		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
 | 
						|
			return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
 | 
						|
		}
 | 
						|
		// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
 | 
						|
		return s.VerifyKey, nil
 | 
						|
	})
 | 
						|
 | 
						|
	if nil != err {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	var claims jwt.MapClaims
 | 
						|
	if claims, ok = token.Claims.(jwt.MapClaims); !ok || !token.Valid {
 | 
						|
		return nil, fmt.Errorf("Token is not valid %v", token)
 | 
						|
	}
 | 
						|
 | 
						|
	userEmail := claims["sub"].(string)
 | 
						|
	sessionID := uuid.NewV4().String()
 | 
						|
 | 
						|
	servletCtx.SetAttribute(og.SessionIDKey, sessionID)
 | 
						|
	servletCtx.SetAttribute(og.SessionClientTypeKey, og.MEMBER)
 | 
						|
	servletCtx.SetAttribute(og.SessionTargetIDKey, userEmail)
 | 
						|
 | 
						|
	return nil, nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn) {
 | 
						|
	s.RPCServlets.OnConnect(servletCtx, conn)
 | 
						|
 | 
						|
	sessionID := servletCtx.GetAttribute(og.SessionIDKey)
 | 
						|
	targetID := servletCtx.GetAttribute(og.SessionTargetIDKey)
 | 
						|
	if nil != sessionID && nil != targetID {
 | 
						|
		s.sessions.Store(sessionID.(string), ogrs.RetainSession(targetID.(string), servletCtx))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) OnDisconnect(servletCtx server.ServletCtx) {
 | 
						|
	s.RPCServlets.OnDisconnect(servletCtx)
 | 
						|
 | 
						|
	sessionID := servletCtx.GetAttribute(og.SessionIDKey)
 | 
						|
	if nil != sessionID {
 | 
						|
		s.sessions.Delete(sessionID.(string))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) handleSubscribe(serverCtx server.ServerCtx, subscribeChan <-chan *ogs.Message) {
 | 
						|
	var sessions []*ogrs.Session
 | 
						|
 | 
						|
LOOP:
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case msg, ok := <-subscribeChan:
 | 
						|
			if !ok {
 | 
						|
				return
 | 
						|
			}
 | 
						|
 | 
						|
			switch msg.TargetType {
 | 
						|
			case ogs.MEMBER:
 | 
						|
				sessions = s.getMemberSessionsByTargetIDs(msg.Targets)
 | 
						|
			case ogs.MEMBER_SESSION:
 | 
						|
				sessions = s.getMemberSessions(msg.Targets)
 | 
						|
			default:
 | 
						|
				logging.Logger().Warnf("Subscriber: Unknown TargetType %s", msg.TargetType)
 | 
						|
				continue LOOP
 | 
						|
			}
 | 
						|
 | 
						|
			if nil == sessions || 0 == len(sessions) {
 | 
						|
				continue LOOP
 | 
						|
			}
 | 
						|
 | 
						|
			for _, session := range sessions {
 | 
						|
				_writeChan := session.ServletCtx.GetAttribute(og.SessionWriteChanKey)
 | 
						|
				if nil != _writeChan {
 | 
						|
					writeChan := _writeChan.(chan<- []byte)
 | 
						|
					writeChan <- *msg.Message
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) getMemberSessions(sessionIDs []string) []*ogrs.Session {
 | 
						|
	var sessions []*ogrs.Session
 | 
						|
 | 
						|
	if nil == sessionIDs || 0 == len(sessionIDs) {
 | 
						|
		return sessions
 | 
						|
	}
 | 
						|
 | 
						|
	for _, sessionID := range sessionIDs {
 | 
						|
		session, ok := s.sessions.Load(sessionID)
 | 
						|
		if ok {
 | 
						|
			sessions = append(sessions, session.(*ogrs.Session))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return sessions
 | 
						|
}
 | 
						|
 | 
						|
func (s *WebappServlets) getMemberSessionsByTargetIDs(targetIDs []string) []*ogrs.Session {
 | 
						|
	var sessions []*ogrs.Session
 | 
						|
	if nil == targetIDs || 0 == len(targetIDs) {
 | 
						|
		return sessions
 | 
						|
	}
 | 
						|
 | 
						|
	s.sessions.Range(func(k, v interface{}) bool {
 | 
						|
		session := v.(*ogrs.Session)
 | 
						|
 | 
						|
		for _, targetID := range targetIDs {
 | 
						|
			if session.TargetID == targetID {
 | 
						|
				sessions = append(sessions, session)
 | 
						|
				break
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		return true
 | 
						|
	})
 | 
						|
 | 
						|
	return sessions
 | 
						|
}
 |