proto: Fix a Marshal race on messages with extensions.
(*Buffer).enc_exts was not acquiring a necessary lock when writing
lazily-decoded extensions back to the map.
PiperOrigin-RevId: 139345543
diff --git a/jsonpb/jsonpb.go b/jsonpb/jsonpb.go
index 82c6162..1fc8ae8 100644
--- a/jsonpb/jsonpb.go
+++ b/jsonpb/jsonpb.go
@@ -585,14 +585,7 @@
case "Any":
return fmt.Errorf("unmarshaling Any not supported yet")
case "Duration":
- ivStr := string(inputValue)
- if ivStr == "null" {
- target.Field(0).SetInt(0)
- target.Field(1).SetInt(0)
- return nil
- }
-
- unq, err := strconv.Unquote(ivStr)
+ unq, err := strconv.Unquote(string(inputValue))
if err != nil {
return err
}
@@ -607,14 +600,7 @@
target.Field(1).SetInt(ns)
return nil
case "Timestamp":
- ivStr := string(inputValue)
- if ivStr == "null" {
- target.Field(0).SetInt(0)
- target.Field(1).SetInt(0)
- return nil
- }
-
- unq, err := strconv.Unquote(ivStr)
+ unq, err := strconv.Unquote(string(inputValue))
if err != nil {
return err
}
diff --git a/jsonpb/jsonpb_test.go b/jsonpb/jsonpb_test.go
index e237df5..78f67c4 100644
--- a/jsonpb/jsonpb_test.go
+++ b/jsonpb/jsonpb_test.go
@@ -467,11 +467,9 @@
{"camelName input", Unmarshaler{}, `{"oBool":true}`, &pb.Simple{OBool: proto.Bool(true)}},
{"Duration", Unmarshaler{}, `{"dur":"3.000s"}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}},
- {"null Duration", Unmarshaler{}, `{"dur":null}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 0}}},
{"Timestamp", Unmarshaler{}, `{"ts":"2014-05-13T16:53:20.021Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}},
{"PreEpochTimestamp", Unmarshaler{}, `{"ts":"1969-12-31T23:59:58.999999995Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -2, Nanos: 999999995}}},
{"ZeroTimeTimestamp", Unmarshaler{}, `{"ts":"0001-01-01T00:00:00Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -62135596800, Nanos: 0}}},
- {"null Timestamp", Unmarshaler{}, `{"ts":null}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 0, Nanos: 0}}},
{"DoubleValue", Unmarshaler{}, `{"dbl":1.2}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}},
{"FloatValue", Unmarshaler{}, `{"flt":1.2}`, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}},
diff --git a/proto/encode.go b/proto/encode.go
index 68b9b30..2b30f84 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -1075,10 +1075,17 @@
func (o *Buffer) enc_exts(p *Properties, base structPointer) error {
exts := structPointer_Extensions(base, p.field)
- if err := encodeExtensions(exts); err != nil {
+
+ v, mu := exts.extensionsRead()
+ if v == nil {
+ return nil
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ if err := encodeExtensionsMap(v); err != nil {
return err
}
- v, _ := exts.extensionsRead()
return o.enc_map_body(v)
}
diff --git a/proto/extensions_test.go b/proto/extensions_test.go
index 403d7c6..b6d9114 100644
--- a/proto/extensions_test.go
+++ b/proto/extensions_test.go
@@ -40,6 +40,7 @@
"github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/proto/testdata"
+ "golang.org/x/sync/errgroup"
)
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
@@ -506,3 +507,30 @@
t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
}
}
+
+func TestMarshalRace(t *testing.T) {
+ // unregistered extension
+ desc := &proto.ExtensionDesc{
+ ExtendedType: (*pb.MyMessage)(nil),
+ ExtensionType: (*bool)(nil),
+ Field: 101010100,
+ Name: "emptyextension",
+ Tag: "varint,0,opt",
+ }
+
+ m := &pb.MyMessage{Count: proto.Int32(4)}
+ if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
+ t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
+ }
+
+ var g errgroup.Group
+ for n := 3; n > 0; n-- {
+ g.Go(func() error {
+ _, err := proto.Marshal(m)
+ return err
+ })
+ }
+ if err := g.Wait(); err != nil {
+ t.Fatal(err)
+ }
+}