package runtime
import (
"errors"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/durationpb"
field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
var valuesKeyRegexp = regexp .MustCompile (`^(.*)\[(.*)\]$` )
var currentQueryParser QueryParameterParser = &DefaultQueryParser {}
type QueryParameterParser interface {
Parse (msg proto .Message , values url .Values , filter *utilities .DoubleArray ) error
}
func PopulateQueryParameters (msg proto .Message , values url .Values , filter *utilities .DoubleArray ) error {
return currentQueryParser .Parse (msg , values , filter )
}
type DefaultQueryParser struct {}
func (*DefaultQueryParser ) Parse (msg proto .Message , values url .Values , filter *utilities .DoubleArray ) error {
for key , values := range values {
if match := valuesKeyRegexp .FindStringSubmatch (key ); len (match ) == 3 {
key = match [1 ]
values = append ([]string {match [2 ]}, values ...)
}
msgValue := msg .ProtoReflect ()
fieldPath := normalizeFieldPath (msgValue , strings .Split (key , "." ))
if filter .HasCommonPrefix (fieldPath ) {
continue
}
if err := populateFieldValueFromPath (msgValue , fieldPath , values ); err != nil {
return err
}
}
return nil
}
func PopulateFieldFromPath (msg proto .Message , fieldPathString string , value string ) error {
fieldPath := strings .Split (fieldPathString , "." )
return populateFieldValueFromPath (msg .ProtoReflect (), fieldPath , []string {value })
}
func normalizeFieldPath(msgValue protoreflect .Message , fieldPath []string ) []string {
newFieldPath := make ([]string , 0 , len (fieldPath ))
for i , fieldName := range fieldPath {
fields := msgValue .Descriptor ().Fields ()
fieldDesc := fields .ByTextName (fieldName )
if fieldDesc == nil {
fieldDesc = fields .ByJSONName (fieldName )
}
if fieldDesc == nil {
return fieldPath
}
newFieldPath = append (newFieldPath , string (fieldDesc .Name ()))
if i == len (fieldPath )-1 {
break
}
if fieldDesc .Message () == nil || fieldDesc .Cardinality () == protoreflect .Repeated {
return fieldPath
}
msgValue = msgValue .Get (fieldDesc ).Message ()
}
return newFieldPath
}
func populateFieldValueFromPath(msgValue protoreflect .Message , fieldPath []string , values []string ) error {
if len (fieldPath ) < 1 {
return errors .New ("no field path" )
}
if len (values ) < 1 {
return errors .New ("no value provided" )
}
var fieldDescriptor protoreflect .FieldDescriptor
for i , fieldName := range fieldPath {
fields := msgValue .Descriptor ().Fields ()
fieldDescriptor = fields .ByName (protoreflect .Name (fieldName ))
if fieldDescriptor == nil {
fieldDescriptor = fields .ByJSONName (fieldName )
if fieldDescriptor == nil {
grpclog .Infof ("field not found in %q: %q" , msgValue .Descriptor ().FullName (), strings .Join (fieldPath , "." ))
return nil
}
}
if of := fieldDescriptor .ContainingOneof (); of != nil && !of .IsSynthetic () {
if f := msgValue .WhichOneof (of ); f != nil {
if fieldDescriptor .Message () == nil || fieldDescriptor .FullName () != f .FullName () {
return fmt .Errorf ("field already set for oneof %q" , of .FullName ().Name ())
}
}
}
if i == len (fieldPath )-1 {
break
}
if fieldDescriptor .Message () == nil || fieldDescriptor .Cardinality () == protoreflect .Repeated {
return fmt .Errorf ("invalid path: %q is not a message" , fieldName )
}
msgValue = msgValue .Mutable (fieldDescriptor ).Message ()
}
switch {
case fieldDescriptor .IsList ():
return populateRepeatedField (fieldDescriptor , msgValue .Mutable (fieldDescriptor ).List (), values )
case fieldDescriptor .IsMap ():
return populateMapField (fieldDescriptor , msgValue .Mutable (fieldDescriptor ).Map (), values )
}
if len (values ) > 1 {
return fmt .Errorf ("too many values for field %q: %s" , fieldDescriptor .FullName ().Name (), strings .Join (values , ", " ))
}
return populateField (fieldDescriptor , msgValue , values [0 ])
}
func populateField(fieldDescriptor protoreflect .FieldDescriptor , msgValue protoreflect .Message , value string ) error {
v , err := parseField (fieldDescriptor , value )
if err != nil {
return fmt .Errorf ("parsing field %q: %w" , fieldDescriptor .FullName ().Name (), err )
}
msgValue .Set (fieldDescriptor , v )
return nil
}
func populateRepeatedField(fieldDescriptor protoreflect .FieldDescriptor , list protoreflect .List , values []string ) error {
for _ , value := range values {
v , err := parseField (fieldDescriptor , value )
if err != nil {
return fmt .Errorf ("parsing list %q: %w" , fieldDescriptor .FullName ().Name (), err )
}
list .Append (v )
}
return nil
}
func populateMapField(fieldDescriptor protoreflect .FieldDescriptor , mp protoreflect .Map , values []string ) error {
if len (values ) != 2 {
return fmt .Errorf ("more than one value provided for key %q in map %q" , values [0 ], fieldDescriptor .FullName ())
}
key , err := parseField (fieldDescriptor .MapKey (), values [0 ])
if err != nil {
return fmt .Errorf ("parsing map key %q: %w" , fieldDescriptor .FullName ().Name (), err )
}
value , err := parseField (fieldDescriptor .MapValue (), values [1 ])
if err != nil {
return fmt .Errorf ("parsing map value %q: %w" , fieldDescriptor .FullName ().Name (), err )
}
mp .Set (key .MapKey (), value )
return nil
}
func parseField(fieldDescriptor protoreflect .FieldDescriptor , value string ) (protoreflect .Value , error ) {
switch fieldDescriptor .Kind () {
case protoreflect .BoolKind :
v , err := strconv .ParseBool (value )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfBool (v ), nil
case protoreflect .EnumKind :
enum , err := protoregistry .GlobalTypes .FindEnumByName (fieldDescriptor .Enum ().FullName ())
if err != nil {
if errors .Is (err , protoregistry .NotFound ) {
return protoreflect .Value {}, fmt .Errorf ("enum %q is not registered" , fieldDescriptor .Enum ().FullName ())
}
return protoreflect .Value {}, fmt .Errorf ("failed to look up enum: %w" , err )
}
v := enum .Descriptor ().Values ().ByName (protoreflect .Name (value ))
if v == nil {
i , err := strconv .Atoi (value )
if err != nil {
return protoreflect .Value {}, fmt .Errorf ("%q is not a valid value" , value )
}
if v = enum .Descriptor ().Values ().ByNumber (protoreflect .EnumNumber (i )); v == nil {
return protoreflect .Value {}, fmt .Errorf ("%q is not a valid value" , value )
}
}
return protoreflect .ValueOfEnum (v .Number ()), nil
case protoreflect .Int32Kind , protoreflect .Sint32Kind , protoreflect .Sfixed32Kind :
v , err := strconv .ParseInt (value , 10 , 32 )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfInt32 (int32 (v )), nil
case protoreflect .Int64Kind , protoreflect .Sint64Kind , protoreflect .Sfixed64Kind :
v , err := strconv .ParseInt (value , 10 , 64 )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfInt64 (v ), nil
case protoreflect .Uint32Kind , protoreflect .Fixed32Kind :
v , err := strconv .ParseUint (value , 10 , 32 )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfUint32 (uint32 (v )), nil
case protoreflect .Uint64Kind , protoreflect .Fixed64Kind :
v , err := strconv .ParseUint (value , 10 , 64 )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfUint64 (v ), nil
case protoreflect .FloatKind :
v , err := strconv .ParseFloat (value , 32 )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfFloat32 (float32 (v )), nil
case protoreflect .DoubleKind :
v , err := strconv .ParseFloat (value , 64 )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfFloat64 (v ), nil
case protoreflect .StringKind :
return protoreflect .ValueOfString (value ), nil
case protoreflect .BytesKind :
v , err := Bytes (value )
if err != nil {
return protoreflect .Value {}, err
}
return protoreflect .ValueOfBytes (v ), nil
case protoreflect .MessageKind , protoreflect .GroupKind :
return parseMessage (fieldDescriptor .Message (), value )
default :
panic (fmt .Sprintf ("unknown field kind: %v" , fieldDescriptor .Kind ()))
}
}
func parseMessage(msgDescriptor protoreflect .MessageDescriptor , value string ) (protoreflect .Value , error ) {
var msg proto .Message
switch msgDescriptor .FullName () {
case "google.protobuf.Timestamp" :
t , err := time .Parse (time .RFC3339Nano , value )
if err != nil {
return protoreflect .Value {}, err
}
timestamp := timestamppb .New (t )
if ok := timestamp .IsValid (); !ok {
return protoreflect .Value {}, fmt .Errorf ("%s before 0001-01-01" , value )
}
msg = timestamp
case "google.protobuf.Duration" :
d , err := time .ParseDuration (value )
if err != nil {
return protoreflect .Value {}, err
}
msg = durationpb .New (d )
case "google.protobuf.DoubleValue" :
v , err := strconv .ParseFloat (value , 64 )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .Double (v )
case "google.protobuf.FloatValue" :
v , err := strconv .ParseFloat (value , 32 )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .Float (float32 (v ))
case "google.protobuf.Int64Value" :
v , err := strconv .ParseInt (value , 10 , 64 )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .Int64 (v )
case "google.protobuf.Int32Value" :
v , err := strconv .ParseInt (value , 10 , 32 )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .Int32 (int32 (v ))
case "google.protobuf.UInt64Value" :
v , err := strconv .ParseUint (value , 10 , 64 )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .UInt64 (v )
case "google.protobuf.UInt32Value" :
v , err := strconv .ParseUint (value , 10 , 32 )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .UInt32 (uint32 (v ))
case "google.protobuf.BoolValue" :
v , err := strconv .ParseBool (value )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .Bool (v )
case "google.protobuf.StringValue" :
msg = wrapperspb .String (value )
case "google.protobuf.BytesValue" :
v , err := Bytes (value )
if err != nil {
return protoreflect .Value {}, err
}
msg = wrapperspb .Bytes (v )
case "google.protobuf.FieldMask" :
fm := &field_mask .FieldMask {}
fm .Paths = append (fm .Paths , strings .Split (value , "," )...)
msg = fm
case "google.protobuf.Value" :
var v structpb .Value
if err := protojson .Unmarshal ([]byte (value ), &v ); err != nil {
return protoreflect .Value {}, err
}
msg = &v
case "google.protobuf.Struct" :
var v structpb .Struct
if err := protojson .Unmarshal ([]byte (value ), &v ); err != nil {
return protoreflect .Value {}, err
}
msg = &v
default :
return protoreflect .Value {}, fmt .Errorf ("unsupported message type: %q" , string (msgDescriptor .FullName ()))
}
return protoreflect .ValueOfMessage (msg .ProtoReflect ()), nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .