vendor/google.golang.org/protobuf/internal/impl/codec_message.go
changeset 256 6d9efbef00a9
child 260 445e01aede7e
equal deleted inserted replaced
255:4f153a23adab 256:6d9efbef00a9
       
     1 // Copyright 2019 The Go Authors. All rights reserved.
       
     2 // Use of this source code is governed by a BSD-style
       
     3 // license that can be found in the LICENSE file.
       
     4 
       
     5 package impl
       
     6 
       
     7 import (
       
     8 	"fmt"
       
     9 	"reflect"
       
    10 	"sort"
       
    11 
       
    12 	"google.golang.org/protobuf/encoding/protowire"
       
    13 	"google.golang.org/protobuf/internal/encoding/messageset"
       
    14 	"google.golang.org/protobuf/internal/order"
       
    15 	pref "google.golang.org/protobuf/reflect/protoreflect"
       
    16 	piface "google.golang.org/protobuf/runtime/protoiface"
       
    17 )
       
    18 
       
    19 // coderMessageInfo contains per-message information used by the fast-path functions.
       
    20 // This is a different type from MessageInfo to keep MessageInfo as general-purpose as
       
    21 // possible.
       
    22 type coderMessageInfo struct {
       
    23 	methods piface.Methods
       
    24 
       
    25 	orderedCoderFields []*coderFieldInfo
       
    26 	denseCoderFields   []*coderFieldInfo
       
    27 	coderFields        map[protowire.Number]*coderFieldInfo
       
    28 	sizecacheOffset    offset
       
    29 	unknownOffset      offset
       
    30 	unknownPtrKind     bool
       
    31 	extensionOffset    offset
       
    32 	needsInitCheck     bool
       
    33 	isMessageSet       bool
       
    34 	numRequiredFields  uint8
       
    35 }
       
    36 
       
    37 type coderFieldInfo struct {
       
    38 	funcs      pointerCoderFuncs // fast-path per-field functions
       
    39 	mi         *MessageInfo      // field's message
       
    40 	ft         reflect.Type
       
    41 	validation validationInfo   // information used by message validation
       
    42 	num        pref.FieldNumber // field number
       
    43 	offset     offset           // struct field offset
       
    44 	wiretag    uint64           // field tag (number + wire type)
       
    45 	tagsize    int              // size of the varint-encoded tag
       
    46 	isPointer  bool             // true if IsNil may be called on the struct field
       
    47 	isRequired bool             // true if field is required
       
    48 }
       
    49 
       
    50 func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
       
    51 	mi.sizecacheOffset = invalidOffset
       
    52 	mi.unknownOffset = invalidOffset
       
    53 	mi.extensionOffset = invalidOffset
       
    54 
       
    55 	if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType {
       
    56 		mi.sizecacheOffset = si.sizecacheOffset
       
    57 	}
       
    58 	if si.unknownOffset.IsValid() && (si.unknownType == unknownFieldsAType || si.unknownType == unknownFieldsBType) {
       
    59 		mi.unknownOffset = si.unknownOffset
       
    60 		mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
       
    61 	}
       
    62 	if si.extensionOffset.IsValid() && si.extensionType == extensionFieldsType {
       
    63 		mi.extensionOffset = si.extensionOffset
       
    64 	}
       
    65 
       
    66 	mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
       
    67 	fields := mi.Desc.Fields()
       
    68 	preallocFields := make([]coderFieldInfo, fields.Len())
       
    69 	for i := 0; i < fields.Len(); i++ {
       
    70 		fd := fields.Get(i)
       
    71 
       
    72 		fs := si.fieldsByNumber[fd.Number()]
       
    73 		isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
       
    74 		if isOneof {
       
    75 			fs = si.oneofsByName[fd.ContainingOneof().Name()]
       
    76 		}
       
    77 		ft := fs.Type
       
    78 		var wiretag uint64
       
    79 		if !fd.IsPacked() {
       
    80 			wiretag = protowire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
       
    81 		} else {
       
    82 			wiretag = protowire.EncodeTag(fd.Number(), protowire.BytesType)
       
    83 		}
       
    84 		var fieldOffset offset
       
    85 		var funcs pointerCoderFuncs
       
    86 		var childMessage *MessageInfo
       
    87 		switch {
       
    88 		case ft == nil:
       
    89 			// This never occurs for generated message types.
       
    90 			// It implies that a hand-crafted type has missing Go fields
       
    91 			// for specific protobuf message fields.
       
    92 			funcs = pointerCoderFuncs{
       
    93 				size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
       
    94 					return 0
       
    95 				},
       
    96 				marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
       
    97 					return nil, nil
       
    98 				},
       
    99 				unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
       
   100 					panic("missing Go struct field for " + string(fd.FullName()))
       
   101 				},
       
   102 				isInit: func(p pointer, f *coderFieldInfo) error {
       
   103 					panic("missing Go struct field for " + string(fd.FullName()))
       
   104 				},
       
   105 				merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
       
   106 					panic("missing Go struct field for " + string(fd.FullName()))
       
   107 				},
       
   108 			}
       
   109 		case isOneof:
       
   110 			fieldOffset = offsetOf(fs, mi.Exporter)
       
   111 		case fd.IsWeak():
       
   112 			fieldOffset = si.weakOffset
       
   113 			funcs = makeWeakMessageFieldCoder(fd)
       
   114 		default:
       
   115 			fieldOffset = offsetOf(fs, mi.Exporter)
       
   116 			childMessage, funcs = fieldCoder(fd, ft)
       
   117 		}
       
   118 		cf := &preallocFields[i]
       
   119 		*cf = coderFieldInfo{
       
   120 			num:        fd.Number(),
       
   121 			offset:     fieldOffset,
       
   122 			wiretag:    wiretag,
       
   123 			ft:         ft,
       
   124 			tagsize:    protowire.SizeVarint(wiretag),
       
   125 			funcs:      funcs,
       
   126 			mi:         childMessage,
       
   127 			validation: newFieldValidationInfo(mi, si, fd, ft),
       
   128 			isPointer:  fd.Cardinality() == pref.Repeated || fd.HasPresence(),
       
   129 			isRequired: fd.Cardinality() == pref.Required,
       
   130 		}
       
   131 		mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
       
   132 		mi.coderFields[cf.num] = cf
       
   133 	}
       
   134 	for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
       
   135 		if od := oneofs.Get(i); !od.IsSynthetic() {
       
   136 			mi.initOneofFieldCoders(od, si)
       
   137 		}
       
   138 	}
       
   139 	if messageset.IsMessageSet(mi.Desc) {
       
   140 		if !mi.extensionOffset.IsValid() {
       
   141 			panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
       
   142 		}
       
   143 		if !mi.unknownOffset.IsValid() {
       
   144 			panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName()))
       
   145 		}
       
   146 		mi.isMessageSet = true
       
   147 	}
       
   148 	sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
       
   149 		return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
       
   150 	})
       
   151 
       
   152 	var maxDense pref.FieldNumber
       
   153 	for _, cf := range mi.orderedCoderFields {
       
   154 		if cf.num >= 16 && cf.num >= 2*maxDense {
       
   155 			break
       
   156 		}
       
   157 		maxDense = cf.num
       
   158 	}
       
   159 	mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1)
       
   160 	for _, cf := range mi.orderedCoderFields {
       
   161 		if int(cf.num) >= len(mi.denseCoderFields) {
       
   162 			break
       
   163 		}
       
   164 		mi.denseCoderFields[cf.num] = cf
       
   165 	}
       
   166 
       
   167 	// To preserve compatibility with historic wire output, marshal oneofs last.
       
   168 	if mi.Desc.Oneofs().Len() > 0 {
       
   169 		sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
       
   170 			fi := fields.ByNumber(mi.orderedCoderFields[i].num)
       
   171 			fj := fields.ByNumber(mi.orderedCoderFields[j].num)
       
   172 			return order.LegacyFieldOrder(fi, fj)
       
   173 		})
       
   174 	}
       
   175 
       
   176 	mi.needsInitCheck = needsInitCheck(mi.Desc)
       
   177 	if mi.methods.Marshal == nil && mi.methods.Size == nil {
       
   178 		mi.methods.Flags |= piface.SupportMarshalDeterministic
       
   179 		mi.methods.Marshal = mi.marshal
       
   180 		mi.methods.Size = mi.size
       
   181 	}
       
   182 	if mi.methods.Unmarshal == nil {
       
   183 		mi.methods.Flags |= piface.SupportUnmarshalDiscardUnknown
       
   184 		mi.methods.Unmarshal = mi.unmarshal
       
   185 	}
       
   186 	if mi.methods.CheckInitialized == nil {
       
   187 		mi.methods.CheckInitialized = mi.checkInitialized
       
   188 	}
       
   189 	if mi.methods.Merge == nil {
       
   190 		mi.methods.Merge = mi.merge
       
   191 	}
       
   192 }
       
   193 
       
   194 // getUnknownBytes returns a *[]byte for the unknown fields.
       
   195 // It is the caller's responsibility to check whether the pointer is nil.
       
   196 // This function is specially designed to be inlineable.
       
   197 func (mi *MessageInfo) getUnknownBytes(p pointer) *[]byte {
       
   198 	if mi.unknownPtrKind {
       
   199 		return *p.Apply(mi.unknownOffset).BytesPtr()
       
   200 	} else {
       
   201 		return p.Apply(mi.unknownOffset).Bytes()
       
   202 	}
       
   203 }
       
   204 
       
   205 // mutableUnknownBytes returns a *[]byte for the unknown fields.
       
   206 // The returned pointer is guaranteed to not be nil.
       
   207 func (mi *MessageInfo) mutableUnknownBytes(p pointer) *[]byte {
       
   208 	if mi.unknownPtrKind {
       
   209 		bp := p.Apply(mi.unknownOffset).BytesPtr()
       
   210 		if *bp == nil {
       
   211 			*bp = new([]byte)
       
   212 		}
       
   213 		return *bp
       
   214 	} else {
       
   215 		return p.Apply(mi.unknownOffset).Bytes()
       
   216 	}
       
   217 }