This commit is contained in:
crusader 2017-11-23 16:42:44 +09:00
parent 2a40c39f08
commit c1655c3e72
2 changed files with 21 additions and 30 deletions

46
cors.go
View File

@ -1,16 +1,14 @@
package cors_fasthttp package cors_fasthttp
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
"git.loafle.net/commons_go/logging" "git.loafle.net/commons_go/logging"
"github.com/valyala/fasthttp"
) )
// Cors http handler // Cors http handler
@ -20,9 +18,7 @@ type Cors interface {
// Cors http handler // Cors http handler
type cors struct { type cors struct {
ctx context.Context co CorsOptions
logger *zap.Logger
co CorsOptions
// Set to true when allowed origins contains a "*" // Set to true when allowed origins contains a "*"
allowedOriginsAll bool allowedOriginsAll bool
// Normalized list of plain allowed origins // Normalized list of plain allowed origins
@ -40,10 +36,8 @@ type cors struct {
} }
// New creates a new Cors handler with the provided options. // New creates a new Cors handler with the provided options.
func New(ctx context.Context, co CorsOptions) Cors { func New(co CorsOptions) Cors {
c := &cors{ c := &cors{
ctx: ctx,
logger: logging.WithContext(ctx),
co: co, co: co,
exposedHeaders: convert(co.ExposedHeaders, http.CanonicalHeaderKey), exposedHeaders: convert(co.ExposedHeaders, http.CanonicalHeaderKey),
} }
@ -109,14 +103,14 @@ func New(ctx context.Context, co CorsOptions) Cors {
} }
// Default creates a new Cors handler with default options. // Default creates a new Cors handler with default options.
func Default(ctx context.Context) Cors { func Default() Cors {
return New(ctx, CorsOptions{}) return New(CorsOptions{})
} }
// AllowAll create a new Cors handler with permissive configuration allowing all // AllowAll create a new Cors handler with permissive configuration allowing all
// origins with all standard methods with any header and credentials. // origins with all standard methods with any header and credentials.
func AllowAll(ctx context.Context) Cors { func AllowAll() Cors {
return New(ctx, CorsOptions{ return New(CorsOptions{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
AllowedMethods: []string{"HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"}, AllowedMethods: []string{"HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"},
AllowedHeaders: []string{"*"}, AllowedHeaders: []string{"*"},
@ -129,7 +123,7 @@ func AllowAll(ctx context.Context) Cors {
func (c *cors) Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler { func (c *cors) Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler {
return fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { return fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) {
if string(ctx.Method()) == "OPTIONS" && ctx.Request.Header.Peek("Access-Control-Request-Method") != nil { if string(ctx.Method()) == "OPTIONS" && ctx.Request.Header.Peek("Access-Control-Request-Method") != nil {
c.logger.Info("Handler: Preflight request") logging.Logger().Info("Handler: Preflight request")
c.handlePreflight(ctx) c.handlePreflight(ctx)
// Preflight requests are standalone and should stop the chain as some other // Preflight requests are standalone and should stop the chain as some other
// middleware may not handle OPTIONS requests correctly. One typical example // middleware may not handle OPTIONS requests correctly. One typical example
@ -141,7 +135,7 @@ func (c *cors) Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler {
ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetStatusCode(fasthttp.StatusOK)
} }
} else { } else {
c.logger.Info("Handler: Actual request") logging.Logger().Info("Handler: Actual request")
c.handleActualRequest(ctx) c.handleActualRequest(ctx)
h(ctx) h(ctx)
} }
@ -153,7 +147,7 @@ func (c *cors) handlePreflight(ctx *fasthttp.RequestCtx) {
origin := string(ctx.Request.Header.Peek("Origin")) origin := string(ctx.Request.Header.Peek("Origin"))
if string(ctx.Method()) != "OPTIONS" { if string(ctx.Method()) != "OPTIONS" {
c.logger.Info(fmt.Sprintf(" Preflight aborted: %s!=OPTIONS", string(ctx.Method()))) logging.Logger().Info(fmt.Sprintf(" Preflight aborted: %s!=OPTIONS", string(ctx.Method())))
return return
} }
// Always set Vary headers // Always set Vary headers
@ -164,22 +158,22 @@ func (c *cors) handlePreflight(ctx *fasthttp.RequestCtx) {
ctx.Response.Header.Add("Vary", "Access-Control-Request-Headers") ctx.Response.Header.Add("Vary", "Access-Control-Request-Headers")
if origin == "" { if origin == "" {
c.logger.Info(" Preflight aborted: empty origin") logging.Logger().Info(" Preflight aborted: empty origin")
return return
} }
if !c.isOriginAllowed(origin) { if !c.isOriginAllowed(origin) {
c.logger.Info(fmt.Sprintf(" Preflight aborted: origin '%s' not allowed", origin)) logging.Logger().Info(fmt.Sprintf(" Preflight aborted: origin '%s' not allowed", origin))
return return
} }
reqMethod := string(ctx.Request.Header.Peek("Access-Control-Request-Method")) reqMethod := string(ctx.Request.Header.Peek("Access-Control-Request-Method"))
if !c.isMethodAllowed(reqMethod) { if !c.isMethodAllowed(reqMethod) {
c.logger.Info(fmt.Sprintf(" Preflight aborted: method '%s' not allowed", reqMethod)) logging.Logger().Info(fmt.Sprintf(" Preflight aborted: method '%s' not allowed", reqMethod))
return return
} }
reqHeaders := parseHeaderList(string(ctx.Request.Header.Peek("Access-Control-Request-Headers"))) reqHeaders := parseHeaderList(string(ctx.Request.Header.Peek("Access-Control-Request-Headers")))
if !c.areHeadersAllowed(reqHeaders) { if !c.areHeadersAllowed(reqHeaders) {
c.logger.Info(fmt.Sprintf(" Preflight aborted: headers '%v' not allowed", reqHeaders)) logging.Logger().Info(fmt.Sprintf(" Preflight aborted: headers '%v' not allowed", reqHeaders))
return return
} }
if c.allowedOriginsAll && !c.co.AllowCredentials { if c.allowedOriginsAll && !c.co.AllowCredentials {
@ -201,7 +195,7 @@ func (c *cors) handlePreflight(ctx *fasthttp.RequestCtx) {
if c.co.MaxAge > 0 { if c.co.MaxAge > 0 {
ctx.Response.Header.Set("Access-Control-Max-Age", strconv.Itoa(c.co.MaxAge)) ctx.Response.Header.Set("Access-Control-Max-Age", strconv.Itoa(c.co.MaxAge))
} }
// c.logger.Info(fmt.Sprintf(" Preflight response headers: %v", ctx.Response.Header.)) // logging.Logger().Info(fmt.Sprintf(" Preflight response headers: %v", ctx.Response.Header.))
} }
// handleActualRequest handles simple cross-origin requests, actual request or redirects // handleActualRequest handles simple cross-origin requests, actual request or redirects
@ -210,17 +204,17 @@ func (c *cors) handleActualRequest(ctx *fasthttp.RequestCtx) {
method := string(ctx.Method()) method := string(ctx.Method())
if method == "OPTIONS" { if method == "OPTIONS" {
c.logger.Info(fmt.Sprintf(" Actual request no headers added: method == %s", method)) logging.Logger().Info(fmt.Sprintf(" Actual request no headers added: method == %s", method))
return return
} }
// Always set Vary, see https://github.com/rs/cors/issues/10 // Always set Vary, see https://github.com/rs/cors/issues/10
ctx.Response.Header.Add("Vary", "Origin") ctx.Response.Header.Add("Vary", "Origin")
if origin == "" { if origin == "" {
c.logger.Info(" Actual request no headers added: missing origin") logging.Logger().Info(" Actual request no headers added: missing origin")
return return
} }
if !c.isOriginAllowed(origin) { if !c.isOriginAllowed(origin) {
c.logger.Info(fmt.Sprintf(" Actual request no headers added: origin '%s' not allowed", origin)) logging.Logger().Info(fmt.Sprintf(" Actual request no headers added: origin '%s' not allowed", origin))
return return
} }
@ -229,7 +223,7 @@ func (c *cors) handleActualRequest(ctx *fasthttp.RequestCtx) {
// spec doesn't instruct to check the allowed methods for simple cross-origin requests. // spec doesn't instruct to check the allowed methods for simple cross-origin requests.
// We think it's a nice feature to be able to have control on those methods though. // We think it's a nice feature to be able to have control on those methods though.
if !c.isMethodAllowed(method) { if !c.isMethodAllowed(method) {
c.logger.Info(fmt.Sprintf(" Actual request no headers added: method '%s' not allowed", method)) logging.Logger().Info(fmt.Sprintf(" Actual request no headers added: method '%s' not allowed", method))
return return
} }

View File

@ -1,16 +1,13 @@
package main package main
import ( import (
"context"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"git.loafle.net/commons_go/cors_fasthttp" "git.loafle.net/commons_go/cors_fasthttp"
) )
func main() { func main() {
ctx := context.Background() c := cors_fasthttp.AllowAll()
c := cors_fasthttp.AllowAll(ctx)
fasthttp.ListenAndServe(":8080", c.Handler(Handler)) fasthttp.ListenAndServe(":8080", c.Handler(Handler))
} }