vendor/github.com/golang/protobuf/proto/extensions.go
changeset 251 1c52a0eeb952
parent 242 2a9ec03fe5a1
child 256 6d9efbef00a9
--- a/vendor/github.com/golang/protobuf/proto/extensions.go	Wed Sep 18 19:17:42 2019 +0200
+++ b/vendor/github.com/golang/protobuf/proto/extensions.go	Sun Feb 16 18:54:01 2020 +0100
@@ -185,9 +185,25 @@
 	// extension will have only enc set. When such an extension is
 	// accessed using GetExtension (or GetExtensions) desc and value
 	// will be set.
-	desc  *ExtensionDesc
+	desc *ExtensionDesc
+
+	// value is a concrete value for the extension field. Let the type of
+	// desc.ExtensionType be the "API type" and the type of Extension.value
+	// be the "storage type". The API type and storage type are the same except:
+	//	* For scalars (except []byte), the API type uses *T,
+	//	while the storage type uses T.
+	//	* For repeated fields, the API type uses []T, while the storage type
+	//	uses *[]T.
+	//
+	// The reason for the divergence is so that the storage type more naturally
+	// matches what is expected of when retrieving the values through the
+	// protobuf reflection APIs.
+	//
+	// The value may only be populated if desc is also populated.
 	value interface{}
-	enc   []byte
+
+	// enc is the raw bytes for the extension field.
+	enc []byte
 }
 
 // SetRawExtension is for testing only.
@@ -334,7 +350,7 @@
 			// descriptors with the same field number.
 			return nil, errors.New("proto: descriptor conflict")
 		}
-		return e.value, nil
+		return extensionAsLegacyType(e.value), nil
 	}
 
 	if extension.ExtensionType == nil {
@@ -349,11 +365,11 @@
 
 	// Remember the decoded version and drop the encoded version.
 	// That way it is safe to mutate what we return.
-	e.value = v
+	e.value = extensionAsStorageType(v)
 	e.desc = extension
 	e.enc = nil
 	emap[extension.Field] = e
-	return e.value, nil
+	return extensionAsLegacyType(e.value), nil
 }
 
 // defaultExtensionValue returns the default value for extension.
@@ -488,7 +504,7 @@
 	}
 	typ := reflect.TypeOf(extension.ExtensionType)
 	if typ != reflect.TypeOf(value) {
-		return errors.New("proto: bad extension value type")
+		return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", value, extension.ExtensionType)
 	}
 	// nil extension values need to be caught early, because the
 	// encoder can't distinguish an ErrNil due to a nil extension
@@ -500,7 +516,7 @@
 	}
 
 	extmap := epb.extensionsWrite()
-	extmap[extension.Field] = Extension{desc: extension, value: value}
+	extmap[extension.Field] = Extension{desc: extension, value: extensionAsStorageType(value)}
 	return nil
 }
 
@@ -541,3 +557,51 @@
 func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
 	return extensionMaps[reflect.TypeOf(pb).Elem()]
 }
+
+// extensionAsLegacyType converts an value in the storage type as the API type.
+// See Extension.value.
+func extensionAsLegacyType(v interface{}) interface{} {
+	switch rv := reflect.ValueOf(v); rv.Kind() {
+	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
+		// Represent primitive types as a pointer to the value.
+		rv2 := reflect.New(rv.Type())
+		rv2.Elem().Set(rv)
+		v = rv2.Interface()
+	case reflect.Ptr:
+		// Represent slice types as the value itself.
+		switch rv.Type().Elem().Kind() {
+		case reflect.Slice:
+			if rv.IsNil() {
+				v = reflect.Zero(rv.Type().Elem()).Interface()
+			} else {
+				v = rv.Elem().Interface()
+			}
+		}
+	}
+	return v
+}
+
+// extensionAsStorageType converts an value in the API type as the storage type.
+// See Extension.value.
+func extensionAsStorageType(v interface{}) interface{} {
+	switch rv := reflect.ValueOf(v); rv.Kind() {
+	case reflect.Ptr:
+		// Represent slice types as the value itself.
+		switch rv.Type().Elem().Kind() {
+		case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
+			if rv.IsNil() {
+				v = reflect.Zero(rv.Type().Elem()).Interface()
+			} else {
+				v = rv.Elem().Interface()
+			}
+		}
+	case reflect.Slice:
+		// Represent slice types as a pointer to the value.
+		if rv.Type().Elem().Kind() != reflect.Uint8 {
+			rv2 := reflect.New(rv.Type())
+			rv2.Elem().Set(rv)
+			v = rv2.Interface()
+		}
+	}
+	return v
+}