commit 044bb8093372359fda1fb15b56a9b1ab3063f32c Author: crusader Date: Thu Aug 31 16:30:44 2017 +0900 ing diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3733e36 --- /dev/null +++ b/.gitignore @@ -0,0 +1,68 @@ +# Created by .ignore support plugin (hsz.mobi) +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff: +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/dictionaries + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.xml +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries + +# Mongo Explorer plugin: +.idea/**/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +/out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties +### Go template +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ +.idea/ +*.iml + +vendor/ +glide.lock +.DS_Store +dist/ +debug diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..2ca2b1d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,32 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug", + "type": "go", + "request": "launch", + "mode": "debug", + "remotePath": "", + "port": 2345, + "host": "127.0.0.1", + "program": "${workspaceRoot}/main.go", + "env": {}, + "args": [], + "showLog": true + }, + { + "name": "File Debug", + "type": "go", + "request": "launch", + "mode": "debug", + "remotePath": "", + "port": 2345, + "host": "127.0.0.1", + "program": "${fileDirname}", + "env": {}, + "args": [], + "showLog": true + } + + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..20af2f6 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +// Place your settings in this file to overwrite default and user settings. +{ +} \ No newline at end of file diff --git a/cors.go b/cors.go new file mode 100644 index 0000000..9045fce --- /dev/null +++ b/cors.go @@ -0,0 +1,312 @@ +package cors_fasthttp + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/valyala/fasthttp" + "go.uber.org/zap" + + "git.loafle.net/commons_go/logging" +) + +// Cors http handler +type Cors interface { + Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler +} + +// Cors http handler +type cors struct { + ctx context.Context + logger *zap.Logger + co CorsOptions + // Set to true when allowed origins contains a "*" + allowedOriginsAll bool + // Normalized list of plain allowed origins + allowedOrigins []string + // List of allowed origins containing wildcards + allowedWOrigins []wildcard + // Set to true when allowed headers contains a "*" + allowedHeadersAll bool + // Normalized list of allowed headers + allowedHeaders []string + // Normalized list of allowed methods + allowedMethods []string + // Normalized list of exposed headers + exposedHeaders []string +} + +// New creates a new Cors handler with the provided options. +func New(ctx context.Context, co CorsOptions) Cors { + c := &cors{ + ctx: ctx, + logger: logging.WithContext(ctx), + co: co, + exposedHeaders: convert(co.ExposedHeaders, http.CanonicalHeaderKey), + } + + // Normalize options + // Note: for origins and methods matching, the spec requires a case-sensitive matching. + // As it may error prone, we chose to ignore the spec here. + + // Allowed Origins + // Allowed Origins + if len(co.AllowedOrigins) == 0 { + if co.AllowOriginFunc == nil { + // Default is all origins + c.allowedOriginsAll = true + } + } else { + c.allowedOrigins = []string{} + c.allowedWOrigins = []wildcard{} + for _, origin := range co.AllowedOrigins { + // Normalize + origin = strings.ToLower(origin) + if origin == "*" { + // If "*" is present in the list, turn the whole list into a match all + c.allowedOriginsAll = true + c.allowedOrigins = nil + c.allowedWOrigins = nil + break + } else if i := strings.IndexByte(origin, '*'); i >= 0 { + // Split the origin in two: start and end string without the * + w := wildcard{origin[0:i], origin[i+1 : len(origin)]} + c.allowedWOrigins = append(c.allowedWOrigins, w) + } else { + c.allowedOrigins = append(c.allowedOrigins, origin) + } + } + } + + // Allowed Headers + if len(co.AllowedHeaders) == 0 { + // Use sensible defaults + c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"} + } else { + // Origin is always appended as some browsers will always request for this header at preflight + c.allowedHeaders = convert(append(co.AllowedHeaders, "Origin"), http.CanonicalHeaderKey) + for _, h := range co.AllowedHeaders { + if h == "*" { + c.allowedHeadersAll = true + c.allowedHeaders = nil + break + } + } + } + + // Allowed Methods + if len(co.AllowedMethods) == 0 { + // Default is spec's "simple" methods + c.allowedMethods = []string{"GET", "POST", "HEAD"} + } else { + c.allowedMethods = convert(co.AllowedMethods, strings.ToUpper) + } + + return c +} + +// Default creates a new Cors handler with default options. +func Default(ctx context.Context) Cors { + return New(ctx, CorsOptions{}) +} + +// AllowAll create a new Cors handler with permissive configuration allowing all +// origins with all standard methods with any header and credentials. +func AllowAll(ctx context.Context) Cors { + return New(ctx, CorsOptions{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"}, + AllowedHeaders: []string{"*"}, + AllowCredentials: true, + }) +} + +// Handler apply the CORS specification on the request, and add relevant CORS headers +// as necessary. +func (c *cors) Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler { + return fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { + if string(ctx.Method()) == "OPTIONS" && ctx.Request.Header.Peek("Access-Control-Request-Method") != nil { + c.logger.Info("Handler: Preflight request") + c.handlePreflight(ctx) + // Preflight requests are standalone and should stop the chain as some other + // middleware may not handle OPTIONS requests correctly. One typical example + // is authentication middleware ; OPTIONS requests won't carry authentication + // headers (see #1) + if c.co.OptionsPassthrough { + h(ctx) + } else { + ctx.SetStatusCode(fasthttp.StatusOK) + } + } else { + c.logger.Info("Handler: Actual request") + c.handleActualRequest(ctx) + h(ctx) + } + }) +} + +// handlePreflight handles pre-flight CORS requests +func (c *cors) handlePreflight(ctx *fasthttp.RequestCtx) { + origin := string(ctx.Request.Header.Peek("Origin")) + + if string(ctx.Method()) != "OPTIONS" { + c.logger.Info(fmt.Sprintf(" Preflight aborted: %s!=OPTIONS", string(ctx.Method()))) + return + } + // Always set Vary headers + // see https://github.com/rs/cors/issues/10, + // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 + ctx.Response.Header.Add("Vary", "Origin") + ctx.Response.Header.Add("Vary", "Access-Control-Request-Method") + ctx.Response.Header.Add("Vary", "Access-Control-Request-Headers") + + if origin == "" { + c.logger.Info(" Preflight aborted: empty origin") + return + } + if !c.isOriginAllowed(origin) { + c.logger.Info(fmt.Sprintf(" Preflight aborted: origin '%s' not allowed", origin)) + return + } + + reqMethod := string(ctx.Request.Header.Peek("Access-Control-Request-Method")) + if !c.isMethodAllowed(reqMethod) { + c.logger.Info(fmt.Sprintf(" Preflight aborted: method '%s' not allowed", reqMethod)) + return + } + reqHeaders := parseHeaderList(string(ctx.Request.Header.Peek("Access-Control-Request-Headers"))) + if !c.areHeadersAllowed(reqHeaders) { + c.logger.Info(fmt.Sprintf(" Preflight aborted: headers '%v' not allowed", reqHeaders)) + return + } + if c.allowedOriginsAll && !c.co.AllowCredentials { + ctx.Response.Header.Set("Access-Control-Allow-Origin", "*") + } else { + ctx.Response.Header.Set("Access-Control-Allow-Origin", origin) + } + // Spec says: Since the list of methods can be unbounded, simply returning the method indicated + // by Access-Control-Request-Method (if supported) can be enough + ctx.Response.Header.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod)) + if len(reqHeaders) > 0 { + // Spec says: Since the list of headers can be unbounded, simply returning supported headers + // from Access-Control-Request-Headers can be enough + ctx.Response.Header.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) + } + if c.co.AllowCredentials { + ctx.Response.Header.Set("Access-Control-Allow-Credentials", "true") + } + if c.co.MaxAge > 0 { + ctx.Response.Header.Set("Access-Control-Max-Age", strconv.Itoa(c.co.MaxAge)) + } + // c.logger.Info(fmt.Sprintf(" Preflight response headers: %v", ctx.Response.Header.)) +} + +// handleActualRequest handles simple cross-origin requests, actual request or redirects +func (c *cors) handleActualRequest(ctx *fasthttp.RequestCtx) { + origin := string(ctx.Request.Header.Peek("Origin")) + method := string(ctx.Method()) + + if method == "OPTIONS" { + c.logger.Info(fmt.Sprintf(" Actual request no headers added: method == %s", method)) + return + } + // Always set Vary, see https://github.com/rs/cors/issues/10 + ctx.Response.Header.Add("Vary", "Origin") + if origin == "" { + c.logger.Info(" Actual request no headers added: missing origin") + return + } + if !c.isOriginAllowed(origin) { + c.logger.Info(fmt.Sprintf(" Actual request no headers added: origin '%s' not allowed", origin)) + return + } + + // Note that spec does define a way to specifically disallow a simple method like GET or + // POST. Access-Control-Allow-Methods is only used for pre-flight requests and the + // spec doesn't instruct to check the allowed methods for simple cross-origin requests. + // We think it's a nice feature to be able to have control on those methods though. + if !c.isMethodAllowed(method) { + c.logger.Info(fmt.Sprintf(" Actual request no headers added: method '%s' not allowed", method)) + + return + } + if c.allowedOriginsAll && !c.co.AllowCredentials { + ctx.Response.Header.Set("Access-Control-Allow-Origin", "*") + } else { + ctx.Response.Header.Set("Access-Control-Allow-Origin", origin) + } + if len(c.exposedHeaders) > 0 { + ctx.Response.Header.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", ")) + } + if c.co.AllowCredentials { + ctx.Response.Header.Set("Access-Control-Allow-Credentials", "true") + } + // c.logf(" Actual response added headers: %v", headers) +} + +// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests +// on the endpoint +func (c *cors) isOriginAllowed(origin string) bool { + if c.co.AllowOriginFunc != nil { + return c.co.AllowOriginFunc(origin) + } + if c.allowedOriginsAll { + return true + } + origin = strings.ToLower(origin) + for _, o := range c.allowedOrigins { + if o == origin { + return true + } + } + for _, w := range c.allowedWOrigins { + if w.match(origin) { + return true + } + } + return false +} + +// isMethodAllowed checks if a given method can be used as part of a cross-domain request +// on the endpoing +func (c *cors) isMethodAllowed(method string) bool { + if len(c.allowedMethods) == 0 { + // If no method allowed, always return false, even for preflight request + return false + } + method = strings.ToUpper(method) + if method == "OPTIONS" { + // Always allow preflight requests + return true + } + for _, m := range c.allowedMethods { + if m == method { + return true + } + } + return false +} + +// areHeadersAllowed checks if a given list of headers are allowed to used within +// a cross-domain request. +func (c *cors) areHeadersAllowed(requestedHeaders []string) bool { + if c.allowedHeadersAll || len(requestedHeaders) == 0 { + return true + } + for _, header := range requestedHeaders { + header = http.CanonicalHeaderKey(header) + found := false + for _, h := range c.allowedHeaders { + if h == header { + found = true + } + } + if !found { + return false + } + } + return true +} diff --git a/cors_options.go b/cors_options.go new file mode 100644 index 0000000..d20a421 --- /dev/null +++ b/cors_options.go @@ -0,0 +1,35 @@ +package cors_fasthttp + +type CorsOptions struct { + // AllowedOrigins is a list of origins a cross-domain request can be executed from. + // If the special "*" value is present in the list, all origins will be allowed. + // An origin may contain a wildcard (*) to replace 0 or more characters + // (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty. + // Only one wildcard can be used per origin. + // Default value is ["*"] + AllowedOrigins []string + // AllowOriginFunc is a custom function to validate the origin. It take the origin + // as argument and returns true if allowed or false otherwise. If this option is + // set, the content of AllowedOrigins is ignored. + AllowOriginFunc func(origin string) bool + // AllowedMethods is a list of methods the client is allowed to use with + // cross-domain requests. Default value is simple methods (HEAD, GET and POST). + AllowedMethods []string + // AllowedHeaders is list of non simple headers the client is allowed to use with + // cross-domain requests. + // If the special "*" value is present in the list, all headers will be allowed. + // Default value is [] but "Origin" is always appended to the list. + AllowedHeaders []string + // ExposedHeaders indicates which headers are safe to expose to the API of a CORS + // API specification + ExposedHeaders []string + // AllowCredentials indicates whether the request can include user credentials like + // cookies, HTTP authentication or client side SSL certificates. + AllowCredentials bool + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached + MaxAge int + // OptionsPassthrough instructs preflight to let other potential next handlers to + // process the OPTIONS method. Turn this on if your application handles OPTIONS. + OptionsPassthrough bool +} diff --git a/examples/server.go b/examples/server.go new file mode 100644 index 0000000..d5e5a45 --- /dev/null +++ b/examples/server.go @@ -0,0 +1,20 @@ +package main + +import ( + "context" + + "github.com/valyala/fasthttp" + + "git.loafle.net/commons_go/cors_fasthttp" +) + +func main() { + ctx := context.Background() + c := cors_fasthttp.New(ctx, cors_fasthttp.CorsOptions{}) + + fasthttp.ListenAndServe(":8080", c.Handler(Handler)) +} + +func Handler(ctx *fasthttp.RequestCtx) { + +} diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 0000000..d9aef45 --- /dev/null +++ b/glide.yaml @@ -0,0 +1,5 @@ +package: git.loafle.net/commons_go/cors_fasthttp +import: +- package: git.loafle.net/commons_go/logging +- package: github.com/valyala/fasthttp + version: v20160617 diff --git a/util.go b/util.go new file mode 100644 index 0000000..3261dd1 --- /dev/null +++ b/util.go @@ -0,0 +1,70 @@ +package cors_fasthttp + +import "strings" + +const toLower = 'a' - 'A' + +type converter func(string) string + +type wildcard struct { + prefix string + suffix string +} + +func (w wildcard) match(s string) bool { + return len(s) >= len(w.prefix+w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix) +} + +// convert converts a list of string using the passed converter function +func convert(s []string, c converter) []string { + out := []string{} + for _, i := range s { + out = append(out, c(i)) + } + return out +} + +// parseHeaderList tokenize + normalize a string containing a list of headers +func parseHeaderList(headerList string) []string { + l := len(headerList) + h := make([]byte, 0, l) + upper := true + // Estimate the number headers in order to allocate the right splice size + t := 0 + for i := 0; i < l; i++ { + if headerList[i] == ',' { + t++ + } + } + headers := make([]string, 0, t) + for i := 0; i < l; i++ { + b := headerList[i] + if b >= 'a' && b <= 'z' { + if upper { + h = append(h, b-toLower) + } else { + h = append(h, b) + } + } else if b >= 'A' && b <= 'Z' { + if !upper { + h = append(h, b+toLower) + } else { + h = append(h, b) + } + } else if b == '-' || b == '_' || (b >= '0' && b <= '9') { + h = append(h, b) + } + + if b == ' ' || b == ',' || i == l-1 { + if len(h) > 0 { + // Flush the found header + headers = append(headers, string(h)) + h = h[:0] + upper = true + } + } else { + upper = b == '-' || b == '_' + } + } + return headers +}