package runtime
import (
"context"
"errors"
"fmt"
"net/http"
"net/textproto"
"regexp"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
type UnescapingMode int
const (
UnescapingModeLegacy UnescapingMode = iota
UnescapingModeAllExceptReserved
UnescapingModeAllExceptSlash
UnescapingModeAllCharacters
UnescapingModeDefault = UnescapingModeLegacy
)
var encodedPathSplitter = regexp .MustCompile ("(/|%2F)" )
type HandlerFunc func (w http .ResponseWriter , r *http .Request , pathParams map [string ]string )
type Middleware func (HandlerFunc ) HandlerFunc
type ServeMux struct {
handlers map [string ][]handler
middlewares []Middleware
forwardResponseOptions []func (context .Context , http .ResponseWriter , proto .Message ) error
forwardResponseRewriter ForwardResponseRewriter
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
outgoingTrailerMatcher HeaderMatcherFunc
metadataAnnotators []func (context .Context , *http .Request ) metadata .MD
errorHandler ErrorHandlerFunc
streamErrorHandler StreamErrorHandlerFunc
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
writeContentLength bool
}
type ServeMuxOption func (*ServeMux )
type ForwardResponseRewriter func (ctx context .Context , response proto .Message ) (any , error )
func WithForwardResponseRewriter (fwdResponseRewriter ForwardResponseRewriter ) ServeMuxOption {
return func (sm *ServeMux ) {
sm .forwardResponseRewriter = fwdResponseRewriter
}
}
func WithForwardResponseOption (forwardResponseOption func (context .Context , http .ResponseWriter , proto .Message ) error ) ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .forwardResponseOptions = append (serveMux .forwardResponseOptions , forwardResponseOption )
}
}
func WithUnescapingMode (mode UnescapingMode ) ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .unescapingMode = mode
}
}
func WithMiddlewares (middlewares ...Middleware ) ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .middlewares = append (serveMux .middlewares , middlewares ...)
}
}
func SetQueryParameterParser (queryParameterParser QueryParameterParser ) ServeMuxOption {
return func (serveMux *ServeMux ) {
currentQueryParser = queryParameterParser
}
}
type HeaderMatcherFunc func (string ) (string , bool )
func DefaultHeaderMatcher (key string ) (string , bool ) {
switch key = textproto .CanonicalMIMEHeaderKey (key ); {
case isPermanentHTTPHeader (key ):
return MetadataPrefix + key , true
case strings .HasPrefix (key , MetadataHeaderPrefix ):
return key [len (MetadataHeaderPrefix ):], true
}
return "" , false
}
func defaultOutgoingHeaderMatcher(key string ) (string , bool ) {
return fmt .Sprintf ("%s%s" , MetadataHeaderPrefix , key ), true
}
func defaultOutgoingTrailerMatcher(key string ) (string , bool ) {
return fmt .Sprintf ("%s%s" , MetadataTrailerPrefix , key ), true
}
func WithIncomingHeaderMatcher (fn HeaderMatcherFunc ) ServeMuxOption {
for _ , header := range fn .matchedMalformedHeaders () {
grpclog .Warningf ("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information." , header )
}
return func (mux *ServeMux ) {
mux .incomingHeaderMatcher = fn
}
}
func (fn HeaderMatcherFunc ) matchedMalformedHeaders () []string {
if fn == nil {
return nil
}
headers := make ([]string , 0 )
for header := range malformedHTTPHeaders {
out , accept := fn (header )
if accept && isMalformedHTTPHeader (out ) {
headers = append (headers , out )
}
}
return headers
}
func WithOutgoingHeaderMatcher (fn HeaderMatcherFunc ) ServeMuxOption {
return func (mux *ServeMux ) {
mux .outgoingHeaderMatcher = fn
}
}
func WithOutgoingTrailerMatcher (fn HeaderMatcherFunc ) ServeMuxOption {
return func (mux *ServeMux ) {
mux .outgoingTrailerMatcher = fn
}
}
func WithMetadata (annotator func (context .Context , *http .Request ) metadata .MD ) ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .metadataAnnotators = append (serveMux .metadataAnnotators , annotator )
}
}
func WithErrorHandler (fn ErrorHandlerFunc ) ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .errorHandler = fn
}
}
func WithStreamErrorHandler (fn StreamErrorHandlerFunc ) ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .streamErrorHandler = fn
}
}
func WithRoutingErrorHandler (fn RoutingErrorHandlerFunc ) ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .routingErrorHandler = fn
}
}
func WithDisablePathLengthFallback () ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .disablePathLengthFallback = true
}
}
func WithWriteContentLength () ServeMuxOption {
return func (serveMux *ServeMux ) {
serveMux .writeContentLength = true
}
}
func WithHealthEndpointAt (healthCheckClient grpc_health_v1 .HealthClient , endpointPath string ) ServeMuxOption {
return func (s *ServeMux ) {
_ = s .HandlePath (
http .MethodGet , endpointPath , func (w http .ResponseWriter , r *http .Request , _ map [string ]string ,
) {
_ , outboundMarshaler := MarshalerForRequest (s , r )
resp , err := healthCheckClient .Check (r .Context (), &grpc_health_v1 .HealthCheckRequest {
Service : r .URL .Query ().Get ("service" ),
})
if err != nil {
s .errorHandler (r .Context (), s , outboundMarshaler , w , r , err )
return
}
w .Header ().Set ("Content-Type" , "application/json" )
if resp .GetStatus () != grpc_health_v1 .HealthCheckResponse_SERVING {
switch resp .GetStatus () {
case grpc_health_v1 .HealthCheckResponse_NOT_SERVING , grpc_health_v1 .HealthCheckResponse_UNKNOWN :
err = status .Error (codes .Unavailable , resp .String ())
case grpc_health_v1 .HealthCheckResponse_SERVICE_UNKNOWN :
err = status .Error (codes .NotFound , resp .String ())
}
s .errorHandler (r .Context (), s , outboundMarshaler , w , r , err )
return
}
_ = outboundMarshaler .NewEncoder (w ).Encode (resp )
})
}
}
func WithHealthzEndpoint (healthCheckClient grpc_health_v1 .HealthClient ) ServeMuxOption {
return WithHealthEndpointAt (healthCheckClient , "/healthz" )
}
func NewServeMux (opts ...ServeMuxOption ) *ServeMux {
serveMux := &ServeMux {
handlers : make (map [string ][]handler ),
forwardResponseOptions : make ([]func (context .Context , http .ResponseWriter , proto .Message ) error , 0 ),
forwardResponseRewriter : func (ctx context .Context , response proto .Message ) (any , error ) { return response , nil },
marshalers : makeMarshalerMIMERegistry (),
errorHandler : DefaultHTTPErrorHandler ,
streamErrorHandler : DefaultStreamErrorHandler ,
routingErrorHandler : DefaultRoutingErrorHandler ,
unescapingMode : UnescapingModeDefault ,
}
for _ , opt := range opts {
opt (serveMux )
}
if serveMux .incomingHeaderMatcher == nil {
serveMux .incomingHeaderMatcher = DefaultHeaderMatcher
}
if serveMux .outgoingHeaderMatcher == nil {
serveMux .outgoingHeaderMatcher = defaultOutgoingHeaderMatcher
}
if serveMux .outgoingTrailerMatcher == nil {
serveMux .outgoingTrailerMatcher = defaultOutgoingTrailerMatcher
}
return serveMux
}
func (s *ServeMux ) Handle (meth string , pat Pattern , h HandlerFunc ) {
if len (s .middlewares ) > 0 {
h = chainMiddlewares (s .middlewares )(h )
}
s .handlers [meth ] = append ([]handler {{pat : pat , h : h }}, s .handlers [meth ]...)
}
func (s *ServeMux ) HandlePath (meth string , pathPattern string , h HandlerFunc ) error {
compiler , err := httprule .Parse (pathPattern )
if err != nil {
return fmt .Errorf ("parsing path pattern: %w" , err )
}
tp := compiler .Compile ()
pattern , err := NewPattern (tp .Version , tp .OpCodes , tp .Pool , tp .Verb )
if err != nil {
return fmt .Errorf ("creating new pattern: %w" , err )
}
s .Handle (meth , pattern , h )
return nil
}
func (s *ServeMux ) ServeHTTP (w http .ResponseWriter , r *http .Request ) {
ctx := r .Context ()
path := r .URL .Path
if !strings .HasPrefix (path , "/" ) {
_ , outboundMarshaler := MarshalerForRequest (s , r )
s .routingErrorHandler (ctx , s , outboundMarshaler , w , r , http .StatusBadRequest )
return
}
if s .unescapingMode != UnescapingModeLegacy && r .URL .RawPath != "" {
path = r .URL .RawPath
}
if override := r .Header .Get ("X-HTTP-Method-Override" ); override != "" && s .isPathLengthFallback (r ) {
if err := r .ParseForm (); err != nil {
_ , outboundMarshaler := MarshalerForRequest (s , r )
sterr := status .Error (codes .InvalidArgument , err .Error())
s .errorHandler (ctx , s , outboundMarshaler , w , r , sterr )
return
}
r .Method = strings .ToUpper (override )
}
var pathComponents []string
if s .unescapingMode == UnescapingModeAllCharacters {
pathComponents = encodedPathSplitter .Split (path [1 :], -1 )
} else {
pathComponents = strings .Split (path [1 :], "/" )
}
lastPathComponent := pathComponents [len (pathComponents )-1 ]
for _ , h := range s .handlers [r .Method ] {
var verb string
patVerb := h .pat .Verb ()
idx := -1
if patVerb != "" && strings .HasSuffix (lastPathComponent , ":" +patVerb ) {
idx = len (lastPathComponent ) - len (patVerb ) - 1
}
if idx == 0 {
_ , outboundMarshaler := MarshalerForRequest (s , r )
s .routingErrorHandler (ctx , s , outboundMarshaler , w , r , http .StatusNotFound )
return
}
comps := make ([]string , len (pathComponents ))
copy (comps , pathComponents )
if idx > 0 {
comps [len (comps )-1 ], verb = lastPathComponent [:idx ], lastPathComponent [idx +1 :]
}
pathParams , err := h .pat .MatchAndEscape (comps , verb , s .unescapingMode )
if err != nil {
var mse MalformedSequenceError
if ok := errors .As (err , &mse ); ok {
_ , outboundMarshaler := MarshalerForRequest (s , r )
s .errorHandler (ctx , s , outboundMarshaler , w , r , &HTTPStatusError {
HTTPStatus : http .StatusBadRequest ,
Err : mse ,
})
}
continue
}
s .handleHandler (h , w , r , pathParams )
return
}
for m , handlers := range s .handlers {
if m == r .Method {
continue
}
for _ , h := range handlers {
var verb string
patVerb := h .pat .Verb ()
idx := -1
if patVerb != "" && strings .HasSuffix (lastPathComponent , ":" +patVerb ) {
idx = len (lastPathComponent ) - len (patVerb ) - 1
}
comps := make ([]string , len (pathComponents ))
copy (comps , pathComponents )
if idx > 0 {
comps [len (comps )-1 ], verb = lastPathComponent [:idx ], lastPathComponent [idx +1 :]
}
pathParams , err := h .pat .MatchAndEscape (comps , verb , s .unescapingMode )
if err != nil {
var mse MalformedSequenceError
if ok := errors .As (err , &mse ); ok {
_ , outboundMarshaler := MarshalerForRequest (s , r )
s .errorHandler (ctx , s , outboundMarshaler , w , r , &HTTPStatusError {
HTTPStatus : http .StatusBadRequest ,
Err : mse ,
})
}
continue
}
if s .isPathLengthFallback (r ) && m == http .MethodGet {
if err := r .ParseForm (); err != nil {
_ , outboundMarshaler := MarshalerForRequest (s , r )
sterr := status .Error (codes .InvalidArgument , err .Error())
s .errorHandler (ctx , s , outboundMarshaler , w , r , sterr )
return
}
s .handleHandler (h , w , r , pathParams )
return
}
_ , outboundMarshaler := MarshalerForRequest (s , r )
s .routingErrorHandler (ctx , s , outboundMarshaler , w , r , http .StatusMethodNotAllowed )
return
}
}
_ , outboundMarshaler := MarshalerForRequest (s , r )
s .routingErrorHandler (ctx , s , outboundMarshaler , w , r , http .StatusNotFound )
}
func (s *ServeMux ) GetForwardResponseOptions () []func (context .Context , http .ResponseWriter , proto .Message ) error {
return s .forwardResponseOptions
}
func (s *ServeMux ) isPathLengthFallback (r *http .Request ) bool {
return !s .disablePathLengthFallback && r .Method == "POST" && r .Header .Get ("Content-Type" ) == "application/x-www-form-urlencoded"
}
type handler struct {
pat Pattern
h HandlerFunc
}
func (s *ServeMux ) handleHandler (h handler , w http .ResponseWriter , r *http .Request , pathParams map [string ]string ) {
h .h (w , r .WithContext (withHTTPPattern (r .Context (), h .pat )), pathParams )
}
func chainMiddlewares(mws []Middleware ) Middleware {
return func (next HandlerFunc ) HandlerFunc {
for i := len (mws ); i > 0 ; i -- {
next = mws [i -1 ](next )
}
return next
}
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .