package proto
import (
"errors"
"fmt"
"reflect"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
"google.golang.org/protobuf/runtime/protoimpl"
)
type (
ExtensionDesc = protoimpl .ExtensionInfo
ExtensionRange = protoiface .ExtensionRangeV1
Extension = protoimpl .ExtensionFieldV1
XXX_InternalExtensions = protoimpl .ExtensionFields
)
var ErrMissingExtension = errors .New ("proto: missing extension" )
var errNotExtendable = errors .New ("proto: not an extendable proto.Message" )
func HasExtension (m Message , xt *ExtensionDesc ) (has bool ) {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () {
return false
}
xtd := xt .TypeDescriptor ()
if isValidExtension (mr .Descriptor (), xtd ) {
has = mr .Has (xtd )
} else {
mr .Range (func (fd protoreflect .FieldDescriptor , _ protoreflect .Value ) bool {
has = int32 (fd .Number ()) == xt .Field
return !has
})
}
for b := mr .GetUnknown (); !has && len (b ) > 0 ; {
num , _ , n := protowire .ConsumeField (b )
has = int32 (num ) == xt .Field
b = b [n :]
}
return has
}
func ClearExtension (m Message , xt *ExtensionDesc ) {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () {
return
}
xtd := xt .TypeDescriptor ()
if isValidExtension (mr .Descriptor (), xtd ) {
mr .Clear (xtd )
} else {
mr .Range (func (fd protoreflect .FieldDescriptor , _ protoreflect .Value ) bool {
if int32 (fd .Number ()) == xt .Field {
mr .Clear (fd )
return false
}
return true
})
}
clearUnknown (mr , fieldNum (xt .Field ))
}
func ClearAllExtensions (m Message ) {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () {
return
}
mr .Range (func (fd protoreflect .FieldDescriptor , _ protoreflect .Value ) bool {
if fd .IsExtension () {
mr .Clear (fd )
}
return true
})
clearUnknown (mr , mr .Descriptor ().ExtensionRanges ())
}
func GetExtension (m Message , xt *ExtensionDesc ) (interface {}, error ) {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () || mr .Descriptor ().ExtensionRanges ().Len () == 0 {
return nil , errNotExtendable
}
var bo protoreflect .RawFields
for bi := mr .GetUnknown (); len (bi ) > 0 ; {
num , _ , n := protowire .ConsumeField (bi )
if int32 (num ) == xt .Field {
bo = append (bo , bi [:n ]...)
}
bi = bi [n :]
}
if xt .ExtensionType == nil {
return []byte (bo ), nil
}
xtd := xt .TypeDescriptor ()
if !isValidExtension (mr .Descriptor (), xtd ) {
return nil , fmt .Errorf ("proto: bad extended type; %T does not extend %T" , xt .ExtendedType , m )
}
if !mr .Has (xtd ) && len (bo ) > 0 {
m2 := mr .New ()
if err := (proto .UnmarshalOptions {
Resolver : extensionResolver {xt },
}.Unmarshal (bo , m2 .Interface ())); err != nil {
return nil , err
}
if m2 .Has (xtd ) {
mr .Set (xtd , m2 .Get (xtd ))
clearUnknown (mr , fieldNum (xt .Field ))
}
}
var pv protoreflect .Value
switch {
case mr .Has (xtd ):
pv = mr .Get (xtd )
case xtd .HasDefault ():
pv = xtd .Default ()
default :
return nil , ErrMissingExtension
}
v := xt .InterfaceOf (pv )
rv := reflect .ValueOf (v )
if isScalarKind (rv .Kind ()) {
rv2 := reflect .New (rv .Type ())
rv2 .Elem ().Set (rv )
v = rv2 .Interface ()
}
return v , nil
}
type extensionResolver struct { xt protoreflect .ExtensionType }
func (r extensionResolver ) FindExtensionByName (field protoreflect .FullName ) (protoreflect .ExtensionType , error ) {
if xtd := r .xt .TypeDescriptor (); xtd .FullName () == field {
return r .xt , nil
}
return protoregistry .GlobalTypes .FindExtensionByName (field )
}
func (r extensionResolver ) FindExtensionByNumber (message protoreflect .FullName , field protoreflect .FieldNumber ) (protoreflect .ExtensionType , error ) {
if xtd := r .xt .TypeDescriptor (); xtd .ContainingMessage ().FullName () == message && xtd .Number () == field {
return r .xt , nil
}
return protoregistry .GlobalTypes .FindExtensionByNumber (message , field )
}
func GetExtensions (m Message , xts []*ExtensionDesc ) ([]interface {}, error ) {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () {
return nil , errNotExtendable
}
vs := make ([]interface {}, len (xts ))
for i , xt := range xts {
v , err := GetExtension (m , xt )
if err != nil {
if err == ErrMissingExtension {
continue
}
return vs , err
}
vs [i ] = v
}
return vs , nil
}
func SetExtension (m Message , xt *ExtensionDesc , v interface {}) error {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () || mr .Descriptor ().ExtensionRanges ().Len () == 0 {
return errNotExtendable
}
rv := reflect .ValueOf (v )
if reflect .TypeOf (v ) != reflect .TypeOf (xt .ExtensionType ) {
return fmt .Errorf ("proto: bad extension value type. got: %T, want: %T" , v , xt .ExtensionType )
}
if rv .Kind () == reflect .Ptr {
if rv .IsNil () {
return fmt .Errorf ("proto: SetExtension called with nil value of type %T" , v )
}
if isScalarKind (rv .Elem ().Kind ()) {
v = rv .Elem ().Interface ()
}
}
xtd := xt .TypeDescriptor ()
if !isValidExtension (mr .Descriptor (), xtd ) {
return fmt .Errorf ("proto: bad extended type; %T does not extend %T" , xt .ExtendedType , m )
}
mr .Set (xtd , xt .ValueOf (v ))
clearUnknown (mr , fieldNum (xt .Field ))
return nil
}
func SetRawExtension (m Message , fnum int32 , b []byte ) {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () {
return
}
for b0 := b ; len (b0 ) > 0 ; {
num , _ , n := protowire .ConsumeField (b0 )
if int32 (num ) != fnum {
panic (fmt .Sprintf ("mismatching field number: got %d, want %d" , num , fnum ))
}
b0 = b0 [n :]
}
ClearExtension (m , &ExtensionDesc {Field : fnum })
mr .SetUnknown (append (mr .GetUnknown (), b ...))
}
func ExtensionDescs (m Message ) ([]*ExtensionDesc , error ) {
mr := MessageReflect (m )
if mr == nil || !mr .IsValid () || mr .Descriptor ().ExtensionRanges ().Len () == 0 {
return nil , errNotExtendable
}
extDescs := make (map [protoreflect .FieldNumber ]*ExtensionDesc )
mr .Range (func (fd protoreflect .FieldDescriptor , v protoreflect .Value ) bool {
if fd .IsExtension () {
xt := fd .(protoreflect .ExtensionTypeDescriptor )
if xd , ok := xt .Type ().(*ExtensionDesc ); ok {
extDescs [fd .Number ()] = xd
}
}
return true
})
extRanges := mr .Descriptor ().ExtensionRanges ()
for b := mr .GetUnknown (); len (b ) > 0 ; {
num , _ , n := protowire .ConsumeField (b )
if extRanges .Has (num ) && extDescs [num ] == nil {
extDescs [num ] = nil
}
b = b [n :]
}
var xts []*ExtensionDesc
for num , xt := range extDescs {
if xt == nil {
xt = &ExtensionDesc {Field : int32 (num )}
}
xts = append (xts , xt )
}
return xts , nil
}
func isValidExtension(md protoreflect .MessageDescriptor , xtd protoreflect .ExtensionTypeDescriptor ) bool {
return xtd .ContainingMessage () == md && md .ExtensionRanges ().Has (xtd .Number ())
}
func isScalarKind(k reflect .Kind ) bool {
switch k {
case reflect .Bool , reflect .Int32 , reflect .Int64 , reflect .Uint32 , reflect .Uint64 , reflect .Float32 , reflect .Float64 , reflect .String :
return true
default :
return false
}
}
func clearUnknown(m protoreflect .Message , remover interface {
Has (protoreflect .FieldNumber ) bool
}) {
var bo protoreflect .RawFields
for bi := m .GetUnknown (); len (bi ) > 0 ; {
num , _ , n := protowire .ConsumeField (bi )
if !remover .Has (num ) {
bo = append (bo , bi [:n ]...)
}
bi = bi [n :]
}
if bi := m .GetUnknown (); len (bi ) != len (bo ) {
m .SetUnknown (bo )
}
}
type fieldNum protoreflect .FieldNumber
func (n1 fieldNum ) Has (n2 protoreflect .FieldNumber ) bool {
return protoreflect .FieldNumber (n1 ) == n2
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .