commit e78f2357407a58e94460caf8bac5c0d44e9f41b1 Author: crusader Date: Wed Aug 22 17:37:12 2018 +0900 project initialized diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3733e36 --- /dev/null +++ b/.gitignore @@ -0,0 +1,68 @@ +# Created by .ignore support plugin (hsz.mobi) +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff: +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/dictionaries + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.xml +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries + +# Mongo Explorer plugin: +.idea/**/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +/out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties +### Go template +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ +.idea/ +*.iml + +vendor/ +glide.lock +.DS_Store +dist/ +debug diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..2ca2b1d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,32 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug", + "type": "go", + "request": "launch", + "mode": "debug", + "remotePath": "", + "port": 2345, + "host": "127.0.0.1", + "program": "${workspaceRoot}/main.go", + "env": {}, + "args": [], + "showLog": true + }, + { + "name": "File Debug", + "type": "go", + "request": "launch", + "mode": "debug", + "remotePath": "", + "port": 2345, + "host": "127.0.0.1", + "program": "${fileDirname}", + "env": {}, + "args": [], + "showLog": true + } + + ] +} \ No newline at end of file diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 0000000..487018d --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,96 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + branch = "master" + name = "git.loafle.net/overflow/config-go" + packages = ["."] + revision = "c4066fa55db336afc618d64156f703f59138f791" + +[[projects]] + branch = "master" + name = "git.loafle.net/overflow/log-go" + packages = ["."] + revision = "93f396a4335da70aac4dfa146bebba323424b157" + +[[projects]] + branch = "master" + name = "git.loafle.net/overflow/util-go" + packages = ["ctx"] + revision = "c6e5aef1a08d21290038e1dfd4d532d1770fce04" + +[[projects]] + name = "github.com/BurntSushi/toml" + packages = ["."] + revision = "b26d9c308763d68093482582cea63d69be07a0f0" + version = "v0.3.0" + +[[projects]] + name = "github.com/klauspost/compress" + packages = [ + "flate", + "gzip", + "zlib" + ] + revision = "b939724e787a27c0005cabe3f78e7ed7987ac74f" + version = "v1.4.0" + +[[projects]] + name = "github.com/klauspost/cpuid" + packages = ["."] + revision = "ae7887de9fa5d2db4eaa8174a7eff2c1ac00f2da" + version = "v1.1" + +[[projects]] + branch = "master" + name = "github.com/valyala/bytebufferpool" + packages = ["."] + revision = "e746df99fe4a3986f4d4f79e13c1e0117ce9c2f7" + +[[projects]] + name = "github.com/valyala/fasthttp" + packages = [ + ".", + "fasthttputil", + "stackless" + ] + revision = "e5f51c11919d4f66400334047b897ef0a94c6f3c" + version = "v20180529" + +[[projects]] + name = "go.uber.org/atomic" + packages = ["."] + revision = "1ea20fb1cbb1cc08cbd0d913a96dead89aa18289" + version = "v1.3.2" + +[[projects]] + name = "go.uber.org/multierr" + packages = ["."] + revision = "3c4937480c32f4c13a875a1829af76c98ca3d40a" + version = "v1.1.0" + +[[projects]] + name = "go.uber.org/zap" + packages = [ + ".", + "buffer", + "internal/bufferpool", + "internal/color", + "internal/exit", + "zapcore" + ] + revision = "ff33455a0e382e8a81d14dd7c922020b6b5e7982" + version = "v1.9.1" + +[[projects]] + name = "gopkg.in/yaml.v2" + packages = ["."] + revision = "5420a8b6744d3b0345ab293f6fcba19c978f1183" + version = "v2.2.1" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + inputs-digest = "5a46d6beb0750003f59e9f3c34fcaedb4ae836d2a8a46d111ea62059d4d08480" + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 0000000..3f016d2 --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,42 @@ +# Gopkg.toml example +# +# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + + +[[constraint]] + branch = "master" + name = "git.loafle.net/overflow/log-go" + +[[constraint]] + branch = "master" + name = "git.loafle.net/overflow/util-go" + +[[constraint]] + name = "github.com/valyala/fasthttp" + version = "20180529.0.0" + +[prune] + go-tests = true + unused-packages = true diff --git a/connection-handler.go b/connection-handler.go new file mode 100644 index 0000000..fd99c5a --- /dev/null +++ b/connection-handler.go @@ -0,0 +1,92 @@ +package server + +import ( + "crypto/tls" + "net" + "sync/atomic" + "time" +) + +type ConnectionHandler interface { + GetConcurrency() int + GetKeepAlive() time.Duration + GetHandshakeTimeout() time.Duration + GetTLSConfig() *tls.Config + + Listener(serverCtx ServerCtx) (net.Listener, error) +} + +type ConnectionHandlers struct { + // The maximum number of concurrent connections the server may serve. + // + // DefaultConcurrency is used if not set. + Network string `json:"network,omitempty"` + Address string `json:"address,omitempty"` + Concurrency int `json:"concurrency,omitempty"` + KeepAlive time.Duration `json:"keepAlive,omitempty"` + HandshakeTimeout time.Duration `json:"handshakeTimeout,omitempty"` + TLSConfig *tls.Config `json:"-"` + validated atomic.Value +} + +func (ch *ConnectionHandlers) Listener(serverCtx ServerCtx) (net.Listener, error) { + l, err := net.Listen(ch.Network, ch.Address) + if nil != err { + return nil, err + } + + return l, nil +} + +func (ch *ConnectionHandlers) GetConcurrency() int { + return ch.Concurrency +} + +func (ch *ConnectionHandlers) GetKeepAlive() time.Duration { + return ch.KeepAlive +} + +func (ch *ConnectionHandlers) GetHandshakeTimeout() time.Duration { + return ch.HandshakeTimeout +} + +func (ch *ConnectionHandlers) GetTLSConfig() *tls.Config { + return ch.TLSConfig +} + +func (ch *ConnectionHandlers) Clone() *ConnectionHandlers { + return &ConnectionHandlers{ + Network: ch.Network, + Address: ch.Address, + Concurrency: ch.Concurrency, + KeepAlive: ch.KeepAlive, + HandshakeTimeout: ch.HandshakeTimeout, + TLSConfig: ch.TLSConfig, + validated: ch.validated, + } +} + +func (ch *ConnectionHandlers) Validate() error { + if nil != ch.validated.Load() { + return nil + } + ch.validated.Store(true) + + if ch.Concurrency <= 0 { + ch.Concurrency = DefaultConcurrency + } + + if ch.KeepAlive <= 0 { + ch.KeepAlive = DefaultKeepAlive + } else { + ch.KeepAlive = ch.KeepAlive * time.Second + } + + if ch.HandshakeTimeout <= 0 { + ch.HandshakeTimeout = DefaultHandshakeTimeout + } else { + ch.HandshakeTimeout = ch.HandshakeTimeout * time.Second + } + + return nil +} diff --git a/const.go b/const.go new file mode 100644 index 0000000..6613aab --- /dev/null +++ b/const.go @@ -0,0 +1,38 @@ +package server + +import "time" + +const ( + // DefaultConcurrency is the maximum number of concurrent connections + // the Server may serve by default (i.e. if Server.Concurrency isn't set). + DefaultConcurrency = 256 * 1024 + + DefaultKeepAlive = 0 + + // DefaultHandshakeTimeout is default value of websocket handshake Timeout + DefaultHandshakeTimeout = 0 + + // DefaultReadBufferSize is default value of Read Buffer Size + DefaultReadBufferSize = 0 + // DefaultWriteBufferSize is default value of Write Buffer Size + DefaultWriteBufferSize = 0 + // DefaultReadTimeout is default value of read timeout + DefaultReadTimeout = 0 + // DefaultWriteTimeout is default value of write timeout + DefaultWriteTimeout = 0 + // DefaultEnableCompression is default value of support compression + DefaultEnableCompression = false + // DefaultMaxMessageSize is default size for a message read from the peer + DefaultMaxMessageSize = 4096 + // DefaultPongTimeout is default value of websocket pong Timeout + DefaultPongTimeout = 60 * time.Second + // DefaultPingTimeout is default value of websocket ping Timeout + DefaultPingTimeout = 10 * time.Second + // DefaultPingPeriod is default value of send ping period + DefaultPingPeriod = (DefaultPingTimeout * 9) / 10 + + DefaultReconnectInterval = 5 * time.Second + DefaultReconnectTryTime = 10 + + DefaultCompressionThreshold = 1024 +) diff --git a/internal/grace/net/net.go b/internal/grace/net/net.go new file mode 100644 index 0000000..3670a34 --- /dev/null +++ b/internal/grace/net/net.go @@ -0,0 +1,244 @@ +package net + +import ( + "fmt" + "net" + "os" + "os/exec" + "strconv" + "strings" + "sync" +) + +const ( + // Used to indicate a graceful restart in the new process. + envCountKey = "LISTEN_FDS" + envCountKeyPrefix = envCountKey + "=" +) + +// In order to keep the working directory the same as when we started we record +// it at startup. +var originalWD, _ = os.Getwd() + +// Net provides the family of Listen functions and maintains the associated +// state. Typically you will have only once instance of Net per application. +type Net struct { + inherited []net.Listener + active []net.Listener + mutex sync.Mutex + inheritOnce sync.Once + + // used in tests to override the default behavior of starting from fd 3. + fdStart int +} + +func (n *Net) inherit() error { + var retErr error + n.inheritOnce.Do(func() { + n.mutex.Lock() + defer n.mutex.Unlock() + countStr := os.Getenv(envCountKey) + if countStr == "" { + return + } + count, err := strconv.Atoi(countStr) + if err != nil { + retErr = fmt.Errorf("found invalid count value: %s=%s", envCountKey, countStr) + return + } + + // In tests this may be overridden. + fdStart := n.fdStart + if fdStart == 0 { + // In normal operations if we are inheriting, the listeners will begin at + // fd 3. + fdStart = 3 + } + + for i := fdStart; i < fdStart+count; i++ { + file := os.NewFile(uintptr(i), "listener") + l, err := net.FileListener(file) + if err != nil { + file.Close() + retErr = fmt.Errorf("error inheriting socket fd %d: %s", i, err) + return + } + if err := file.Close(); err != nil { + retErr = fmt.Errorf("error closing inherited socket fd %d: %s", i, err) + return + } + n.inherited = append(n.inherited, l) + } + }) + return retErr +} + +// Listen announces on the local network address laddr. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It +// returns an inherited net.Listener for the matching network and address, or +// creates a new one using net.Listen. +func (n *Net) Listen(nett, laddr string) (net.Listener, error) { + switch nett { + default: + return nil, net.UnknownNetworkError(nett) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(nett, laddr) + if err != nil { + return nil, err + } + return n.ListenTCP(nett, addr) + case "unix", "unixpacket", "invalid_unix_net_for_test": + addr, err := net.ResolveUnixAddr(nett, laddr) + if err != nil { + return nil, err + } + return n.ListenUnix(nett, addr) + } +} + +// ListenTCP announces on the local network address laddr. The network net must +// be: "tcp", "tcp4" or "tcp6". It returns an inherited net.Listener for the +// matching network and address, or creates a new one using net.ListenTCP. +func (n *Net) ListenTCP(nett string, laddr *net.TCPAddr) (*net.TCPListener, error) { + if err := n.inherit(); err != nil { + return nil, err + } + + n.mutex.Lock() + defer n.mutex.Unlock() + + // look for an inherited listener + for i, l := range n.inherited { + if l == nil { // we nil used inherited listeners + continue + } + if isSameAddr(l.Addr(), laddr) { + n.inherited[i] = nil + n.active = append(n.active, l) + return l.(*net.TCPListener), nil + } + } + + // make a fresh listener + l, err := net.ListenTCP(nett, laddr) + if err != nil { + return nil, err + } + n.active = append(n.active, l) + return l, nil +} + +// ListenUnix announces on the local network address laddr. The network net +// must be a: "unix" or "unixpacket". It returns an inherited net.Listener for +// the matching network and address, or creates a new one using net.ListenUnix. +func (n *Net) ListenUnix(nett string, laddr *net.UnixAddr) (*net.UnixListener, error) { + if err := n.inherit(); err != nil { + return nil, err + } + + n.mutex.Lock() + defer n.mutex.Unlock() + + // look for an inherited listener + for i, l := range n.inherited { + if l == nil { // we nil used inherited listeners + continue + } + if isSameAddr(l.Addr(), laddr) { + n.inherited[i] = nil + n.active = append(n.active, l) + return l.(*net.UnixListener), nil + } + } + + // make a fresh listener + l, err := net.ListenUnix(nett, laddr) + if err != nil { + return nil, err + } + n.active = append(n.active, l) + return l, nil +} + +// activeListeners returns a snapshot copy of the active listeners. +func (n *Net) activeListeners() ([]net.Listener, error) { + n.mutex.Lock() + defer n.mutex.Unlock() + ls := make([]net.Listener, len(n.active)) + copy(ls, n.active) + return ls, nil +} + +func isSameAddr(a1, a2 net.Addr) bool { + if a1.Network() != a2.Network() { + return false + } + a1s := a1.String() + a2s := a2.String() + if a1s == a2s { + return true + } + + // This allows for ipv6 vs ipv4 local addresses to compare as equal. This + // scenario is common when listening on localhost. + const ipv6prefix = "[::]" + a1s = strings.TrimPrefix(a1s, ipv6prefix) + a2s = strings.TrimPrefix(a2s, ipv6prefix) + const ipv4prefix = "0.0.0.0" + a1s = strings.TrimPrefix(a1s, ipv4prefix) + a2s = strings.TrimPrefix(a2s, ipv4prefix) + return a1s == a2s +} + +// StartProcess starts a new process passing it the active listeners. It +// doesn't fork, but starts a new process using the same environment and +// arguments as when it was originally started. This allows for a newly +// deployed binary to be started. It returns the pid of the newly started +// process when successful. +func (n *Net) StartProcess() (int, error) { + listeners, err := n.activeListeners() + if err != nil { + return 0, err + } + + // Extract the fds from the listeners. + files := make([]*os.File, len(listeners)) + for i, l := range listeners { + files[i], err = l.(filer).File() + if err != nil { + return 0, err + } + defer files[i].Close() + } + + // Use the original binary location. This works with symlinks such that if + // the file it points to has been changed we will use the updated symlink. + argv0, err := exec.LookPath(os.Args[0]) + if err != nil { + return 0, err + } + + // Pass on the environment and replace the old count key with the new one. + var env []string + for _, v := range os.Environ() { + if !strings.HasPrefix(v, envCountKeyPrefix) { + env = append(env, v) + } + } + env = append(env, fmt.Sprintf("%s%d", envCountKeyPrefix, len(listeners))) + + allFiles := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, files...) + process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{ + Dir: originalWD, + Env: env, + Files: allFiles, + }) + if err != nil { + return 0, err + } + return process.Pid, nil +} + +type filer interface { + File() (*os.File, error) +} diff --git a/internal/router/params_go17.go b/internal/router/params_go17.go new file mode 100644 index 0000000..b1a1de8 --- /dev/null +++ b/internal/router/params_go17.go @@ -0,0 +1,38 @@ +// +build go1.7 + +package router + +import ( + "context" + "net/http" +) + +type paramsKey struct{} + +// ParamsKey is the request context key under which URL params are stored. +// +// This is only present from go 1.7. +var ParamsKey = paramsKey{} + +// Handler is an adapter which allows the usage of an http.Handler as a +// request handle. With go 1.7+, the Params will be available in the +// request context under ParamsKey. +func (r *Router) Handler(method, path string, handler http.Handler) { + r.Handle(method, path, + func(w http.ResponseWriter, req *http.Request, p Params) { + ctx := req.Context() + ctx = context.WithValue(ctx, ParamsKey, p) + req = req.WithContext(ctx) + handler.ServeHTTP(w, req) + }, + ) +} + +// ParamsFromContext pulls the URL parameters from a request context, +// or returns nil if none are present. +// +// This is only present from go 1.7. +func ParamsFromContext(ctx context.Context) Params { + p, _ := ctx.Value(ParamsKey).(Params) + return p +} diff --git a/internal/router/params_legacy.go b/internal/router/params_legacy.go new file mode 100644 index 0000000..1d7bf42 --- /dev/null +++ b/internal/router/params_legacy.go @@ -0,0 +1,16 @@ +// +build !go1.7 + +package router + +import "net/http" + +// Handler is an adapter which allows the usage of an http.Handler as a +// request handle. With go 1.7+, the Params will be available in the +// request context under ParamsKey. +func (r *Router) Handler(method, path string, handler http.Handler) { + r.Handle(method, path, + func(w http.ResponseWriter, req *http.Request, _ Params) { + handler.ServeHTTP(w, req) + }, + ) +} diff --git a/internal/router/path.go b/internal/router/path.go new file mode 100644 index 0000000..1aa1f25 --- /dev/null +++ b/internal/router/path.go @@ -0,0 +1,118 @@ +package router + +// CleanPath is the URL version of path.Clean, it returns a canonical URL path +// for p, eliminating . and .. elements. +// +// The following rules are applied iteratively until no further processing can +// be done: +// 1. Replace multiple slashes with a single slash. +// 2. Eliminate each . path name element (the current directory). +// 3. Eliminate each inner .. path name element (the parent directory) +// along with the non-.. element that precedes it. +// 4. Eliminate .. elements that begin a rooted path: +// that is, replace "/.." by "/" at the beginning of a path. +// +// If the result of this process is an empty string, "/" is returned +func CleanPath(p string) string { + // Turn empty string into "/" + if p == "" { + return "/" + } + + n := len(p) + var buf []byte + + // Invariants: + // reading from path; r is index of next byte to process. + // writing to buf; w is index of next byte to write. + + // path must start with '/' + r := 1 + w := 1 + + if p[0] != '/' { + r = 0 + buf = make([]byte, n+1) + buf[0] = '/' + } + + trailing := n > 2 && p[n-1] == '/' + + // A bit more clunky without a 'lazybuf' like the path package, but the loop + // gets completely inlined (bufApp). So in contrast to the path package this + // loop has no expensive function calls (except 1x make) + + for r < n { + switch { + case p[r] == '/': + // empty path element, trailing slash is added after the end + r++ + + case p[r] == '.' && r+1 == n: + trailing = true + r++ + + case p[r] == '.' && p[r+1] == '/': + // . element + r++ + + case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): + // .. element: remove to last / + r += 2 + + if w > 1 { + // can backtrack + w-- + + if buf == nil { + for w > 1 && p[w] != '/' { + w-- + } + } else { + for w > 1 && buf[w] != '/' { + w-- + } + } + } + + default: + // real path element. + // add slash if needed + if w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + // copy element + for r < n && p[r] != '/' { + bufApp(&buf, p, w, p[r]) + w++ + r++ + } + } + } + + // re-append trailing slash + if trailing && w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + if buf == nil { + return p[:w] + } + return string(buf[:w]) +} + +// internal helper to lazily create a buffer if necessary +func bufApp(buf *[]byte, s string, w int, c byte) { + if *buf == nil { + if s[w] == c { + return + } + + *buf = make([]byte, len(s)) + copy(*buf, s[:w]) + } + (*buf)[w] = c +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..7088432 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,398 @@ +// Package httprouter is a trie based high performance HTTP request router. +// +// A trivial example is: +// +// package main +// +// import ( +// "fmt" +// "github.com/julienschmidt/httprouter" +// "net/http" +// "log" +// ) +// +// func Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +// fmt.Fprint(w, "Welcome!\n") +// } +// +// func Hello(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +// fmt.Fprintf(w, "hello, %s!\n", ps.ByName("name")) +// } +// +// func main() { +// router := httprouter.New() +// router.GET("/", Index) +// router.GET("/hello/:name", Hello) +// +// log.Fatal(http.ListenAndServe(":8080", router)) +// } +// +// The router matches incoming requests by the request method and the path. +// If a handle is registered for this path and method, the router delegates the +// request to that function. +// For the methods GET, POST, PUT, PATCH and DELETE shortcut functions exist to +// register handles, for all other methods router.Handle can be used. +// +// The registered path, against which the router matches incoming requests, can +// contain two types of parameters: +// Syntax Type +// :name named parameter +// *name catch-all parameter +// +// Named parameters are dynamic path segments. They match anything until the +// next '/' or the path end: +// Path: /blog/:category/:post +// +// Requests: +// /blog/go/request-routers match: category="go", post="request-routers" +// /blog/go/request-routers/ no match, but the router would redirect +// /blog/go/ no match +// /blog/go/request-routers/comments no match +// +// Catch-all parameters match anything until the path end, including the +// directory index (the '/' before the catch-all). Since they match anything +// until the end, catch-all parameters must always be the final path element. +// Path: /files/*filepath +// +// Requests: +// /files/ match: filepath="/" +// /files/LICENSE match: filepath="/LICENSE" +// /files/templates/article.html match: filepath="/templates/article.html" +// /files no match, but the router would redirect +// +// The value of parameters is saved as a slice of the Param struct, consisting +// each of a key and a value. The slice is passed to the Handle func as a third +// parameter. +// There are two ways to retrieve the value of a parameter: +// // by the name of the parameter +// user := ps.ByName("user") // defined by :user or *user +// +// // by the index of the parameter. This way you can also get the name (key) +// thirdKey := ps[2].Key // the name of the 3rd parameter +// thirdValue := ps[2].Value // the value of the 3rd parameter + +package router + +import ( + "net/http" +) + +// Handle is a function that can be registered to a route to handle HTTP +// requests. Like http.HandlerFunc, but has a third parameter for the values of +// wildcards (variables). +type Handle func(http.ResponseWriter, *http.Request, Params) + +// Param is a single URL parameter, consisting of a key and a value. +type Param struct { + Key string + Value string +} + +// Params is a Param-slice, as returned by the router. +// The slice is ordered, the first URL parameter is also the first slice value. +// It is therefore safe to read values by the index. +type Params []Param + +// ByName returns the value of the first Param which key matches the given name. +// If no matching Param is found, an empty string is returned. +func (ps Params) ByName(name string) string { + for i := range ps { + if ps[i].Key == name { + return ps[i].Value + } + } + return "" +} + +// Router is a http.Handler which can be used to dispatch requests to different +// handler functions via configurable routes +type Router struct { + trees map[string]*node + + // Enables automatic redirection if the current route can't be matched but a + // handler for the path with (without) the trailing slash exists. + // For example if /foo/ is requested but a route only exists for /foo, the + // client is redirected to /foo with http status code 301 for GET requests + // and 307 for all other request methods. + RedirectTrailingSlash bool + + // If enabled, the router tries to fix the current request path, if no + // handle is registered for it. + // First superfluous path elements like ../ or // are removed. + // Afterwards the router does a case-insensitive lookup of the cleaned path. + // If a handle can be found for this route, the router makes a redirection + // to the corrected path with status code 301 for GET requests and 307 for + // all other request methods. + // For example /FOO and /..//Foo could be redirected to /foo. + // RedirectTrailingSlash is independent of this option. + RedirectFixedPath bool + + // If enabled, the router checks if another method is allowed for the + // current route, if the current request can not be routed. + // If this is the case, the request is answered with 'Method Not Allowed' + // and HTTP status code 405. + // If no other Method is allowed, the request is delegated to the NotFound + // handler. + HandleMethodNotAllowed bool + + // If enabled, the router automatically replies to OPTIONS requests. + // Custom OPTIONS handlers take priority over automatic replies. + HandleOPTIONS bool + + // Configurable http.Handler which is called when no matching route is + // found. If it is not set, http.NotFound is used. + NotFound http.Handler + + // Configurable http.Handler which is called when a request + // cannot be routed and HandleMethodNotAllowed is true. + // If it is not set, http.Error with http.StatusMethodNotAllowed is used. + // The "Allow" header with allowed request methods is set before the handler + // is called. + MethodNotAllowed http.Handler + + // Function to handle panics recovered from http handlers. + // It should be used to generate a error page and return the http error code + // 500 (Internal Server Error). + // The handler can be used to keep your server from crashing because of + // unrecovered panics. + PanicHandler func(http.ResponseWriter, *http.Request, interface{}) +} + +// Make sure the Router conforms with the http.Handler interface +var _ http.Handler = New() + +// New returns a new initialized Router. +// Path auto-correction, including trailing slashes, is enabled by default. +func New() *Router { + return &Router{ + RedirectTrailingSlash: true, + RedirectFixedPath: true, + HandleMethodNotAllowed: true, + HandleOPTIONS: true, + } +} + +// GET is a shortcut for router.Handle("GET", path, handle) +func (r *Router) GET(path string, handle Handle) { + r.Handle("GET", path, handle) +} + +// HEAD is a shortcut for router.Handle("HEAD", path, handle) +func (r *Router) HEAD(path string, handle Handle) { + r.Handle("HEAD", path, handle) +} + +// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle) +func (r *Router) OPTIONS(path string, handle Handle) { + r.Handle("OPTIONS", path, handle) +} + +// POST is a shortcut for router.Handle("POST", path, handle) +func (r *Router) POST(path string, handle Handle) { + r.Handle("POST", path, handle) +} + +// PUT is a shortcut for router.Handle("PUT", path, handle) +func (r *Router) PUT(path string, handle Handle) { + r.Handle("PUT", path, handle) +} + +// PATCH is a shortcut for router.Handle("PATCH", path, handle) +func (r *Router) PATCH(path string, handle Handle) { + r.Handle("PATCH", path, handle) +} + +// DELETE is a shortcut for router.Handle("DELETE", path, handle) +func (r *Router) DELETE(path string, handle Handle) { + r.Handle("DELETE", path, handle) +} + +// Handle registers a new request handle with the given path and method. +// +// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut +// functions can be used. +// +// This function is intended for bulk loading and to allow the usage of less +// frequently used, non-standardized or custom methods (e.g. for internal +// communication with a proxy). +func (r *Router) Handle(method, path string, handle Handle) { + if path[0] != '/' { + panic("path must begin with '/' in path '" + path + "'") + } + + if r.trees == nil { + r.trees = make(map[string]*node) + } + + root := r.trees[method] + if root == nil { + root = new(node) + r.trees[method] = root + } + + root.addRoute(path, handle) +} + +// HandlerFunc is an adapter which allows the usage of an http.HandlerFunc as a +// request handle. +func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) { + r.Handler(method, path, handler) +} + +// ServeFiles serves files from the given file system root. +// The path must end with "/*filepath", files are then served from the local +// path /defined/root/dir/*filepath. +// For example if root is "/etc" and *filepath is "passwd", the local file +// "/etc/passwd" would be served. +// Internally a http.FileServer is used, therefore http.NotFound is used instead +// of the Router's NotFound handler. +// To use the operating system's file system implementation, +// use http.Dir: +// router.ServeFiles("/src/*filepath", http.Dir("/var/www")) +func (r *Router) ServeFiles(path string, root http.FileSystem) { + if len(path) < 10 || path[len(path)-10:] != "/*filepath" { + panic("path must end with /*filepath in path '" + path + "'") + } + + fileServer := http.FileServer(root) + + r.GET(path, func(w http.ResponseWriter, req *http.Request, ps Params) { + req.URL.Path = ps.ByName("filepath") + fileServer.ServeHTTP(w, req) + }) +} + +func (r *Router) recv(w http.ResponseWriter, req *http.Request) { + if rcv := recover(); rcv != nil { + r.PanicHandler(w, req, rcv) + } +} + +// Lookup allows the manual lookup of a method + path combo. +// This is e.g. useful to build a framework around this router. +// If the path was found, it returns the handle function and the path parameter +// values. Otherwise the third return value indicates whether a redirection to +// the same path with an extra / without the trailing slash should be performed. +func (r *Router) Lookup(method, path string) (Handle, Params, bool) { + if root := r.trees[method]; root != nil { + return root.getValue(path) + } + return nil, nil, false +} + +func (r *Router) allowed(path, reqMethod string) (allow string) { + if path == "*" { // server-wide + for method := range r.trees { + if method == "OPTIONS" { + continue + } + + // add request method to list of allowed methods + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } + } + } else { // specific path + for method := range r.trees { + // Skip the requested method - we already tried this one + if method == reqMethod || method == "OPTIONS" { + continue + } + + handle, _, _ := r.trees[method].getValue(path) + if handle != nil { + // add request method to list of allowed methods + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } + } + } + } + if len(allow) > 0 { + allow += ", OPTIONS" + } + return +} + +// ServeHTTP makes the router implement the http.Handler interface. +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if r.PanicHandler != nil { + defer r.recv(w, req) + } + + path := req.URL.Path + + if root := r.trees[req.Method]; root != nil { + if handle, ps, tsr := root.getValue(path); handle != nil { + handle(w, req, ps) + return + } else if req.Method != "CONNECT" && path != "/" { + code := 301 // Permanent redirect, request with GET method + if req.Method != "GET" { + // Temporary redirect, request with same method + // As of Go 1.3, Go does not support status code 308. + code = 307 + } + + if tsr && r.RedirectTrailingSlash { + if len(path) > 1 && path[len(path)-1] == '/' { + req.URL.Path = path[:len(path)-1] + } else { + req.URL.Path = path + "/" + } + http.Redirect(w, req, req.URL.String(), code) + return + } + + // Try to fix the request path + if r.RedirectFixedPath { + fixedPath, found := root.findCaseInsensitivePath( + CleanPath(path), + r.RedirectTrailingSlash, + ) + if found { + req.URL.Path = string(fixedPath) + http.Redirect(w, req, req.URL.String(), code) + return + } + } + } + } + + if req.Method == "OPTIONS" { + // Handle OPTIONS requests + if r.HandleOPTIONS { + if allow := r.allowed(path, req.Method); len(allow) > 0 { + w.Header().Set("Allow", allow) + return + } + } + } else { + // Handle 405 + if r.HandleMethodNotAllowed { + if allow := r.allowed(path, req.Method); len(allow) > 0 { + w.Header().Set("Allow", allow) + if r.MethodNotAllowed != nil { + r.MethodNotAllowed.ServeHTTP(w, req) + } else { + http.Error(w, + http.StatusText(http.StatusMethodNotAllowed), + http.StatusMethodNotAllowed, + ) + } + return + } + } + } + + // Handle 404 + if r.NotFound != nil { + r.NotFound.ServeHTTP(w, req) + } else { + http.NotFound(w, req) + } +} diff --git a/internal/router/tree.go b/internal/router/tree.go new file mode 100644 index 0000000..de89eaa --- /dev/null +++ b/internal/router/tree.go @@ -0,0 +1,652 @@ +package router + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +func countParams(path string) uint8 { + var n uint + for i := 0; i < len(path); i++ { + if path[i] != ':' && path[i] != '*' { + continue + } + n++ + } + if n >= 255 { + return 255 + } + return uint8(n) +} + +type nodeType uint8 + +const ( + static nodeType = iota // default + root + param + catchAll +) + +type node struct { + path string + wildChild bool + nType nodeType + maxParams uint8 + indices string + children []*node + handle Handle + priority uint32 +} + +// increments priority of the given child and reorders if necessary +func (n *node) incrementChildPrio(pos int) int { + n.children[pos].priority++ + prio := n.children[pos].priority + + // adjust position (move to front) + newPos := pos + for newPos > 0 && n.children[newPos-1].priority < prio { + // swap node positions + n.children[newPos-1], n.children[newPos] = n.children[newPos], n.children[newPos-1] + + newPos-- + } + + // build new index char string + if newPos != pos { + n.indices = n.indices[:newPos] + // unchanged prefix, might be empty + n.indices[pos:pos+1] + // the index char we move + n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos' + } + + return newPos +} + +// addRoute adds a node with the given handle to the path. +// Not concurrency-safe! +func (n *node) addRoute(path string, handle Handle) { + fullPath := path + n.priority++ + numParams := countParams(path) + + // non-empty tree + if len(n.path) > 0 || len(n.children) > 0 { + walk: + for { + // Update maxParams of the current node + if numParams > n.maxParams { + n.maxParams = numParams + } + + // Find the longest common prefix. + // This also implies that the common prefix contains no ':' or '*' + // since the existing key can't contain those chars. + i := 0 + max := min(len(path), len(n.path)) + for i < max && path[i] == n.path[i] { + i++ + } + + // Split edge + if i < len(n.path) { + child := node{ + path: n.path[i:], + wildChild: n.wildChild, + nType: static, + indices: n.indices, + children: n.children, + handle: n.handle, + priority: n.priority - 1, + } + + // Update maxParams (max of all children) + for i := range child.children { + if child.children[i].maxParams > child.maxParams { + child.maxParams = child.children[i].maxParams + } + } + + n.children = []*node{&child} + // []byte for proper unicode char conversion, see #65 + n.indices = string([]byte{n.path[i]}) + n.path = path[:i] + n.handle = nil + n.wildChild = false + } + + // Make new node a child of this node + if i < len(path) { + path = path[i:] + + if n.wildChild { + n = n.children[0] + n.priority++ + + // Update maxParams of the child node + if numParams > n.maxParams { + n.maxParams = numParams + } + numParams-- + + // Check if the wildcard matches + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + // Check for longer wildcard, e.g. :name and :names + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk + } else { + // Wildcard conflict + var pathSeg string + if n.nType == catchAll { + pathSeg = path + } else { + pathSeg = strings.SplitN(path, "/", 2)[0] + } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path + panic("'" + pathSeg + + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") + } + } + + c := path[0] + + // slash after param + if n.nType == param && c == '/' && len(n.children) == 1 { + n = n.children[0] + n.priority++ + continue walk + } + + // Check if a child with the next path byte exists + for i := 0; i < len(n.indices); i++ { + if c == n.indices[i] { + i = n.incrementChildPrio(i) + n = n.children[i] + continue walk + } + } + + // Otherwise insert it + if c != ':' && c != '*' { + // []byte for proper unicode char conversion, see #65 + n.indices += string([]byte{c}) + child := &node{ + maxParams: numParams, + } + n.children = append(n.children, child) + n.incrementChildPrio(len(n.indices) - 1) + n = child + } + n.insertChild(numParams, path, fullPath, handle) + return + + } else if i == len(path) { // Make node a (in-path) leaf + if n.handle != nil { + panic("a handle is already registered for path '" + fullPath + "'") + } + n.handle = handle + } + return + } + } else { // Empty tree + n.insertChild(numParams, path, fullPath, handle) + n.nType = root + } +} + +func (n *node) insertChild(numParams uint8, path, fullPath string, handle Handle) { + var offset int // already handled bytes of the path + + // find prefix until first wildcard (beginning with ':'' or '*'') + for i, max := 0, len(path); numParams > 0; i++ { + c := path[i] + if c != ':' && c != '*' { + continue + } + + // find wildcard end (either '/' or path end) + end := i + 1 + for end < max && path[end] != '/' { + switch path[end] { + // the wildcard name must not contain ':' and '*' + case ':', '*': + panic("only one wildcard per path segment is allowed, has: '" + + path[i:] + "' in path '" + fullPath + "'") + default: + end++ + } + } + + // check if this Node existing children which would be + // unreachable if we insert the wildcard here + if len(n.children) > 0 { + panic("wildcard route '" + path[i:end] + + "' conflicts with existing children in path '" + fullPath + "'") + } + + // check if the wildcard has a name + if end-i < 2 { + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") + } + + if c == ':' { // param + // split path at the beginning of the wildcard + if i > 0 { + n.path = path[offset:i] + offset = i + } + + child := &node{ + nType: param, + maxParams: numParams, + } + n.children = []*node{child} + n.wildChild = true + n = child + n.priority++ + numParams-- + + // if the path doesn't end with the wildcard, then there + // will be another non-wildcard subpath starting with '/' + if end < max { + n.path = path[offset:end] + offset = end + + child := &node{ + maxParams: numParams, + priority: 1, + } + n.children = []*node{child} + n = child + } + + } else { // catchAll + if end != max || numParams > 1 { + panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") + } + + if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { + panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'") + } + + // currently fixed width 1 for '/' + i-- + if path[i] != '/' { + panic("no / before catch-all in path '" + fullPath + "'") + } + + n.path = path[offset:i] + + // first node: catchAll node with empty path + child := &node{ + wildChild: true, + nType: catchAll, + maxParams: 1, + } + n.children = []*node{child} + n.indices = string(path[i]) + n = child + n.priority++ + + // second node: node holding the variable + child = &node{ + path: path[i:], + nType: catchAll, + maxParams: 1, + handle: handle, + priority: 1, + } + n.children = []*node{child} + + return + } + } + + // insert remaining path part and handle to the leaf + n.path = path[offset:] + n.handle = handle +} + +// Returns the handle registered with the given path (key). The values of +// wildcards are saved to a map. +// If no handle can be found, a TSR (trailing slash redirect) recommendation is +// made if a handle exists with an extra (without the) trailing slash for the +// given path. +func (n *node) getValue(path string) (handle Handle, p Params, tsr bool) { +walk: // outer loop for walking the tree + for { + if len(path) > len(n.path) { + if path[:len(n.path)] == n.path { + path = path[len(n.path):] + // If this node does not have a wildcard (param or catchAll) + // child, we can just look up the next child node and continue + // to walk down the tree + if !n.wildChild { + c := path[0] + for i := 0; i < len(n.indices); i++ { + if c == n.indices[i] { + n = n.children[i] + continue walk + } + } + + // Nothing found. + // We can recommend to redirect to the same URL without a + // trailing slash if a leaf exists for that path. + tsr = (path == "/" && n.handle != nil) + return + + } + + // handle wildcard child + n = n.children[0] + switch n.nType { + case param: + // find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // save param value + if p == nil { + // lazy allocation + p = make(Params, 0, n.maxParams) + } + i := len(p) + p = p[:i+1] // expand slice within preallocated capacity + p[i].Key = n.path[1:] + p[i].Value = path[:end] + + // we need to go deeper! + if end < len(path) { + if len(n.children) > 0 { + path = path[end:] + n = n.children[0] + continue walk + } + + // ... but we can't + tsr = (len(path) == end+1) + return + } + + if handle = n.handle; handle != nil { + return + } else if len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists for TSR recommendation + n = n.children[0] + tsr = (n.path == "/" && n.handle != nil) + } + + return + + case catchAll: + // save param value + if p == nil { + // lazy allocation + p = make(Params, 0, n.maxParams) + } + i := len(p) + p = p[:i+1] // expand slice within preallocated capacity + p[i].Key = n.path[2:] + p[i].Value = path + + handle = n.handle + return + + default: + panic("invalid node type") + } + } + } else if path == n.path { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if handle = n.handle; handle != nil { + return + } + + if path == "/" && n.wildChild && n.nType != root { + tsr = true + return + } + + // No handle found. Check if a handle for this path + a + // trailing slash exists for trailing slash recommendation + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { + n = n.children[i] + tsr = (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) + return + } + } + + return + } + + // Nothing found. We can recommend to redirect to the same URL with an + // extra trailing slash if a leaf exists for that path + tsr = (path == "/") || + (len(n.path) == len(path)+1 && n.path[len(path)] == '/' && + path == n.path[:len(n.path)-1] && n.handle != nil) + return + } +} + +// Makes a case-insensitive lookup of the given path and tries to find a handler. +// It can optionally also fix trailing slashes. +// It returns the case-corrected path and a bool indicating whether the lookup +// was successful. +func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) { + return n.findCaseInsensitivePathRec( + path, + strings.ToLower(path), + make([]byte, 0, len(path)+1), // preallocate enough memory for new path + [4]byte{}, // empty rune buffer + fixTrailingSlash, + ) +} + +// shift bytes in array by n bytes left +func shiftNRuneBytes(rb [4]byte, n int) [4]byte { + switch n { + case 0: + return rb + case 1: + return [4]byte{rb[1], rb[2], rb[3], 0} + case 2: + return [4]byte{rb[2], rb[3]} + case 3: + return [4]byte{rb[3]} + default: + return [4]byte{} + } +} + +// recursive case-insensitive lookup function used by n.findCaseInsensitivePath +func (n *node) findCaseInsensitivePathRec(path, loPath string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) ([]byte, bool) { + loNPath := strings.ToLower(n.path) + +walk: // outer loop for walking the tree + for len(loPath) >= len(loNPath) && (len(loNPath) == 0 || loPath[1:len(loNPath)] == loNPath[1:]) { + // add common path to result + ciPath = append(ciPath, n.path...) + + if path = path[len(n.path):]; len(path) > 0 { + loOld := loPath + loPath = loPath[len(loNPath):] + + // If this node does not have a wildcard (param or catchAll) child, + // we can just look up the next child node and continue to walk down + // the tree + if !n.wildChild { + // skip rune bytes already processed + rb = shiftNRuneBytes(rb, len(loNPath)) + + if rb[0] != 0 { + // old rune not finished + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == rb[0] { + // continue with child node + n = n.children[i] + loNPath = strings.ToLower(n.path) + continue walk + } + } + } else { + // process a new rune + var rv rune + + // find rune start + // runes are up to 4 byte long, + // -4 would definitely be another rune + var off int + for max := min(len(loNPath), 3); off < max; off++ { + if i := len(loNPath) - off; utf8.RuneStart(loOld[i]) { + // read rune from cached lowercase path + rv, _ = utf8.DecodeRuneInString(loOld[i:]) + break + } + } + + // calculate lowercase bytes of current rune + utf8.EncodeRune(rb[:], rv) + // skipp already processed bytes + rb = shiftNRuneBytes(rb, off) + + for i := 0; i < len(n.indices); i++ { + // lowercase matches + if n.indices[i] == rb[0] { + // must use a recursive approach since both the + // uppercase byte and the lowercase byte might exist + // as an index + if out, found := n.children[i].findCaseInsensitivePathRec( + path, loPath, ciPath, rb, fixTrailingSlash, + ); found { + return out, true + } + break + } + } + + // same for uppercase rune, if it differs + if up := unicode.ToUpper(rv); up != rv { + utf8.EncodeRune(rb[:], up) + rb = shiftNRuneBytes(rb, off) + + for i := 0; i < len(n.indices); i++ { + // uppercase matches + if n.indices[i] == rb[0] { + // continue with child node + n = n.children[i] + loNPath = strings.ToLower(n.path) + continue walk + } + } + } + } + + // Nothing found. We can recommend to redirect to the same URL + // without a trailing slash if a leaf exists for that path + return ciPath, (fixTrailingSlash && path == "/" && n.handle != nil) + } + + n = n.children[0] + switch n.nType { + case param: + // find param end (either '/' or path end) + k := 0 + for k < len(path) && path[k] != '/' { + k++ + } + + // add param value to case insensitive path + ciPath = append(ciPath, path[:k]...) + + // we need to go deeper! + if k < len(path) { + if len(n.children) > 0 { + // continue with child node + n = n.children[0] + loNPath = strings.ToLower(n.path) + loPath = loPath[k:] + path = path[k:] + continue + } + + // ... but we can't + if fixTrailingSlash && len(path) == k+1 { + return ciPath, true + } + return ciPath, false + } + + if n.handle != nil { + return ciPath, true + } else if fixTrailingSlash && len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists + n = n.children[0] + if n.path == "/" && n.handle != nil { + return append(ciPath, '/'), true + } + } + return ciPath, false + + case catchAll: + return append(ciPath, path...), true + + default: + panic("invalid node type") + } + } else { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if n.handle != nil { + return ciPath, true + } + + // No handle found. + // Try to fix the path by adding a trailing slash + if fixTrailingSlash { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { + n = n.children[i] + if (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) { + return append(ciPath, '/'), true + } + return ciPath, false + } + } + } + return ciPath, false + } + } + + // Nothing found. + // Try to fix the path by adding / removing a trailing slash + if fixTrailingSlash { + if path == "/" { + return ciPath, true + } + if len(loPath)+1 == len(loNPath) && loNPath[len(loPath)] == '/' && + loPath[1:] == loNPath[1:len(loPath)] && n.handle != nil { + return append(ciPath, n.path...), true + } + } + return ciPath, false +} diff --git a/readwrite-handler.go b/readwrite-handler.go new file mode 100644 index 0000000..4b3024d --- /dev/null +++ b/readwrite-handler.go @@ -0,0 +1,101 @@ +package server + +import ( + "sync/atomic" + "time" +) + +type ReadWriteHandler interface { + GetMaxMessageSize() int64 + GetReadBufferSize() int + GetWriteBufferSize() int + GetReadTimeout() time.Duration + GetWriteTimeout() time.Duration +} + +type ReadWriteHandlers struct { + MaxMessageSize int64 `json:"maxMessageSize,omitempty"` + // Per-connection buffer size for requests' reading. + // This also limits the maximum header size. + // + // Increase this buffer if your clients send multi-KB RequestURIs + // and/or multi-KB headers (for example, BIG cookies). + // + // Default buffer size is used if not set. + ReadBufferSize int `json:"readBufferSize,omitempty"` + // Per-connection buffer size for responses' writing. + // + // Default buffer size is used if not set. + WriteBufferSize int `json:"writeBufferSize,omitempty"` + // Maximum duration for reading the full request (including body). + // + // This also limits the maximum duration for idle keep-alive + // connections. + // + // By default request read timeout is unlimited. + ReadTimeout time.Duration `json:"readTimeout,omitempty"` + + // Maximum duration for writing the full response (including body). + // + // By default response write timeout is unlimited. + WriteTimeout time.Duration `json:"writeTimeout,omitempty"` + + validated atomic.Value +} + +func (rwh *ReadWriteHandlers) GetMaxMessageSize() int64 { + return rwh.MaxMessageSize +} +func (rwh *ReadWriteHandlers) GetReadBufferSize() int { + return rwh.ReadBufferSize +} +func (rwh *ReadWriteHandlers) GetWriteBufferSize() int { + return rwh.WriteBufferSize +} +func (rwh *ReadWriteHandlers) GetReadTimeout() time.Duration { + return rwh.ReadTimeout +} +func (rwh *ReadWriteHandlers) GetWriteTimeout() time.Duration { + return rwh.WriteTimeout +} + +func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers { + return &ReadWriteHandlers{ + MaxMessageSize: rwh.MaxMessageSize, + ReadBufferSize: rwh.ReadBufferSize, + WriteBufferSize: rwh.WriteBufferSize, + ReadTimeout: rwh.ReadTimeout, + WriteTimeout: rwh.WriteTimeout, + validated: rwh.validated, + } +} + +func (rwh *ReadWriteHandlers) Validate() error { + if nil != rwh.validated.Load() { + return nil + } + rwh.validated.Store(true) + + if rwh.MaxMessageSize <= 0 { + rwh.MaxMessageSize = DefaultMaxMessageSize + } + if rwh.ReadBufferSize <= 0 { + rwh.ReadBufferSize = DefaultReadBufferSize + } + if rwh.WriteBufferSize <= 0 { + rwh.WriteBufferSize = DefaultWriteBufferSize + } + if rwh.ReadTimeout <= 0 { + rwh.ReadTimeout = DefaultReadTimeout + } else { + rwh.ReadTimeout = rwh.ReadTimeout * time.Second + } + + if rwh.WriteTimeout <= 0 { + rwh.WriteTimeout = DefaultWriteTimeout + } else { + rwh.WriteTimeout = rwh.WriteTimeout * time.Second + } + + return nil +} diff --git a/server-ctx.go b/server-ctx.go new file mode 100644 index 0000000..592d270 --- /dev/null +++ b/server-ctx.go @@ -0,0 +1,19 @@ +package server + +import ( + ouc "git.loafle.net/overflow/util-go/ctx" +) + +type ServerCtx interface { + ouc.Ctx +} + +func NewServerCtx(parent ouc.Ctx) ServerCtx { + return &serverCtx{ + Ctx: ouc.NewCtx(parent), + } +} + +type serverCtx struct { + ouc.Ctx +} diff --git a/server-handler.go b/server-handler.go new file mode 100644 index 0000000..5d016b4 --- /dev/null +++ b/server-handler.go @@ -0,0 +1,69 @@ +package server + +import "sync/atomic" + +type ServerHandler interface { + ConnectionHandler + + GetName() string + ServerCtx() ServerCtx + + Init(serverCtx ServerCtx) error + OnStart(serverCtx ServerCtx) error + OnStop(serverCtx ServerCtx) + Destroy(serverCtx ServerCtx) + + Validate() error +} + +type ServerHandlers struct { + ConnectionHandlers + + // Server name for sending in response headers. + // + // Default server name is used if left blank. + Name string `json:"name,omitempty"` + + validated atomic.Value +} + +func (sh *ServerHandlers) ServerCtx() ServerCtx { + return NewServerCtx(nil) +} + +func (sh *ServerHandlers) Init(serverCtx ServerCtx) error { + return nil +} + +func (sh *ServerHandlers) OnStart(serverCtx ServerCtx) error { + return nil +} + +func (sh *ServerHandlers) OnStop(serverCtx ServerCtx) { + +} + +func (sh *ServerHandlers) Destroy(serverCtx ServerCtx) { + +} + +func (sh *ServerHandlers) GetName() string { + return sh.Name +} + +func (sh *ServerHandlers) Validate() error { + if nil != sh.validated.Load() { + return nil + } + sh.validated.Store(true) + + if err := sh.ConnectionHandlers.Validate(); nil != err { + return err + } + + if "" == sh.Name { + sh.Name = "Server" + } + + return nil +} diff --git a/servlet-ctx.go b/servlet-ctx.go new file mode 100644 index 0000000..d8fde7b --- /dev/null +++ b/servlet-ctx.go @@ -0,0 +1,28 @@ +package server + +import ( + ouc "git.loafle.net/overflow/util-go/ctx" +) + +type ServletCtx interface { + ouc.Ctx + + ServerCtx() ServerCtx +} + +func NewServletContext(parent ouc.Ctx, serverCtx ServerCtx) ServletCtx { + return &servletCtx{ + Ctx: ouc.NewCtx(parent), + serverCtx: serverCtx, + } +} + +type servletCtx struct { + ouc.Ctx + + serverCtx ServerCtx +} + +func (sc *servletCtx) ServerCtx() ServerCtx { + return sc.serverCtx +} diff --git a/servlet.go b/servlet.go new file mode 100644 index 0000000..d67adab --- /dev/null +++ b/servlet.go @@ -0,0 +1,10 @@ +package server + +type Servlet interface { + ServletCtx(serverCtx ServerCtx) ServletCtx + + Init(serverCtx ServerCtx) error + OnStart(serverCtx ServerCtx) error + OnStop(serverCtx ServerCtx) + Destroy(serverCtx ServerCtx) +} diff --git a/socket/client-conn-handler.go b/socket/client-conn-handler.go new file mode 100644 index 0000000..10d13a7 --- /dev/null +++ b/socket/client-conn-handler.go @@ -0,0 +1,64 @@ +package socket + +import ( + "sync/atomic" + "time" + + "git.loafle.net/overflow/server-go" +) + +type ClientConnHandler interface { + server.ConnectionHandler + + GetReconnectInterval() time.Duration + GetReconnectTryTime() int +} + +type ClientConnHandlers struct { + server.ConnectionHandlers + + ReconnectInterval time.Duration `json:"reconnectInterval,omitempty"` + ReconnectTryTime int `json:"reconnectTryTime,omitempty"` + + validated atomic.Value +} + +func (cch *ClientConnHandlers) GetReconnectInterval() time.Duration { + return cch.ReconnectInterval +} + +func (cch *ClientConnHandlers) GetReconnectTryTime() int { + return cch.ReconnectTryTime +} + +func (cch *ClientConnHandlers) Clone() *ClientConnHandlers { + return &ClientConnHandlers{ + ConnectionHandlers: *cch.ConnectionHandlers.Clone(), + ReconnectInterval: cch.ReconnectInterval, + ReconnectTryTime: cch.ReconnectTryTime, + validated: cch.validated, + } +} + +func (cch *ClientConnHandlers) Validate() error { + if nil != cch.validated.Load() { + return nil + } + cch.validated.Store(true) + + if err := cch.ConnectionHandlers.Validate(); nil != err { + return err + } + + if cch.ReconnectInterval <= 0 { + cch.ReconnectInterval = server.DefaultReconnectInterval + } else { + cch.ReconnectInterval = cch.ReconnectInterval * time.Second + } + + if cch.ReconnectTryTime <= 0 { + cch.ReconnectTryTime = server.DefaultReconnectTryTime + } + + return nil +} diff --git a/socket/client-readwriter.go b/socket/client-readwriter.go new file mode 100644 index 0000000..cdd8733 --- /dev/null +++ b/socket/client-readwriter.go @@ -0,0 +1,70 @@ +package socket + +import ( + "io" + "sync" + + olog "git.loafle.net/overflow/log-go" +) + +type ClientReadWriter struct { + ReadwriteHandler ReadWriteHandler + ReadChan chan<- SocketMessage + WriteChan <-chan SocketMessage + DisconnectedChan chan<- struct{} + ReconnectedChan <-chan Conn + ClientStopChan <-chan struct{} + ClientStopWg *sync.WaitGroup +} + +func (crw *ClientReadWriter) HandleConnection(conn Conn) { + + defer func() { + if nil != conn { + conn.Close() + } + olog.Logger().Info("disconnected") + crw.ClientStopWg.Done() + }() + + olog.Logger().Info("connected") + + var err error + + for { + if nil != err { + if IsUnexpectedCloseError(err) || io.EOF == err || io.ErrUnexpectedEOF == err { + crw.DisconnectedChan <- struct{}{} + newConn := <-crw.ReconnectedChan + if nil == newConn { + return + } + conn = newConn + } else { + return + } + } + + stopChan := make(chan struct{}) + + readerDoneChan := make(chan error) + writerDoneChan := make(chan error) + + go connReadHandler(crw.ReadwriteHandler, conn, stopChan, readerDoneChan, crw.ReadChan) + go connWriteHandler(crw.ReadwriteHandler, conn, stopChan, writerDoneChan, crw.WriteChan) + + select { + case err = <-readerDoneChan: + close(stopChan) + <-writerDoneChan + case err = <-writerDoneChan: + close(stopChan) + <-readerDoneChan + case <-crw.ClientStopChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + return + } + } +} diff --git a/socket/client/client-ctx.go b/socket/client/client-ctx.go new file mode 100644 index 0000000..4b04088 --- /dev/null +++ b/socket/client/client-ctx.go @@ -0,0 +1,19 @@ +package client + +import ( + ouc "git.loafle.net/overflow/util-go/ctx" +) + +type ClientCtx interface { + ouc.Ctx +} + +func NewClientCtx(parent ouc.Ctx) ClientCtx { + return &clientCtx{ + Ctx: ouc.NewCtx(parent), + } +} + +type clientCtx struct { + ouc.Ctx +} diff --git a/socket/client/client-handler.go b/socket/client/client-handler.go new file mode 100644 index 0000000..376afe0 --- /dev/null +++ b/socket/client/client-handler.go @@ -0,0 +1,72 @@ +package client + +import ( + "fmt" + "sync/atomic" +) + +type ClientHandler interface { + GetName() string + GetConnector() Connector + + ClientCtx() ClientCtx + + Init(clientCtx ClientCtx) error + OnStart(clientCtx ClientCtx) error + OnStop(clientCtx ClientCtx) + Destroy(clientCtx ClientCtx) + + Validate() error +} + +type ClientHandlers struct { + Name string `json:"name,omitempty"` + Connector Connector `json:"-"` + + validated atomic.Value +} + +func (ch *ClientHandlers) ClientCtx() ClientCtx { + return NewClientCtx(nil) +} + +func (ch *ClientHandlers) Init(clientCtx ClientCtx) error { + return nil +} + +func (ch *ClientHandlers) OnStart(clientCtx ClientCtx) error { + return nil +} + +func (ch *ClientHandlers) OnStop(clientCtx ClientCtx) { + +} + +func (ch *ClientHandlers) Destroy(clientCtx ClientCtx) { + +} + +func (ch *ClientHandlers) GetName() string { + return ch.Name +} + +func (ch *ClientHandlers) GetConnector() Connector { + return ch.Connector +} + +func (ch *ClientHandlers) Validate() error { + if nil != ch.validated.Load() { + return nil + } + ch.validated.Store(true) + + if nil == ch.Connector { + return fmt.Errorf("Connector is not valid") + } + + if err := ch.Connector.Validate(); nil != err { + return err + } + + return nil +} diff --git a/socket/client/connector.go b/socket/client/connector.go new file mode 100644 index 0000000..5c684cc --- /dev/null +++ b/socket/client/connector.go @@ -0,0 +1,75 @@ +package client + +import ( + "sync/atomic" + + "git.loafle.net/overflow/server-go/socket" +) + +type OnDisconnectedFunc func(connector Connector) + +type Connector interface { + socket.ClientConnHandler + socket.ReadWriteHandler + + Connect() (readChan <-chan socket.SocketMessage, writeChan chan<- socket.SocketMessage, err error) + Disconnect() error + + GetName() string + GetOnDisconnected() OnDisconnectedFunc + SetOnDisconnected(fnc OnDisconnectedFunc) + + Clone() Connector + Validate() error +} + +type Connectors struct { + socket.ClientConnHandlers + socket.ReadWriteHandlers + + Name string `json:"name,omitempty"` + OnDisconnected OnDisconnectedFunc `json:"-"` + + validated atomic.Value +} + +func (c *Connectors) GetName() string { + return c.Name +} + +func (c *Connectors) GetOnDisconnected() OnDisconnectedFunc { + return c.OnDisconnected +} + +func (c *Connectors) SetOnDisconnected(fnc OnDisconnectedFunc) { + c.OnDisconnected = fnc +} + +func (c *Connectors) Clone() *Connectors { + return &Connectors{ + Name: c.Name, + ClientConnHandlers: *c.ClientConnHandlers.Clone(), + ReadWriteHandlers: *c.ReadWriteHandlers.Clone(), + validated: c.validated, + } +} + +func (c *Connectors) Validate() error { + if nil != c.validated.Load() { + return nil + } + c.validated.Store(true) + + if err := c.ClientConnHandlers.Validate(); nil != err { + return err + } + if err := c.ReadWriteHandlers.Validate(); nil != err { + return err + } + + if "" == c.Name { + c.Name = "Connector" + } + + return nil +} diff --git a/socket/conn-compression.go b/socket/conn-compression.go new file mode 100644 index 0000000..acc279f --- /dev/null +++ b/socket/conn-compression.go @@ -0,0 +1,147 @@ +package socket + +import ( + "compress/flate" + "errors" + "io" + "strings" + "sync" +) + +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 + defaultCompressionThreshold = 1024 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func DecompressNoContextTakeover(r io.Reader) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{fr} +} + +func IsValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func CompressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWriteWrapper struct { + fw *flate.Writer + tw *truncWriter + p *sync.Pool +} + +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } + + return w.fw.Write(p) +} + +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } + err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/socket/conn-error.go b/socket/conn-error.go new file mode 100644 index 0000000..64e68e8 --- /dev/null +++ b/socket/conn-error.go @@ -0,0 +1,111 @@ +package socket + +import ( + "errors" + "io" + "strconv" +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("socket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("socket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents close frame. +type CloseError struct { + + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("socket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "socket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("socket: bad write message type") + errWriteClosed = errors.New("socket: write closed") + errInvalidControlFrame = errors.New("socket: invalid control frame") + // ErrBadHandshake is returned when the server response to opening handshake is + // invalid. + ErrBadHandshake = errors.New("socket: bad handshake") + ErrInvalidCompression = errors.New("socket: invalid compression negotiation") +) diff --git a/socket/conn-json.go b/socket/conn-json.go new file mode 100644 index 0000000..5683eda --- /dev/null +++ b/socket/conn-json.go @@ -0,0 +1,55 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socket + +import ( + "encoding/json" + "io" +) + +// WriteJSON is deprecated, use c.WriteJSON instead. +func WriteJSON(c *SocketConn, v interface{}) error { + return c.WriteJSON(v) +} + +// WriteJSON writes the JSON encoding of v to the connection. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *SocketConn) WriteJSON(v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON is deprecated, use c.ReadJSON instead. +func ReadJSON(c *SocketConn, v interface{}) error { + return c.ReadJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *SocketConn) ReadJSON(v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} diff --git a/socket/conn-mask.go b/socket/conn-mask.go new file mode 100644 index 0000000..04d394c --- /dev/null +++ b/socket/conn-mask.go @@ -0,0 +1,49 @@ +package socket + +import "unsafe" + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytes(key [4]byte, pos int, b []byte) int { + + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/socket/conn-prepared.go b/socket/conn-prepared.go new file mode 100644 index 0000000..30d2dff --- /dev/null +++ b/socket/conn-prepared.go @@ -0,0 +1,100 @@ +package socket + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + err error + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan bool, 1) + mu <- true + var nc prepareConn + c := &SocketConn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + + if key.compress { + c.newCompressionWriter = CompressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/socket/conn-read.go b/socket/conn-read.go new file mode 100644 index 0000000..948ba70 --- /dev/null +++ b/socket/conn-read.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.5 + +package socket + +import "io" + +func (c *SocketConn) read(n int) ([]byte, error) { + p, err := c.BuffReader.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.BuffReader.Discard(len(p)) + return p, err +} diff --git a/socket/conn.go b/socket/conn.go new file mode 100644 index 0000000..0626f79 --- /dev/null +++ b/socket/conn.go @@ -0,0 +1,1106 @@ +package socket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strconv" + "sync" + "time" + "unicode/utf8" +) + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + +func hideTempErr(err error) error { + if e, ok := err.(net.Error); ok && e.Temporary() { + err = &netError{msg: e.Error(), timeout: e.Timeout()} + } + return err +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +type Conn interface { + Close() error + CloseHandler() func(code int, text string) error + EnableWriteCompression(enable bool) + LocalAddr() net.Addr + NextReader() (messageType int, r io.Reader, err error) + NextWriter(messageType int) (io.WriteCloser, error) + PingHandler() func(appData string) error + PongHandler() func(appData string) error + ReadJSON(v interface{}) error + ReadMessage() (messageType int, p []byte, err error) + RemoteAddr() net.Addr + SetCloseHandler(h func(code int, text string) error) + SetCompressionLevel(level int) error + SetPingHandler(h func(appData string) error) + SetPongHandler(h func(appData string) error) + SetReadDeadline(t time.Time) error + SetReadLimit(limit int64) + SetWriteDeadline(t time.Time) error + Subprotocol() string + UnderlyingConn() net.Conn + WriteControl(messageType int, data []byte, deadline time.Time) error + WriteJSON(v interface{}) error + WriteMessage(messageType int, data []byte) error + WritePreparedMessage(pm *PreparedMessage) error +} + +// The SocketConn type represents a WebSocket connection. +type SocketConn struct { + conn net.Conn + isServer bool + subprotocol string + + // Write fields + mu chan bool // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + + // Read fields + reader io.ReadCloser // the current reader returned to the application + readErr error + BuffReader *bufio.Reader + readRemaining int64 // bytes remaining in current frame. + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + handleClose func(int, string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser +} + +func NewConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *SocketConn { + return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) +} + +type writeHook struct { + p []byte +} + +func (wh *writeHook) Write(p []byte) (int, error) { + wh.p = p + return len(p), nil +} + +func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *SocketConn { + mu := make(chan bool, 1) + mu <- true + + var br *bufio.Reader + if readBufferSize == 0 && brw != nil && brw.Reader != nil { + // Reuse the supplied bufio.Reader if the buffer has a useful size. + // This code assumes that peek on a reader returns + // bufio.Reader.buf[:0]. + brw.Reader.Reset(conn) + if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { + br = brw.Reader + } + } + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } + if readBufferSize < maxControlFramePayloadSize { + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + var writeBuf []byte + if writeBufferSize == 0 && brw != nil && brw.Writer != nil { + // Use the bufio.Writer's buffer if the buffer has a useful size. This + // code assumes that bufio.Writer.buf[:1] is passed to the + // bufio.Writer's underlying writer. + var wh writeHook + brw.Writer.Reset(&wh) + brw.Writer.WriteByte(0) + brw.Flush() + if cap(wh.p) >= maxFrameHeaderSize+256 { + writeBuf = wh.p[:cap(wh.p)] + } + } + + if writeBuf == nil { + if writeBufferSize == 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) + } + + c := &SocketConn{ + isServer: isServer, + BuffReader: br, + conn: conn, + mu: mu, + readFinal: true, + writeBuf: writeBuf, + enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, + } + c.SetCloseHandler(nil) + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// Close closes the underlying network connection without sending or waiting for a close frame. +func (c *SocketConn) Close() error { + return c.conn.Close() +} + +func (c *SocketConn) Subprotocol() string { + return c.subprotocol +} + +func (c *SocketConn) SetSubprotocol(subprotocol string) { + c.subprotocol = subprotocol +} + +// LocalAddr returns the local network address. +func (c *SocketConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *SocketConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *SocketConn) SetNewCompressionWriter(w func(io.WriteCloser, int) io.WriteCloser) { + c.newCompressionWriter = w +} +func (c *SocketConn) SetNewDecompressionReader(r func(io.Reader) io.ReadCloser) { + c.newDecompressionReader = r +} + +// Write methods + +func (c *SocketConn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *SocketConn) write(frameType int, deadline time.Time, bufs ...[]byte) error { + <-c.mu + defer func() { c.mu <- true }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + for _, buf := range bufs { + if len(buf) > 0 { + _, err := c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + } + } + + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return nil +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *SocketConn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + d := time.Hour * 1000 + if !deadline.IsZero() { + d = deadline.Sub(time.Now()) + if d < 0 { + return errWriteTimeout + } + } + + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + defer func() { c.mu <- true }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return err +} + +func (c *SocketConn) prepWrite(messageType int) error { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil + } + + if !isControl(messageType) && !isData(messageType) { + return errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + return err +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) { +// if err := c.prepWrite(messageType); err != nil { +// return nil, err +// } + +// mw := &messageWriter{ +// c: c, +// frameType: messageType, +// pos: maxFrameHeaderSize, +// } +// c.writer = mw +// if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { +// w := c.newCompressionWriter(c.writer, c.compressionLevel) +// mw.compress = true +// c.writer = w +// } +// return c.writer, nil +// } + +// NextWriterWithUseCompress returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +func (c *SocketConn) NextWriter(messageType int) (io.WriteCloser, error) { + if err := c.prepWrite(messageType); err != nil { + return nil, err + } + + mw := &messageWriter{ + c: c, + frameType: messageType, + pos: maxFrameHeaderSize, + } + c.writer = mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *SocketConn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) fatal(err error) error { + if w.err != nil { + w.err = err + w.c.writer = nil + } + return err +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(w.frameType) && + (!final || length > maxControlFramePayloadSize) { + return w.fatal(errInvalidControlFrame) + } + + b0 := byte(w.frameType) + if final { + b0 |= finalBit + } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) + if len(extra) > 0 { + return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + if err != nil { + return w.fatal(err) + } + + if final { + c.writer = nil + return nil + } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos + if n <= 0 { + if err := w.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.pos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err + } + for { + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + if err := w.flushFrame(true, nil); err != nil { + return err + } + w.err = errWriteClosed + return nil +} + +// WritePreparedMessage writes prepared message into connection. +func (c *SocketConn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *SocketConn) WriteMessage(messageType int, data []byte) error { + + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + if err := c.prepWrite(messageType); err != nil { + return err + } + mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize} + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *SocketConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *SocketConn) advanceFrame() (int, error) { + + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(ioutil.Discard, c.BuffReader, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + final := p[0]&finalBit != 0 + frameType := int(p[0] & 0xf) + mask := p[1]&maskBit != 0 + c.readRemaining = int64(p[1] & 0x7f) + + c.readDecompress = false + if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + c.readDecompress = true + p[0] &^= rsv1Bit + } + + if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + return noFrame, c.handleProtocolError("control frame length > 125") + } + if !final { + return noFrame, c.handleProtocolError("control frame not final") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + return noFrame, c.handleProtocolError("message start before final message frame") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + return noFrame, c.handleProtocolError("continuation after final message frame") + } + c.readFinal = final + default: + return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + } + + // 3. Read and parse frame length. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint16(p)) + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + c.readRemaining = int64(binary.BigEndian.Uint64(p)) + } + + // 4. Handle frame masking. + + if mask != c.isServer { + return noFrame, c.handleProtocolError("incorrect mask flag") + } + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + if c.readLimit > 0 && c.readLength > c.readLimit { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + c.readRemaining = 0 + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("invalid close code") + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *SocketConn) handleProtocolError(message string) error { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *SocketConn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = hideTempErr(err) + break + } + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + c.reader = c.messageReader + if c.readDecompress { + c.reader = c.newDecompressionReader(c.reader) + } + return frameType, c.reader, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *SocketConn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.BuffReader.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + c.readRemaining -= int64(n) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = hideTempErr(err) + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +func (r *messageReader) Close() error { + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *SocketConn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = ioutil.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *SocketConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size for a message read from the peer. If a +// message exceeds the limit, the connection sends a close frame to the peer +// and returns ErrReadLimit to the application. +func (c *SocketConn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// CloseHandler returns the current close handler +func (c *SocketConn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close frame +// back to the peer. +// +// The application must read the connection to process close messages as +// described in the section on Control Frames above. +// +// The connection read methods return a CloseError when a close frame is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close frame back to +// the peer. +func (c *SocketConn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := []byte{} + if code != CloseNoStatusReceived { + message = FormatCloseMessage(code, "") + } + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *SocketConn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING frame application data. The default +// ping handler sends a pong to the peer. +// +// The application must read the connection to process ping messages as +// described in the section on Control Frames above. +func (c *SocketConn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + if err == ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *SocketConn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG frame application data. The default +// pong handler does nothing. +// +// The application must read the connection to process ping messages as +// described in the section on Control Frames above. +func (c *SocketConn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +func (c *SocketConn) UnderlyingConn() net.Conn { + return c.conn +} + +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *SocketConn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *SocketConn) SetCompressionLevel(level int) error { + if !IsValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +func FormatCloseMessage(closeCode int, text string) []byte { + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/socket/net/client/connector.go b/socket/net/client/connector.go new file mode 100644 index 0000000..8158a58 --- /dev/null +++ b/socket/net/client/connector.go @@ -0,0 +1,219 @@ +package client + +import ( + "crypto/tls" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + olog "git.loafle.net/overflow/log-go" + "git.loafle.net/overflow/server-go/socket" + "git.loafle.net/overflow/server-go/socket/client" +) + +type Connectors struct { + client.Connectors + + Network string `json:"network,omitempty"` + Address string `json:"address,omitempty"` + LocalAddress net.Addr `json:"-"` + + stopChan chan struct{} + stopWg sync.WaitGroup + + readChan chan socket.SocketMessage + writeChan chan socket.SocketMessage + + disconnectedChan chan struct{} + reconnectedChan chan socket.Conn + + crw socket.ClientReadWriter + + validated atomic.Value +} + +func (c *Connectors) Connect() (readChan <-chan socket.SocketMessage, writeChan chan<- socket.SocketMessage, err error) { + var ( + conn socket.Conn + ) + + if nil != c.stopChan { + return nil, nil, fmt.Errorf("%s already connected", c.logHeader()) + } + + conn, err = c.connect() + if nil != err { + return nil, nil, err + } + + c.readChan = make(chan socket.SocketMessage, 256) + c.writeChan = make(chan socket.SocketMessage, 256) + c.disconnectedChan = make(chan struct{}) + c.reconnectedChan = make(chan socket.Conn) + c.stopChan = make(chan struct{}) + + c.crw.ReadwriteHandler = c + c.crw.ReadChan = c.readChan + c.crw.WriteChan = c.writeChan + c.crw.ClientStopChan = c.stopChan + c.crw.ClientStopWg = &c.stopWg + c.crw.DisconnectedChan = c.disconnectedChan + c.crw.ReconnectedChan = c.reconnectedChan + + c.stopWg.Add(1) + go c.handleReconnect() + c.stopWg.Add(1) + go c.crw.HandleConnection(conn) + + return c.readChan, c.writeChan, nil +} + +func (c *Connectors) Disconnect() error { + if c.stopChan == nil { + return fmt.Errorf("%s must be connected before disconnection it", c.logHeader()) + } + close(c.stopChan) + c.stopWg.Wait() + + c.stopChan = nil + + return nil +} + +func (c *Connectors) logHeader() string { + return fmt.Sprintf("Connector[%s]: ", c.Name) +} + +func (c *Connectors) onDisconnected() { + close(c.readChan) + close(c.writeChan) + + c.reconnectedChan <- nil + + onDisconnected := c.OnDisconnected + if nil != onDisconnected { + go func() { + onDisconnected(c) + }() + } +} + +func (c *Connectors) handleReconnect() { + defer func() { + c.stopWg.Done() + }() + +RC_LOOP: + for { + select { + case <-c.disconnectedChan: + case <-c.stopChan: + return + } + + if 0 >= c.GetReconnectTryTime() { + c.onDisconnected() + return + } + + olog.Logger().Debugf("%s connection lost", c.logHeader()) + + for indexI := 0; indexI < c.GetReconnectTryTime(); indexI++ { + olog.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI) + + conn, err := c.connect() + if nil == err { + olog.Logger().Debugf("reconnected") + c.reconnectedChan <- conn + continue RC_LOOP + } + time.Sleep(c.GetReconnectInterval()) + } + + olog.Logger().Debugf("%s reconnecting has been failed", c.logHeader()) + c.onDisconnected() + return + } +} + +func (c *Connectors) connect() (socket.Conn, error) { + netConn, err := c.dial() + if nil != err { + return nil, err + } + + conn := socket.NewConn(netConn, false, c.GetReadBufferSize(), c.GetWriteBufferSize()) + conn.SetCloseHandler(func(code int, text string) error { + olog.Logger().Debugf("%s close", c.logHeader()) + return nil + }) + return conn, nil +} + +func (c *Connectors) dial() (net.Conn, error) { + var deadline time.Time + if 0 != c.GetHandshakeTimeout() { + deadline = time.Now().Add(c.GetHandshakeTimeout()) + } + + d := &net.Dialer{ + KeepAlive: c.GetKeepAlive(), + Deadline: deadline, + LocalAddr: c.LocalAddress, + } + + conn, err := d.Dial(c.Network, c.Address) + if nil != err { + return nil, err + } + + if nil != c.GetTLSConfig() { + cfg := c.GetTLSConfig().Clone() + tlsConn := tls.Client(conn, cfg) + if err := tlsConn.Handshake(); err != nil { + tlsConn.Close() + return nil, err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return nil, err + } + } + conn = tlsConn + } + + return conn, nil +} + +func (c *Connectors) Clone() client.Connector { + return &Connectors{ + Connectors: *c.Connectors.Clone(), + Network: c.Network, + Address: c.Address, + LocalAddress: c.LocalAddress, + validated: c.validated, + } +} + +func (c *Connectors) Validate() error { + if nil != c.validated.Load() { + return nil + } + c.validated.Store(true) + + if err := c.Connectors.Validate(); nil != err { + return err + } + + if "" == c.Network { + return fmt.Errorf("%s Network is not valid", c.logHeader()) + } + + if "" == c.Address { + return fmt.Errorf("%s Address is not valid", c.logHeader()) + } + + return nil +} diff --git a/socket/net/server-handler.go b/socket/net/server-handler.go new file mode 100644 index 0000000..c433194 --- /dev/null +++ b/socket/net/server-handler.go @@ -0,0 +1,94 @@ +package net + +import ( + "net" + "sync/atomic" + + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/socket" +) + +type ServerHandler interface { + socket.ServerHandler + + OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) + + RegisterServlet(servlet Servlet) + Servlet(serverCtx server.ServerCtx, conn net.Conn) Servlet +} + +type ServerHandlers struct { + socket.ServerHandlers + + servlet Servlet + + validated atomic.Value +} + +func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.Init(serverCtx); nil != err { + return err + } + + if nil != sh.servlet { + if err := sh.servlet.Init(serverCtx); nil != err { + return err + } + } + + return nil +} + +func (sh *ServerHandlers) OnStart(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.OnStart(serverCtx); nil != err { + return err + } + + if nil != sh.servlet { + if err := sh.servlet.OnStart(serverCtx); nil != err { + return err + } + } + + return nil +} + +func (sh *ServerHandlers) OnStop(serverCtx server.ServerCtx) { + if nil != sh.servlet { + sh.servlet.OnStop(serverCtx) + } + + sh.ServerHandlers.OnStop(serverCtx) +} + +func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { + if nil != sh.servlet { + sh.servlet.Destroy(serverCtx) + } + + sh.ServerHandlers.Destroy(serverCtx) +} + +func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, conn net.Conn, status int, reason error) { +} + +func (sh *ServerHandlers) RegisterServlet(servlet Servlet) { + sh.servlet = servlet +} + +func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, conn net.Conn) Servlet { + return sh.servlet +} + +func (sh *ServerHandlers) Validate() error { + if nil != sh.validated.Load() { + return nil + } + sh.validated.Store(true) + + if err := sh.ServerHandlers.Validate(); nil != err { + return err + } + + return nil +} diff --git a/socket/net/server.go b/socket/net/server.go new file mode 100644 index 0000000..ca68bae --- /dev/null +++ b/socket/net/server.go @@ -0,0 +1,167 @@ +package net + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + olog "git.loafle.net/overflow/log-go" + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/socket" +) + +type Server struct { + ServerHandler ServerHandler + + ctx server.ServerCtx + stopChan chan struct{} + stopWg sync.WaitGroup + + srw socket.ServerReadWriter +} + +func (s *Server) ListenAndServe() error { + if s.stopChan != nil { + return fmt.Errorf("%s already running. Stop it before starting it again", s.logHeader()) + } + + var ( + err error + listener net.Listener + ) + if nil == s.ServerHandler { + return fmt.Errorf("%s server handler must be specified", s.logHeader()) + } + if err = s.ServerHandler.Validate(); nil != err { + return err + } + + s.ctx = s.ServerHandler.ServerCtx() + if nil == s.ctx { + return fmt.Errorf("%s ServerCtx is nil", s.logHeader()) + } + + if err = s.ServerHandler.Init(s.ctx); nil != err { + return err + } + + if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { + return err + } + + s.stopChan = make(chan struct{}) + + s.srw.ReadwriteHandler = s.ServerHandler + s.srw.ServerStopChan = s.stopChan + s.srw.ServerStopWg = &s.stopWg + + s.stopWg.Add(1) + return s.handleServer(listener) +} + +func (s *Server) Shutdown(ctx context.Context) error { + if s.stopChan == nil { + return fmt.Errorf("%s must be started before stopping it", s.logHeader()) + } + close(s.stopChan) + s.stopWg.Wait() + + s.ServerHandler.Destroy(s.ctx) + + s.stopChan = nil + + return nil +} + +func (s *Server) logHeader() string { + return fmt.Sprintf("Server[%s]:", s.ServerHandler.GetName()) +} + +func (s *Server) handleServer(listener net.Listener) error { + var ( + stopping atomic.Value + netConn net.Conn + err error + ) + + defer func() { + if nil != listener { + listener.Close() + } + s.ServerHandler.OnStop(s.ctx) + olog.Logger().Infof("%s Stopped", s.logHeader()) + s.stopWg.Done() + }() + + if err = s.ServerHandler.OnStart(s.ctx); nil != err { + return err + } + + olog.Logger().Infof("%s Started", s.logHeader()) + + for { + acceptChan := make(chan struct{}) + + go func() { + if netConn, err = listener.Accept(); err != nil { + if nil == stopping.Load() { + olog.Logger().Errorf("%s %v", s.logHeader(), err) + } + } + close(acceptChan) + }() + + select { + case <-s.stopChan: + stopping.Store(true) + listener.Close() + <-acceptChan + listener = nil + return nil + case <-acceptChan: + } + + if nil != err { + select { + case <-s.stopChan: + return nil + case <-time.After(time.Second): + } + continue + } + + if 0 < s.ServerHandler.GetConcurrency() { + sz := s.srw.ConnectionSize() + if sz >= s.ServerHandler.GetConcurrency() { + olog.Logger().Warnf("%s max connections size %d, refuse", s.logHeader(), sz) + netConn.Close() + continue + } + } + + servlet := s.ServerHandler.(ServerHandler).Servlet(s.ctx, netConn) + if nil == servlet { + olog.Logger().Errorf("%s Servlet is nil", s.logHeader()) + continue + } + + servletCtx := servlet.ServletCtx(s.ctx) + if nil == servletCtx { + olog.Logger().Errorf("%s ServletCtx is nil", s.logHeader()) + continue + } + + if err := servlet.Handshake(servletCtx, netConn); nil != err { + olog.Logger().Infof("%s Handshaking of Client[%s] has been failed %v", s.logHeader(), netConn.RemoteAddr(), err) + continue + } + + conn := socket.NewConn(netConn, true, s.ServerHandler.GetReadBufferSize(), s.ServerHandler.GetWriteBufferSize()) + + s.stopWg.Add(1) + go s.srw.HandleConnection(servlet, servletCtx, conn) + } +} diff --git a/socket/net/servlet.go b/socket/net/servlet.go new file mode 100644 index 0000000..daa2b4b --- /dev/null +++ b/socket/net/servlet.go @@ -0,0 +1,54 @@ +package net + +import ( + "net" + + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/socket" +) + +type Servlet interface { + socket.Servlet + + Handshake(servletCtx server.ServletCtx, conn net.Conn) error +} + +type Servlets struct { + Servlet +} + +func (s *Servlets) ServletCtx(serverCtx server.ServerCtx) server.ServletCtx { + return server.NewServletContext(nil, serverCtx) +} + +func (s *Servlets) Init(serverCtx server.ServerCtx) error { + return nil +} + +func (s *Servlets) OnStart(serverCtx server.ServerCtx) error { + return nil +} + +func (s *Servlets) OnStop(serverCtx server.ServerCtx) { + // +} + +func (s *Servlets) Destroy(serverCtx server.ServerCtx) { + // +} + +func (s *Servlets) Handshake(servletCtx server.ServletCtx, conn net.Conn) error { + return nil +} + +func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn) { + // +} + +func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan socket.SocketMessage, writeChan chan<- socket.SocketMessage) { + +} + +func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) { + // +} diff --git a/socket/readwrite-handler.go b/socket/readwrite-handler.go new file mode 100644 index 0000000..10c2d4f --- /dev/null +++ b/socket/readwrite-handler.go @@ -0,0 +1,107 @@ +package socket + +import ( + "errors" + "sync/atomic" + "time" + + "git.loafle.net/overflow/server-go" +) + +type ReadWriteHandler interface { + server.ReadWriteHandler + GetPongTimeout() time.Duration + GetPingTimeout() time.Duration + GetPingPeriod() time.Duration + + IsEnableCompression() bool + GetCompressionLevel() int + GetCompressionThreshold() int +} + +type ReadWriteHandlers struct { + server.ReadWriteHandlers + + PongTimeout time.Duration `json:"pongTimeout,omitempty"` + PingTimeout time.Duration `json:"pingTimeout,omitempty"` + PingPeriod time.Duration `json:"pingPeriod,omitempty"` + + EnableCompression bool `json:"enableCompression,omitempty"` + CompressionLevel int `json:"compressionLevel,omitempty"` + CompressionThreshold int `json:"compressionThreshold,omitempty"` + + validated atomic.Value +} + +func (rwh *ReadWriteHandlers) GetPongTimeout() time.Duration { + return rwh.PongTimeout +} +func (rwh *ReadWriteHandlers) GetPingTimeout() time.Duration { + return rwh.PingTimeout +} +func (rwh *ReadWriteHandlers) GetPingPeriod() time.Duration { + return rwh.PingPeriod +} + +func (rwh *ReadWriteHandlers) IsEnableCompression() bool { + return rwh.EnableCompression +} +func (rwh *ReadWriteHandlers) GetCompressionLevel() int { + return rwh.CompressionLevel +} +func (rwh *ReadWriteHandlers) GetCompressionThreshold() int { + return rwh.CompressionThreshold +} + +func (rwh *ReadWriteHandlers) Clone() *ReadWriteHandlers { + return &ReadWriteHandlers{ + ReadWriteHandlers: *rwh.ReadWriteHandlers.Clone(), + PongTimeout: rwh.PongTimeout, + PingTimeout: rwh.PingTimeout, + PingPeriod: rwh.PingPeriod, + EnableCompression: rwh.EnableCompression, + CompressionLevel: rwh.CompressionLevel, + CompressionThreshold: rwh.CompressionThreshold, + validated: rwh.validated, + } +} + +func (rwh *ReadWriteHandlers) Validate() error { + if nil != rwh.validated.Load() { + return nil + } + rwh.validated.Store(true) + + if err := rwh.ReadWriteHandlers.Validate(); nil != err { + return err + } + + if rwh.PongTimeout <= 0 { + rwh.PongTimeout = server.DefaultPongTimeout + } else { + rwh.PongTimeout = rwh.PongTimeout * time.Second + } + + if rwh.PingTimeout <= 0 { + rwh.PingTimeout = server.DefaultPingTimeout + } else { + rwh.PingTimeout = rwh.PingTimeout * time.Second + } + if rwh.PingPeriod <= 0 { + rwh.PingPeriod = (rwh.PingTimeout * 9) / 10 + } else { + rwh.PingPeriod = rwh.PingPeriod * time.Second + } + + if rwh.EnableCompression { + if !IsValidCompressionLevel(rwh.CompressionLevel) { + return errors.New("Socket: invalid compression level") + } + } + + if 0 > rwh.CompressionThreshold { + rwh.CompressionThreshold = server.DefaultCompressionThreshold + } + + return nil +} diff --git a/socket/readwrite.go b/socket/readwrite.go new file mode 100644 index 0000000..3063102 --- /dev/null +++ b/socket/readwrite.go @@ -0,0 +1,123 @@ +package socket + +import ( + "fmt" + "time" + + olog "git.loafle.net/overflow/log-go" +) + +type SocketMessage func() (int, []byte) + +func MakeSocketMessage(messageType int, message []byte) SocketMessage { + return func() (int, []byte) { return messageType, message } +} + +func connReadHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, readChan chan<- SocketMessage) { + var ( + err error + ) + + defer func() { + doneChan <- err + }() + + if 0 < readWriteHandler.GetMaxMessageSize() { + conn.SetReadLimit(readWriteHandler.GetMaxMessageSize()) + } + + if 0 < readWriteHandler.GetReadTimeout() { + conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetReadTimeout())) + } else { + conn.SetReadDeadline(time.Time{}) + } + conn.SetPongHandler(func(string) error { + if 0 < readWriteHandler.GetPongTimeout() { + conn.SetReadDeadline(time.Now().Add(readWriteHandler.GetPongTimeout())) + } else { + conn.SetReadDeadline(time.Time{}) + } + return nil + }) + + for { + var message []byte + var messageType int + readMessageChan := make(chan struct{}) + + go func() { + messageType, message, err = conn.ReadMessage() + close(readMessageChan) + }() + + select { + case <-stopChan: + return + case <-readMessageChan: + } + + if nil != err { + if IsUnexpectedCloseError(err, CloseGoingAway, CloseAbnormalClosure) { + err = fmt.Errorf("Read error %v", err) + } + olog.Logger().Debug(err.Error()) + return + } + + readChan <- MakeSocketMessage(messageType, message) + } +} + +func connWriteHandler(readWriteHandler ReadWriteHandler, conn Conn, stopChan <-chan struct{}, doneChan chan<- error, writeChan <-chan SocketMessage) { + var ( + socketMessage SocketMessage + message []byte + messageType int + ok bool + err error + ) + + defer func() { + doneChan <- err + }() + + ticker := time.NewTicker(readWriteHandler.GetPingPeriod()) + defer func() { + ticker.Stop() + }() + for { + select { + case socketMessage, ok = <-writeChan: + if 0 < readWriteHandler.GetWriteTimeout() { + conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetWriteTimeout())) + } else { + conn.SetWriteDeadline(time.Time{}) + } + if !ok { + conn.WriteMessage(CloseMessage, []byte{}) + return + } + + messageType, message = socketMessage() + + err = conn.WriteMessage(messageType, message) + if err != nil { + olog.Logger().Debug(err.Error()) + return + } + + case <-ticker.C: + if 0 < readWriteHandler.GetPingTimeout() { + conn.SetWriteDeadline(time.Now().Add(readWriteHandler.GetPingTimeout())) + } else { + conn.SetWriteDeadline(time.Time{}) + } + if err = conn.WriteMessage(PingMessage, nil); nil != err { + olog.Logger().Debug(err.Error()) + return + } + case <-stopChan: + return + } + } +} diff --git a/socket/server-handler.go b/socket/server-handler.go new file mode 100644 index 0000000..70f8225 --- /dev/null +++ b/socket/server-handler.go @@ -0,0 +1,35 @@ +package socket + +import ( + "sync/atomic" + + "git.loafle.net/overflow/server-go" +) + +type ServerHandler interface { + server.ServerHandler + ReadWriteHandler +} + +type ServerHandlers struct { + server.ServerHandlers + ReadWriteHandlers + + validated atomic.Value +} + +func (sh *ServerHandlers) Validate() error { + if nil != sh.validated.Load() { + return nil + } + sh.validated.Store(true) + + if err := sh.ServerHandlers.Validate(); nil != err { + return err + } + if err := sh.ReadWriteHandlers.Validate(); nil != err { + return err + } + + return nil +} diff --git a/socket/server-readwriter.go b/socket/server-readwriter.go new file mode 100644 index 0000000..7640e45 --- /dev/null +++ b/socket/server-readwriter.go @@ -0,0 +1,84 @@ +package socket + +import ( + "sync" + + "go.uber.org/zap" + + olog "git.loafle.net/overflow/log-go" + "git.loafle.net/overflow/server-go" +) + +type ServerReadWriter struct { + connections sync.Map + + ReadwriteHandler ReadWriteHandler + ServerStopChan <-chan struct{} + ServerStopWg *sync.WaitGroup +} + +func (srw *ServerReadWriter) ConnectionSize() int { + var sz int + srw.connections.Range(func(k, v interface{}) bool { + sz++ + return true + }) + return sz +} + +func (srw *ServerReadWriter) HandleConnection(servlet Servlet, servletCtx server.ServletCtx, conn Conn) { + addr := conn.RemoteAddr() + + defer func() { + if nil != conn { + conn.Close() + } + servlet.OnDisconnect(servletCtx) + olog.Logger().Info("Client has been disconnected", zap.String("Address", addr.String())) + srw.ServerStopWg.Done() + }() + + olog.Logger().Info("Client has been connected", zap.String("Address", addr.String())) + + srw.connections.Store(conn, true) + defer srw.connections.Delete(conn) + + servlet.OnConnect(servletCtx, conn) + conn.SetCloseHandler(func(code int, text string) error { + olog.Logger().Debug("close") + return nil + }) + + stopChan := make(chan struct{}) + servletDoneChan := make(chan struct{}) + + readChan := make(chan SocketMessage) + writeChan := make(chan SocketMessage) + + readerDoneChan := make(chan error) + writerDoneChan := make(chan error) + + go connReadHandler(srw.ReadwriteHandler, conn, stopChan, readerDoneChan, readChan) + go connWriteHandler(srw.ReadwriteHandler, conn, stopChan, writerDoneChan, writeChan) + go servlet.Handle(servletCtx, stopChan, servletDoneChan, readChan, writeChan) + + select { + case <-readerDoneChan: + close(stopChan) + <-writerDoneChan + <-servletDoneChan + case <-writerDoneChan: + close(stopChan) + <-readerDoneChan + <-servletDoneChan + case <-servletDoneChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + case <-srw.ServerStopChan: + close(stopChan) + <-readerDoneChan + <-writerDoneChan + <-servletDoneChan + } +} diff --git a/socket/servlet.go b/socket/servlet.go new file mode 100644 index 0000000..0bf59f8 --- /dev/null +++ b/socket/servlet.go @@ -0,0 +1,13 @@ +package socket + +import ( + "git.loafle.net/overflow/server-go" +) + +type Servlet interface { + server.Servlet + + OnConnect(servletCtx server.ServletCtx, conn Conn) + Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan SocketMessage, writeChan chan<- SocketMessage) + OnDisconnect(servletCtx server.ServletCtx) +} diff --git a/socket/web/client/connector.go b/socket/web/client/connector.go new file mode 100644 index 0000000..ec5ed17 --- /dev/null +++ b/socket/web/client/connector.go @@ -0,0 +1,534 @@ +package client + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + olog "git.loafle.net/overflow/log-go" + "git.loafle.net/overflow/server-go/socket" + "git.loafle.net/overflow/server-go/socket/client" + "git.loafle.net/overflow/server-go/socket/web" +) + +var errMalformedURL = errors.New("malformed ws or wss URL") + +type Connectors struct { + client.Connectors + + URL string `json:"url,omitempty"` + + RequestHeader func() http.Header `json:"-"` + + Subprotocols []string `json:"subprotocols,omitempty"` + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + CookieJar http.CookieJar `json:"-"` + + ResponseHandler func(*http.Response) `json:"-"` + + // NetDial specifies the dial function for creating TCP connections. If + // NetDial is nil, net.Dial is used. + NetDial func(network, addr string) (net.Conn, error) `json:"-"` + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) `json:"-"` + + serverURL *url.URL + + stopChan chan struct{} + stopWg sync.WaitGroup + + readChan chan socket.SocketMessage + writeChan chan socket.SocketMessage + + disconnectedChan chan struct{} + reconnectedChan chan socket.Conn + + crw socket.ClientReadWriter + + validated atomic.Value +} + +func (c *Connectors) Connect() (readChan <-chan socket.SocketMessage, writeChan chan<- socket.SocketMessage, err error) { + var ( + conn socket.Conn + res *http.Response + ) + + if c.stopChan != nil { + return nil, nil, fmt.Errorf("%s already connected", c.logHeader()) + } + + conn, res, err = c.connect() + if nil != err { + return nil, nil, err + } + resH := c.ResponseHandler + if nil != resH { + resH(res) + } + + c.readChan = make(chan socket.SocketMessage, 256) + c.writeChan = make(chan socket.SocketMessage, 256) + c.disconnectedChan = make(chan struct{}) + c.reconnectedChan = make(chan socket.Conn) + c.stopChan = make(chan struct{}) + + c.crw.ReadwriteHandler = c + c.crw.ReadChan = c.readChan + c.crw.WriteChan = c.writeChan + c.crw.ClientStopChan = c.stopChan + c.crw.ClientStopWg = &c.stopWg + c.crw.DisconnectedChan = c.disconnectedChan + c.crw.ReconnectedChan = c.reconnectedChan + + c.stopWg.Add(1) + go c.handleReconnect() + c.stopWg.Add(1) + go c.crw.HandleConnection(conn) + + return c.readChan, c.writeChan, nil +} + +func (c *Connectors) Disconnect() error { + if c.stopChan == nil { + return fmt.Errorf("%s must be connected before disconnection it", c.logHeader()) + } + close(c.stopChan) + c.stopWg.Wait() + + c.stopChan = nil + + return nil +} + +func (c *Connectors) logHeader() string { + return fmt.Sprintf("Connector[%s]:", c.Name) +} + +func (c *Connectors) onDisconnected() { + close(c.readChan) + close(c.writeChan) + + c.reconnectedChan <- nil + + onDisconnected := c.OnDisconnected + if nil != onDisconnected { + go func() { + onDisconnected(c) + }() + } +} + +func (c *Connectors) handleReconnect() { + defer func() { + c.stopWg.Done() + }() + +RC_LOOP: + for { + select { + case <-c.disconnectedChan: + case <-c.stopChan: + return + } + + if 0 >= c.GetReconnectTryTime() { + c.onDisconnected() + return + } + + olog.Logger().Debugf("%s connection lost", c.logHeader()) + + for indexI := 0; indexI < c.GetReconnectTryTime(); indexI++ { + olog.Logger().Debugf("%s trying reconnect[%d]", c.logHeader(), indexI) + + conn, res, err := c.connect() + if nil == err { + resH := c.ResponseHandler + if nil != resH { + resH(res) + } + + olog.Logger().Debugf("%s reconnected", c.logHeader()) + c.reconnectedChan <- conn + continue RC_LOOP + } + time.Sleep(c.GetReconnectInterval()) + } + olog.Logger().Debugf("%s reconnecting has been failed", c.logHeader()) + + c.onDisconnected() + return + } +} + +func (c *Connectors) connect() (socket.Conn, *http.Response, error) { + conn, res, err := c.dial() + if nil != err { + return nil, nil, err + } + + conn.SetCloseHandler(func(code int, text string) error { + olog.Logger().Debugf("%s close", c.logHeader()) + return nil + }) + return conn, res, nil +} + +func (c *Connectors) dial() (socket.Conn, *http.Response, error) { + var ( + err error + challengeKey string + netConn net.Conn + ) + + challengeKey, err = web.GenerateChallengeKey() + if err != nil { + return nil, nil, err + } + + req := &http.Request{ + Method: "GET", + URL: c.serverURL, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: c.serverURL.Host, + } + + cookieJar := c.CookieJar + // Set the cookies present in the cookie jar of the dialer + if nil != cookieJar { + for _, cookie := range cookieJar.Cookies(c.serverURL) { + req.AddCookie(cookie) + } + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + + subprotocols := c.Subprotocols + + if len(subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(subprotocols, ", ")} + } + + for k, vs := range c.RequestHeader() { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || + (k == "Sec-Websocket-Protocol" && len(subprotocols) > 0): + return nil, nil, fmt.Errorf("%s duplicate header not allowed: %s", c.logHeader(), k) + default: + req.Header[k] = vs + } + } + + if c.IsEnableCompression() { + req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + + hostPort, hostNoPort := hostPortNoPort(c.serverURL) + + var proxyURL *url.URL + // Check wether the proxy method has been configured + proxy := c.Proxy + if nil != proxy { + proxyURL, err = proxy(req) + if err != nil { + return nil, nil, err + } + } + + var targetHostPort string + if proxyURL != nil { + targetHostPort, _ = hostPortNoPort(proxyURL) + } else { + targetHostPort = hostPort + } + + var deadline time.Time + handshakeTimeout := c.GetHandshakeTimeout() + if 0 != handshakeTimeout { + deadline = time.Now().Add(handshakeTimeout) + } + + netDial := c.NetDial + if netDial == nil { + netDialer := &net.Dialer{Deadline: deadline} + netDial = netDialer.Dial + } + + netConn, err = netDial("tcp", targetHostPort) + if err != nil { + return nil, nil, err + } + + defer func() { + if nil != netConn { + netConn.Close() + } + }() + + err = netConn.SetDeadline(deadline) + if nil != err { + return nil, nil, err + } + + if nil != proxyURL { + connectHeader := make(http.Header) + if user := proxyURL.User; nil != user { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: hostPort}, + Host: hostPort, + Header: connectHeader, + } + + connectReq.Write(netConn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(netConn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + return nil, nil, errors.New(f[1]) + } + } + + if "https" == c.serverURL.Scheme { + cfg := cloneTLSConfig(c.GetTLSConfig()) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + if err := tlsConn.Handshake(); err != nil { + return nil, nil, err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return nil, nil, err + } + } + } + + conn := socket.NewConn(netConn, false, c.GetReadBufferSize(), c.GetWriteBufferSize()) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + resp, err := http.ReadResponse(conn.BuffReader, req) + if err != nil { + return nil, nil, err + } + + if nil != cookieJar { + if rc := resp.Cookies(); len(rc) > 0 { + cookieJar.SetCookies(c.serverURL, rc) + } + } + + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != web.ComputeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, socket.ErrBadHandshake + } + + for _, ext := range web.HttpParseExtensions(resp.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, socket.ErrInvalidCompression + } + conn.SetNewCompressionWriter(socket.CompressNoContextTakeover) + conn.SetNewDecompressionReader(socket.DecompressNoContextTakeover) + break + } + + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + conn.SetSubprotocol(resp.Header.Get("Sec-Websocket-Protocol")) + + netConn.SetDeadline(time.Time{}) + netConn = nil // to avoid close in defer. + + return conn, resp, nil +} + +// parseURL parses the URL. +// +// This function is a replacement for the standard library url.Parse function. +// In Go 1.4 and earlier, url.Parse loses information from the path. +func parseURL(s string) (*url.URL, error) { + // From the RFC: + // + // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] + // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] + var u url.URL + switch { + case strings.HasPrefix(s, "ws://"): + u.Scheme = "ws" + s = s[len("ws://"):] + case strings.HasPrefix(s, "wss://"): + u.Scheme = "wss" + s = s[len("wss://"):] + default: + return nil, errMalformedURL + } + + if i := strings.Index(s, "?"); i >= 0 { + u.RawQuery = s[i+1:] + s = s[:i] + } + + if i := strings.Index(s, "/"); i >= 0 { + u.Opaque = s[i:] + s = s[:i] + } else { + u.Opaque = "/" + } + + u.Host = s + + if strings.Contains(u.Host, "@") { + // Don't bother parsing user information because user information is + // not allowed in websocket URIs. + return nil, errMalformedURL + } + + return &u, nil +} + +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": + hostPort += ":443" + default: + hostPort += ":80" + } + } + return hostPort, hostNoPort +} + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} + +func (c *Connectors) Clone() client.Connector { + return &Connectors{ + Connectors: *c.Connectors.Clone(), + URL: c.URL, + RequestHeader: c.RequestHeader, + Subprotocols: c.Subprotocols, + CookieJar: c.CookieJar, + ResponseHandler: c.ResponseHandler, + NetDial: c.NetDial, + Proxy: c.Proxy, + serverURL: c.serverURL, + validated: c.validated, + } +} + +func (c *Connectors) Validate() error { + if nil != c.validated.Load() { + return nil + } + c.validated.Store(true) + + if err := c.Connectors.Validate(); nil != err { + return err + } + + if "" == c.URL { + return fmt.Errorf("URL is not valid") + } + + u, err := parseURL(c.URL) + if nil != err { + return err + } + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return errMalformedURL + } + if nil != u.User { + // User name and password are not allowed in websocket URIs. + return errMalformedURL + } + + c.serverURL = u + + if nil == c.Proxy { + c.Proxy = http.ProxyFromEnvironment + } + + return nil +} diff --git a/socket/web/server-handler.go b/socket/web/server-handler.go new file mode 100644 index 0000000..774dc2e --- /dev/null +++ b/socket/web/server-handler.go @@ -0,0 +1,126 @@ +package web + +import ( + "net/http" + "sync/atomic" + + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/socket" + + "github.com/valyala/fasthttp" +) + +type ServerHandler interface { + socket.ServerHandler + + OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) + + RegisterServlet(path string, servlet Servlet) + Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet + + CheckOrigin(ctx *fasthttp.RequestCtx) bool +} + +type ServerHandlers struct { + socket.ServerHandlers + + servlets map[string]Servlet + + validated atomic.Value +} + +func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.Init(serverCtx); nil != err { + return err + } + + if nil != sh.servlets { + for _, servlet := range sh.servlets { + if err := servlet.Init(serverCtx); nil != err { + return err + } + } + } + + return nil +} + +func (sh *ServerHandlers) OnStart(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.OnStart(serverCtx); nil != err { + return err + } + + if nil != sh.servlets { + for _, servlet := range sh.servlets { + if err := servlet.OnStart(serverCtx); nil != err { + return err + } + } + } + + return nil +} + +func (sh *ServerHandlers) OnStop(serverCtx server.ServerCtx) { + if nil != sh.servlets { + for _, servlet := range sh.servlets { + servlet.OnStop(serverCtx) + } + } + + sh.ServerHandlers.OnStop(serverCtx) +} + +func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { + if nil != sh.servlets { + for _, servlet := range sh.servlets { + servlet.Destroy(serverCtx) + } + } + + sh.ServerHandlers.Destroy(serverCtx) +} + +func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, status int, reason error) { + ctx.Response.Header.Set("Sec-Websocket-Version", "13") + ctx.Error(http.StatusText(status), status) +} + +func (sh *ServerHandlers) RegisterServlet(path string, servlet Servlet) { + if nil == sh.servlets { + sh.servlets = make(map[string]Servlet) + } + sh.servlets[path] = servlet +} + +func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet { + path := string(ctx.Path()) + + var servlet Servlet + if path == "" && len(sh.servlets) == 1 { + for _, s := range sh.servlets { + servlet = s + } + } else if servlet = sh.servlets[path]; nil == servlet { + return nil + } + + return servlet +} + +func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool { + return true +} + +func (sh *ServerHandlers) Validate() error { + if nil != sh.validated.Load() { + return nil + } + sh.validated.Store(true) + + if err := sh.ServerHandlers.Validate(); nil != err { + return err + } + + return nil +} diff --git a/socket/web/server.go b/socket/web/server.go new file mode 100644 index 0000000..fd13ec8 --- /dev/null +++ b/socket/web/server.go @@ -0,0 +1,197 @@ +package web + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" + "sync/atomic" + + olog "git.loafle.net/overflow/log-go" + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/socket" + "github.com/valyala/fasthttp" +) + +type Server struct { + ServerHandler ServerHandler + + ctx server.ServerCtx + stopChan chan struct{} + stopWg sync.WaitGroup + + srw socket.ServerReadWriter + + hs *fasthttp.Server + upgrader *Upgrader +} + +func (s *Server) ListenAndServe() error { + var ( + err error + listener net.Listener + ) + if nil == s.ServerHandler { + return fmt.Errorf("%s server handler must be specified", s.logHeader()) + } + s.ServerHandler.Validate() + + if s.stopChan != nil { + return fmt.Errorf("%s already running. Stop it before starting it again", s.logHeader()) + } + + s.ctx = s.ServerHandler.ServerCtx() + if nil == s.ctx { + return fmt.Errorf("%s ServerCtx is nil", s.logHeader()) + } + + s.hs = &fasthttp.Server{ + Handler: s.httpHandler, + Name: s.ServerHandler.GetName(), + Concurrency: s.ServerHandler.GetConcurrency(), + ReadBufferSize: s.ServerHandler.GetReadBufferSize(), + WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), + ReadTimeout: s.ServerHandler.GetReadTimeout(), + WriteTimeout: s.ServerHandler.GetWriteTimeout(), + } + + s.upgrader = &Upgrader{ + HandshakeTimeout: s.ServerHandler.GetHandshakeTimeout(), + ReadBufferSize: s.ServerHandler.GetReadBufferSize(), + WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), + CheckOrigin: s.ServerHandler.(ServerHandler).CheckOrigin, + Error: s.onError, + EnableCompression: s.ServerHandler.IsEnableCompression(), + CompressionLevel: s.ServerHandler.GetCompressionLevel(), + } + + if err = s.ServerHandler.Init(s.ctx); nil != err { + olog.Logger().Errorf("%s Init has been failed %v", s.logHeader(), err) + return err + } + + if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { + return err + } + + s.stopChan = make(chan struct{}) + + s.srw.ReadwriteHandler = s.ServerHandler + s.srw.ServerStopChan = s.stopChan + s.srw.ServerStopWg = &s.stopWg + + s.stopWg.Add(1) + return s.handleServer(listener) +} + +func (s *Server) Shutdown(ctx context.Context) error { + if s.stopChan == nil { + return fmt.Errorf("%s must be started before stopping it", s.logHeader()) + } + close(s.stopChan) + s.stopWg.Wait() + + s.ServerHandler.Destroy(s.ctx) + + s.stopChan = nil + + return nil +} + +func (s *Server) logHeader() string { + return fmt.Sprintf("Server[%s]:", s.ServerHandler.GetName()) +} + +func (s *Server) handleServer(listener net.Listener) error { + var ( + err error + stopping atomic.Value + ) + + defer func() { + if nil != listener { + listener.Close() + } + s.ServerHandler.OnStop(s.ctx) + + olog.Logger().Infof("%s Stopped", s.logHeader()) + + s.stopWg.Done() + }() + + if err = s.ServerHandler.OnStart(s.ctx); nil != err { + olog.Logger().Errorf("%s OnStart has been failed %v", s.logHeader(), err) + return err + } + + hsCloseChan := make(chan error) + go func() { + if err := s.hs.Serve(listener); nil != err { + if nil == stopping.Load() { + hsCloseChan <- err + return + } + } + hsCloseChan <- nil + }() + + olog.Logger().Infof("%s Started", s.logHeader()) + + select { + case err, _ := <-hsCloseChan: + if nil != err { + return err + } + case <-s.stopChan: + stopping.Store(true) + listener.Close() + <-hsCloseChan + listener = nil + } + + return nil +} + +func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { + var ( + servlet Servlet + err error + ) + + if 0 < s.ServerHandler.GetConcurrency() { + sz := s.srw.ConnectionSize() + if sz >= s.ServerHandler.GetConcurrency() { + olog.Logger().Warnf("%s max connections size %d, refuse", s.logHeader(), sz) + s.onError(ctx, fasthttp.StatusServiceUnavailable, err) + return + } + } + + if servlet = s.ServerHandler.Servlet(s.ctx, ctx); nil == servlet { + s.onError(ctx, fasthttp.StatusInternalServerError, err) + return + } + + var responseHeader *fasthttp.ResponseHeader + servletCtx := servlet.ServletCtx(s.ctx) + + if responseHeader, err = servlet.Handshake(servletCtx, ctx); nil != err { + s.onError(ctx, http.StatusNotAcceptable, fmt.Errorf("Handshake err: %v", err)) + return + } + + s.upgrader.Upgrade(ctx, responseHeader, func(conn *socket.SocketConn, err error) { + if err != nil { + s.onError(ctx, fasthttp.StatusInternalServerError, err) + return + } + + s.stopWg.Add(1) + s.srw.HandleConnection(servlet, servletCtx, conn) + }) +} + +func (s *Server) onError(ctx *fasthttp.RequestCtx, status int, reason error) { + s.ServerHandler.(ServerHandler).OnError(s.ctx, ctx, status, reason) +} diff --git a/socket/web/servlet.go b/socket/web/servlet.go new file mode 100644 index 0000000..54a3cea --- /dev/null +++ b/socket/web/servlet.go @@ -0,0 +1,53 @@ +package web + +import ( + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/socket" + "github.com/valyala/fasthttp" +) + +type Servlet interface { + socket.Servlet + + Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) +} + +type Servlets struct { + Servlet +} + +func (s *Servlets) ServletCtx(serverCtx server.ServerCtx) server.ServletCtx { + return server.NewServletContext(nil, serverCtx) +} + +func (s *Servlets) Init(serverCtx server.ServerCtx) error { + return nil +} + +func (s *Servlets) OnStart(serverCtx server.ServerCtx) error { + return nil +} + +func (s *Servlets) OnStop(serverCtx server.ServerCtx) { + // +} + +func (s *Servlets) Destroy(serverCtx server.ServerCtx) { + // +} + +func (s *Servlets) Handshake(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) (*fasthttp.ResponseHeader, error) { + return nil, nil +} + +func (s *Servlets) OnConnect(servletCtx server.ServletCtx, conn socket.Conn) { + // +} + +func (s *Servlets) Handle(servletCtx server.ServletCtx, stopChan <-chan struct{}, doneChan chan<- struct{}, readChan <-chan socket.SocketMessage, writeChan chan<- socket.SocketMessage) { + +} + +func (s *Servlets) OnDisconnect(servletCtx server.ServletCtx) { + // +} diff --git a/socket/web/upgrade.go b/socket/web/upgrade.go new file mode 100644 index 0000000..c3c8ba2 --- /dev/null +++ b/socket/web/upgrade.go @@ -0,0 +1,279 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package web + +import ( + "net" + "net/http" + "net/url" + "strings" + "time" + + "git.loafle.net/overflow/server-go/socket" + "github.com/valyala/fasthttp" +) + +type ( + OnUpgradeFunc func(*socket.SocketConn, error) +) + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +// Upgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +type Upgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is set, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(ctx *fasthttp.RequestCtx, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, the host in the Origin header must not be set or + // must match the host of the request. + CheckOrigin func(ctx *fasthttp.RequestCtx) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool + CompressionLevel int +} + +func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) (*socket.SocketConn, error) { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(ctx, status, err) + } else { + ctx.Response.Header.Set("Sec-Websocket-Version", "13") + ctx.Error(http.StatusText(status), status) + } + return nil, err +} + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +func checkSameOrigin(ctx *fasthttp.RequestCtx) bool { + origin := string(ctx.Request.Header.Peek("Origin")) + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin) + if err != nil { + return false + } + return u.Host == string(ctx.Host()) +} + +func (u *Upgrader) selectSubprotocol(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader) string { + if u.Subprotocols != nil { + clientProtocols := Subprotocols(ctx) + for _, serverProtocol := range u.Subprotocols { + for _, clientProtocol := range clientProtocols { + if clientProtocol == serverProtocol { + return clientProtocol + } + } + } + } else if responseHeader != nil { + return string(responseHeader.Peek("Sec-Websocket-Protocol")) + } + return "" +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// application negotiated subprotocol (Sec-Websocket-Protocol). +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, cb OnUpgradeFunc) { + if !ctx.IsGet() { + cb(u.returnError(ctx, fasthttp.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")) + return + } + + if nil != responseHeader { + if v := responseHeader.Peek("Sec-Websocket-Extensions"); nil != v { + cb(u.returnError(ctx, fasthttp.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")) + return + } + } + + if !tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header")) + return + } + + if !tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header")) + return + } + + if !tokenListContainsValue(&ctx.Request.Header, "Sec-Websocket-Version", "13") { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")) + return + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(ctx) { + cb(u.returnError(ctx, fasthttp.StatusForbidden, "websocket: 'Origin' header value not allowed")) + return + } + + challengeKey := string(ctx.Request.Header.Peek("Sec-Websocket-Key")) + if challengeKey == "" { + cb(u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank")) + return + } + + subprotocol := u.selectSubprotocol(ctx, responseHeader) + + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(&ctx.Request.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + + ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols) + ctx.Response.Header.Set("Upgrade", "websocket") + ctx.Response.Header.Set("Connection", "Upgrade") + ctx.Response.Header.Set("Sec-Websocket-Accept", ComputeAcceptKey(challengeKey)) + if subprotocol != "" { + ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol) + } + if compress { + ctx.Response.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + if nil != responseHeader { + responseHeader.VisitAll(func(key, value []byte) { + k := string(key) + v := string(value) + if k == "Sec-Websocket-Protocol" { + return + } + ctx.Response.Header.Set(k, v) + }) + } + + h := &fasthttp.RequestHeader{} + + //copy request headers in order to have access inside the Conn after + ctx.Request.Header.CopyTo(h) + + ctx.Hijack(func(netConn net.Conn) { + c := socket.NewConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) + c.SetSubprotocol(subprotocol) + if compress { + c.SetCompressionLevel(u.CompressionLevel) + c.SetNewCompressionWriter(socket.CompressNoContextTakeover) + c.SetNewDecompressionReader(socket.DecompressNoContextTakeover) + } + + // Clear deadlines set by HTTP server. + netConn.SetDeadline(time.Time{}) + + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + } + + if u.HandshakeTimeout > 0 { + netConn.SetWriteDeadline(time.Time{}) + } + + cb(c, nil) + }) +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// This function is deprecated, use websocket.Upgrader instead. +// +// The application is responsible for checking the request origin before +// calling Upgrade. An example implementation of the same origin policy is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", 403) +// return +// } +// +// If the endpoint supports subprotocols, then the application is responsible +// for negotiating the protocol used on the connection. Use the Subprotocols() +// function to get the subprotocols requested by the client. Use the +// Sec-Websocket-Protocol response header to specify the subprotocol selected +// by the application. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// negotiated subprotocol (Sec-Websocket-Protocol). +// +// The connection buffers IO to the underlying network connection. The +// readBufSize and writeBufSize parameters specify the size of the buffers to +// use. Messages can be larger than the buffers. +// +// If the request is not a valid WebSocket handshake, then Upgrade returns an +// error of type HandshakeError. Applications should handle this error by +// replying to the client with an HTTP error response. +func Upgrade(ctx *fasthttp.RequestCtx, responseHeader *fasthttp.ResponseHeader, readBufSize, writeBufSize int, cb OnUpgradeFunc) { + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} + u.Error = func(ctx *fasthttp.RequestCtx, status int, reason error) { + // don't return errors to maintain backwards compatibility + } + u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool { + // allow all connections by default + return true + } + u.Upgrade(ctx, responseHeader, cb) +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-Websocket-Protocol header. +func Subprotocols(ctx *fasthttp.RequestCtx) []string { + h := strings.TrimSpace(string(ctx.Request.Header.Peek("Sec-Websocket-Protocol"))) + if h == "" { + return nil + } + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(ctx *fasthttp.RequestCtx) bool { + return tokenListContainsValue(&ctx.Request.Header, "Connection", "upgrade") && + tokenListContainsValue(&ctx.Request.Header, "Upgrade", "websocket") +} diff --git a/socket/web/util.go b/socket/web/util.go new file mode 100644 index 0000000..7614107 --- /dev/null +++ b/socket/web/util.go @@ -0,0 +1,272 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package web + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "strings" + + "github.com/valyala/fasthttp" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func ComputeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func GenerateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// Octet types from RFC 2616. +var octetTypes [256]byte + +const ( + isTokenOctet = 1 << iota + isSpaceOctet +) + +func init() { + // From RFC 2616 + // + // OCTET = + // CHAR = + // CTL = + // CR = + // LF = + // SP = + // HT = + // <"> = + // CRLF = CR LF + // LWS = [CRLF] 1*( SP | HT ) + // TEXT = + // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT + // token = 1* + // qdtext = > + + for c := 0; c < 256; c++ { + var t byte + isCtl := c <= 31 || c == 127 + isChar := 0 <= c && c <= 127 + isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0 + if strings.IndexRune(" \t\r\n", rune(c)) >= 0 { + t |= isSpaceOctet + } + if isChar && !isCtl && !isSeparator { + t |= isTokenOctet + } + octetTypes[c] = t + } +} + +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isSpaceOctet == 0 { + break + } + } + return s[i:] +} + +func nextToken(s string) (token string, rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isTokenOctet == 0 { + break + } + } + return s[:i], s[i:] +} + +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains token. +func tokenListContainsValue(header *fasthttp.RequestHeader, name string, value string) bool { + s := string(header.Peek(name)) + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + break + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + break + } + if strings.EqualFold(t, value) { + return true + } + if s == "" { + break + } + s = s[1:] + } + return false +} + +// parseExtensiosn parses WebSocket extensions from a header. +func parseExtensions(header *fasthttp.RequestHeader) []map[string]string { + + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + s := string(header.Peek("Sec-Websocket-Extensions")) + var result []map[string]string +headers: + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + break headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + break headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + break headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + break headers + } + result = append(result, ext) + if s == "" { + break headers + } + s = s[1:] + } + + return result +} + +// parseExtensiosn parses WebSocket extensions from a header. +func HttpParseExtensions(header http.Header) []map[string]string { + + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} diff --git a/web/error.go b/web/error.go new file mode 100644 index 0000000..611291a --- /dev/null +++ b/web/error.go @@ -0,0 +1,25 @@ +package web + +import ( + ouc "git.loafle.net/overflow/util-go/ctx" +) + +const ( + ErrorKey = ouc.CtxtKey("ErrorKey") +) + +func NewError(code int, cause error) *Error { + return &Error{ + Code: code, + Cause: cause, + } +} + +type Error struct { + Code int + Cause error +} + +func (e *Error) Error() string { + return e.Cause.Error() +} diff --git a/web/fasthttp/server-handler.go b/web/fasthttp/server-handler.go new file mode 100644 index 0000000..a572ee0 --- /dev/null +++ b/web/fasthttp/server-handler.go @@ -0,0 +1,162 @@ +package fasthttp + +import ( + "fmt" + "strings" + "sync/atomic" + + olog "git.loafle.net/overflow/log-go" + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/web" + + "github.com/valyala/fasthttp" +) + +type ServerHandler interface { + web.ServerHandler + + OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, err *web.Error) + + RegisterServlet(path string, servlet Servlet) + Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet + + CheckOrigin(ctx *fasthttp.RequestCtx) bool +} + +type ServerHandlers struct { + web.ServerHandlers + + ErrorServelt Servlet `json:"-"` + + // path = context only. + // ex) /auth => /auth, /auth/member => /auth + servlets map[string]Servlet + + validated atomic.Value +} + +func (sh *ServerHandlers) Init(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.Init(serverCtx); nil != err { + return err + } + + if nil != sh.servlets { + for _, servlet := range sh.servlets { + if err := servlet.Init(serverCtx); nil != err { + return err + } + } + } + + return nil +} + +func (sh *ServerHandlers) OnStart(serverCtx server.ServerCtx) error { + if err := sh.ServerHandlers.OnStart(serverCtx); nil != err { + return err + } + + if nil != sh.servlets { + for _, servlet := range sh.servlets { + if err := servlet.OnStart(serverCtx); nil != err { + return err + } + } + } + + return nil +} + +func (sh *ServerHandlers) OnStop(serverCtx server.ServerCtx) { + if nil != sh.servlets { + for _, servlet := range sh.servlets { + servlet.OnStop(serverCtx) + } + } + + sh.ServerHandlers.OnStop(serverCtx) +} + +func (sh *ServerHandlers) Destroy(serverCtx server.ServerCtx) { + if nil != sh.servlets { + for _, servlet := range sh.servlets { + servlet.Destroy(serverCtx) + } + } + + sh.ServerHandlers.Destroy(serverCtx) +} + +func (sh *ServerHandlers) OnError(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx, err *web.Error) { + if nil != sh.ErrorServelt { + servletCtx := sh.ErrorServelt.ServletCtx(serverCtx) + servletCtx.SetAttribute(web.ErrorKey, err) + sh.ErrorServelt.Handle(servletCtx, ctx) + return + } + + ctx.Error(err.Cause.Error(), err.Code) +} + +func (sh *ServerHandlers) RegisterServlet(contextPath string, servlet Servlet) { + if nil == sh.servlets { + sh.servlets = make(map[string]Servlet) + } + servlet.setContextPath(contextPath) + sh.servlets[contextPath] = servlet +} + +func (sh *ServerHandlers) Servlet(serverCtx server.ServerCtx, ctx *fasthttp.RequestCtx) Servlet { + path := string(ctx.Path()) + contextPath, err := getContextPath(path) + if nil != err { + olog.Logger().Warnf("Bad Request %v", err) + return nil + } + + var servlet Servlet + if servlet = sh.servlets[contextPath]; nil == servlet { + olog.Logger().Warnf("Servlet is not exist for url[%s]", path) + return nil + } + + return servlet +} + +func (sh *ServerHandlers) CheckOrigin(ctx *fasthttp.RequestCtx) bool { + return true +} + +func (sh *ServerHandlers) Validate() error { + if nil != sh.validated.Load() { + return nil + } + sh.validated.Store(true) + + if err := sh.ServerHandlers.Validate(); nil != err { + return err + } + + return nil +} + +func getContextPath(path string) (string, error) { + p := strings.TrimSpace(path) + + if !strings.HasPrefix(p, "/") { + return "", fmt.Errorf("path[%s] must started /", path) + } + p = p[1:] + + if strings.HasSuffix(p, "/") { + cpl := len(p) - 1 + p = p[:cpl] + } + + components := strings.Split(p, "/") + if 0 == len(components) { + return "", fmt.Errorf("path[%s] is not invalid", path) + } + + return fmt.Sprintf("/%s", components[0]), nil +} diff --git a/web/fasthttp/server.go b/web/fasthttp/server.go new file mode 100644 index 0000000..7a43b21 --- /dev/null +++ b/web/fasthttp/server.go @@ -0,0 +1,159 @@ +package fasthttp + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + + olog "git.loafle.net/overflow/log-go" + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/web" + "github.com/valyala/fasthttp" +) + +type Server struct { + ServerHandler ServerHandler + + ctx server.ServerCtx + stopChan chan struct{} + stopWg sync.WaitGroup + + hs *fasthttp.Server +} + +func (s *Server) ListenAndServe() error { + var ( + err error + listener net.Listener + ) + if nil == s.ServerHandler { + return fmt.Errorf("%s server handler must be specified", s.logHeader()) + } + s.ServerHandler.Validate() + + if s.stopChan != nil { + return fmt.Errorf("%s already running. Stop it before starting it again", s.logHeader()) + } + + s.ctx = s.ServerHandler.ServerCtx() + if nil == s.ctx { + return fmt.Errorf("%s ServerCtx is nil", s.logHeader()) + } + + s.hs = &fasthttp.Server{ + Handler: s.httpHandler, + Name: s.ServerHandler.GetName(), + Concurrency: s.ServerHandler.GetConcurrency(), + ReadBufferSize: s.ServerHandler.GetReadBufferSize(), + WriteBufferSize: s.ServerHandler.GetWriteBufferSize(), + ReadTimeout: s.ServerHandler.GetReadTimeout(), + WriteTimeout: s.ServerHandler.GetWriteTimeout(), + } + + if err = s.ServerHandler.Init(s.ctx); nil != err { + return err + } + + if listener, err = s.ServerHandler.Listener(s.ctx); nil != err { + return err + } + + s.stopChan = make(chan struct{}) + + s.stopWg.Add(1) + return s.handleServer(listener) +} + +func (s *Server) Shutdown(ctx context.Context) error { + if s.stopChan == nil { + return fmt.Errorf("%s must be started before stopping it", s.logHeader()) + } + close(s.stopChan) + s.stopWg.Wait() + + s.ServerHandler.Destroy(s.ctx) + + s.stopChan = nil + + return nil +} + +func (s *Server) logHeader() string { + return fmt.Sprintf("Server[%s]:", s.ServerHandler.GetName()) +} + +func (s *Server) handleServer(listener net.Listener) error { + var ( + err error + stopping atomic.Value + ) + + defer func() { + if nil != listener { + listener.Close() + } + s.ServerHandler.OnStop(s.ctx) + + olog.Logger().Infof("%s Stopped", s.logHeader()) + + s.stopWg.Done() + }() + + if err = s.ServerHandler.OnStart(s.ctx); nil != err { + return err + } + + hsCloseChan := make(chan error) + go func() { + if err := s.hs.Serve(listener); nil != err { + if nil == stopping.Load() { + hsCloseChan <- err + return + } + } + hsCloseChan <- nil + }() + + olog.Logger().Infof("%s Started", s.logHeader()) + + select { + case err, _ := <-hsCloseChan: + if nil != err { + return err + } + case <-s.stopChan: + stopping.Store(true) + listener.Close() + <-hsCloseChan + listener = nil + } + + return nil +} + +func (s *Server) httpHandler(ctx *fasthttp.RequestCtx) { + var ( + servlet Servlet + ) + + if s.ServerHandler.CheckOrigin(ctx) { + return + } + + if servlet = s.ServerHandler.Servlet(s.ctx, ctx); nil == servlet { + s.onError(ctx, web.NewError(fasthttp.StatusNotFound, fmt.Errorf("Not Found"))) + return + } + + servletCtx := servlet.ServletCtx(s.ctx) + + if err := servlet.Handle(servletCtx, ctx); nil != err { + s.onError(ctx, err) + } +} + +func (s *Server) onError(ctx *fasthttp.RequestCtx, err *web.Error) { + s.ServerHandler.OnError(s.ctx, ctx, err) +} diff --git a/web/fasthttp/servlet.go b/web/fasthttp/servlet.go new file mode 100644 index 0000000..41f87fe --- /dev/null +++ b/web/fasthttp/servlet.go @@ -0,0 +1,55 @@ +package fasthttp + +import ( + "strings" + + "git.loafle.net/overflow/server-go" + "git.loafle.net/overflow/server-go/web" + "github.com/valyala/fasthttp" +) + +type Servlet interface { + web.Servlet + + Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *web.Error + RequestPath(ctx *fasthttp.RequestCtx) string + setContextPath(contextPath string) +} + +type Servlets struct { + Servlet + + ContextPath string +} + +func (s *Servlets) ServletCtx(serverCtx server.ServerCtx) server.ServletCtx { + return server.NewServletContext(nil, serverCtx) +} + +func (s *Servlets) Init(serverCtx server.ServerCtx) error { + return nil +} + +func (s *Servlets) OnStart(serverCtx server.ServerCtx) error { + return nil +} + +func (s *Servlets) OnStop(serverCtx server.ServerCtx) { + // +} + +func (s *Servlets) Destroy(serverCtx server.ServerCtx) { + // +} + +func (s *Servlets) Handle(servletCtx server.ServletCtx, ctx *fasthttp.RequestCtx) *web.Error { + return nil +} + +func (s *Servlets) setContextPath(contextPath string) { + s.ContextPath = contextPath +} + +func (s *Servlets) RequestPath(ctx *fasthttp.RequestCtx) string { + return strings.Replace(string(ctx.Path()), s.ContextPath, "", -1) +} diff --git a/web/server-handler.go b/web/server-handler.go new file mode 100644 index 0000000..407c105 --- /dev/null +++ b/web/server-handler.go @@ -0,0 +1,35 @@ +package web + +import ( + "sync/atomic" + + "git.loafle.net/overflow/server-go" +) + +type ServerHandler interface { + server.ServerHandler + server.ReadWriteHandler +} + +type ServerHandlers struct { + server.ServerHandlers + server.ReadWriteHandlers + + validated atomic.Value +} + +func (sh *ServerHandlers) Validate() error { + if nil != sh.validated.Load() { + return nil + } + sh.validated.Store(true) + + if err := sh.ServerHandlers.Validate(); nil != err { + return err + } + if err := sh.ReadWriteHandlers.Validate(); nil != err { + return err + } + + return nil +} diff --git a/web/servlet.go b/web/servlet.go new file mode 100644 index 0000000..1d6b22d --- /dev/null +++ b/web/servlet.go @@ -0,0 +1,9 @@ +package web + +import ( + "git.loafle.net/overflow/server-go" +) + +type Servlet interface { + server.Servlet +}