package proto
import (
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
)
type UnmarshalOptions struct {
pragma .NoUnkeyedLiterals
Merge bool
AllowPartial bool
DiscardUnknown bool
Resolver interface {
FindExtensionByName(field protoreflect .FullName ) (protoreflect .ExtensionType , error )
FindExtensionByNumber(message protoreflect .FullName , field protoreflect .FieldNumber ) (protoreflect .ExtensionType , error )
}
RecursionLimit int
NoLazyDecoding bool
}
func Unmarshal (b []byte , m Message ) error {
_ , err := UnmarshalOptions {RecursionLimit : protowire .DefaultRecursionLimit }.unmarshal (b , m .ProtoReflect ())
return err
}
func (o UnmarshalOptions ) Unmarshal (b []byte , m Message ) error {
if o .RecursionLimit == 0 {
o .RecursionLimit = protowire .DefaultRecursionLimit
}
_ , err := o .unmarshal (b , m .ProtoReflect ())
return err
}
func (o UnmarshalOptions ) UnmarshalState (in protoiface .UnmarshalInput ) (protoiface .UnmarshalOutput , error ) {
if o .RecursionLimit == 0 {
o .RecursionLimit = protowire .DefaultRecursionLimit
}
return o .unmarshal (in .Buf , in .Message )
}
func (o UnmarshalOptions ) unmarshal (b []byte , m protoreflect .Message ) (out protoiface .UnmarshalOutput , err error ) {
if o .Resolver == nil {
o .Resolver = protoregistry .GlobalTypes
}
if !o .Merge {
Reset (m .Interface ())
}
allowPartial := o .AllowPartial
o .Merge = true
o .AllowPartial = true
methods := protoMethods (m )
if methods != nil && methods .Unmarshal != nil &&
!(o .DiscardUnknown && methods .Flags &protoiface .SupportUnmarshalDiscardUnknown == 0 ) {
in := protoiface .UnmarshalInput {
Message : m ,
Buf : b ,
Resolver : o .Resolver ,
Depth : o .RecursionLimit ,
}
if o .DiscardUnknown {
in .Flags |= protoiface .UnmarshalDiscardUnknown
}
if !allowPartial {
in .Flags |= protoiface .UnmarshalCheckRequired
}
if o .NoLazyDecoding {
in .Flags |= protoiface .UnmarshalNoLazyDecoding
}
out , err = methods .Unmarshal (in )
} else {
o .RecursionLimit --
if o .RecursionLimit < 0 {
return out , errors .New ("exceeded max recursion depth" )
}
err = o .unmarshalMessageSlow (b , m )
}
if err != nil {
return out , err
}
if allowPartial || (out .Flags &protoiface .UnmarshalInitialized != 0 ) {
return out , nil
}
return out , checkInitialized (m )
}
func (o UnmarshalOptions ) unmarshalMessage (b []byte , m protoreflect .Message ) error {
_ , err := o .unmarshal (b , m )
return err
}
func (o UnmarshalOptions ) unmarshalMessageSlow (b []byte , m protoreflect .Message ) error {
md := m .Descriptor ()
if messageset .IsMessageSet (md ) {
return o .unmarshalMessageSet (b , m )
}
fields := md .Fields ()
for len (b ) > 0 {
num , wtyp , tagLen := protowire .ConsumeTag (b )
if tagLen < 0 {
return errDecode
}
if num > protowire .MaxValidNumber {
return errDecode
}
fd := fields .ByNumber (num )
if fd == nil && md .ExtensionRanges ().Has (num ) {
extType , err := o .Resolver .FindExtensionByNumber (md .FullName (), num )
if err != nil && err != protoregistry .NotFound {
return errors .New ("%v: unable to resolve extension %v: %v" , md .FullName (), num , err )
}
if extType != nil {
fd = extType .TypeDescriptor ()
}
}
var err error
if fd == nil {
err = errUnknown
}
var valLen int
switch {
case err != nil :
case fd .IsList ():
valLen , err = o .unmarshalList (b [tagLen :], wtyp , m .Mutable (fd ).List (), fd )
case fd .IsMap ():
valLen , err = o .unmarshalMap (b [tagLen :], wtyp , m .Mutable (fd ).Map (), fd )
default :
valLen , err = o .unmarshalSingular (b [tagLen :], wtyp , m , fd )
}
if err != nil {
if err != errUnknown {
return err
}
valLen = protowire .ConsumeFieldValue (num , wtyp , b [tagLen :])
if valLen < 0 {
return errDecode
}
if !o .DiscardUnknown {
m .SetUnknown (append (m .GetUnknown (), b [:tagLen +valLen ]...))
}
}
b = b [tagLen +valLen :]
}
return nil
}
func (o UnmarshalOptions ) unmarshalSingular (b []byte , wtyp protowire .Type , m protoreflect .Message , fd protoreflect .FieldDescriptor ) (n int , err error ) {
v , n , err := o .unmarshalScalar (b , wtyp , fd )
if err != nil {
return 0 , err
}
switch fd .Kind () {
case protoreflect .GroupKind , protoreflect .MessageKind :
m2 := m .Mutable (fd ).Message ()
if err := o .unmarshalMessage (v .Bytes (), m2 ); err != nil {
return n , err
}
default :
m .Set (fd , v )
}
return n , nil
}
func (o UnmarshalOptions ) unmarshalMap (b []byte , wtyp protowire .Type , mapv protoreflect .Map , fd protoreflect .FieldDescriptor ) (n int , err error ) {
if wtyp != protowire .BytesType {
return 0 , errUnknown
}
b , n = protowire .ConsumeBytes (b )
if n < 0 {
return 0 , errDecode
}
var (
keyField = fd .MapKey ()
valField = fd .MapValue ()
key protoreflect .Value
val protoreflect .Value
haveKey bool
haveVal bool
)
switch valField .Kind () {
case protoreflect .GroupKind , protoreflect .MessageKind :
val = mapv .NewValue ()
}
for len (b ) > 0 {
num , wtyp , n := protowire .ConsumeTag (b )
if n < 0 {
return 0 , errDecode
}
if num > protowire .MaxValidNumber {
return 0 , errDecode
}
b = b [n :]
err = errUnknown
switch num {
case genid .MapEntry_Key_field_number :
key , n , err = o .unmarshalScalar (b , wtyp , keyField )
if err != nil {
break
}
haveKey = true
case genid .MapEntry_Value_field_number :
var v protoreflect .Value
v , n , err = o .unmarshalScalar (b , wtyp , valField )
if err != nil {
break
}
switch valField .Kind () {
case protoreflect .GroupKind , protoreflect .MessageKind :
if err := o .unmarshalMessage (v .Bytes (), val .Message ()); err != nil {
return 0 , err
}
default :
val = v
}
haveVal = true
}
if err == errUnknown {
n = protowire .ConsumeFieldValue (num , wtyp , b )
if n < 0 {
return 0 , errDecode
}
} else if err != nil {
return 0 , err
}
b = b [n :]
}
if !haveKey {
key = keyField .Default ()
}
if !haveVal {
switch valField .Kind () {
case protoreflect .GroupKind , protoreflect .MessageKind :
default :
val = valField .Default ()
}
}
mapv .Set (key .MapKey (), val )
return n , nil
}
var errUnknown = errors .New ("BUG: internal error (unknown)" )
var errDecode = errors .New ("cannot parse invalid wire-format data" )
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .