proto: Fix a Marshal race on messages with extensions. (*Buffer).enc_exts was not acquiring a necessary lock when writing lazily-decoded extensions back to the map. PiperOrigin-RevId: 139345543
diff --git a/jsonpb/jsonpb.go b/jsonpb/jsonpb.go index 82c6162..1fc8ae8 100644 --- a/jsonpb/jsonpb.go +++ b/jsonpb/jsonpb.go
@@ -585,14 +585,7 @@ case "Any": return fmt.Errorf("unmarshaling Any not supported yet") case "Duration": - ivStr := string(inputValue) - if ivStr == "null" { - target.Field(0).SetInt(0) - target.Field(1).SetInt(0) - return nil - } - - unq, err := strconv.Unquote(ivStr) + unq, err := strconv.Unquote(string(inputValue)) if err != nil { return err } @@ -607,14 +600,7 @@ target.Field(1).SetInt(ns) return nil case "Timestamp": - ivStr := string(inputValue) - if ivStr == "null" { - target.Field(0).SetInt(0) - target.Field(1).SetInt(0) - return nil - } - - unq, err := strconv.Unquote(ivStr) + unq, err := strconv.Unquote(string(inputValue)) if err != nil { return err }
diff --git a/jsonpb/jsonpb_test.go b/jsonpb/jsonpb_test.go index e237df5..78f67c4 100644 --- a/jsonpb/jsonpb_test.go +++ b/jsonpb/jsonpb_test.go
@@ -467,11 +467,9 @@ {"camelName input", Unmarshaler{}, `{"oBool":true}`, &pb.Simple{OBool: proto.Bool(true)}}, {"Duration", Unmarshaler{}, `{"dur":"3.000s"}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}}, - {"null Duration", Unmarshaler{}, `{"dur":null}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 0}}}, {"Timestamp", Unmarshaler{}, `{"ts":"2014-05-13T16:53:20.021Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}}, {"PreEpochTimestamp", Unmarshaler{}, `{"ts":"1969-12-31T23:59:58.999999995Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -2, Nanos: 999999995}}}, {"ZeroTimeTimestamp", Unmarshaler{}, `{"ts":"0001-01-01T00:00:00Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: -62135596800, Nanos: 0}}}, - {"null Timestamp", Unmarshaler{}, `{"ts":null}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 0, Nanos: 0}}}, {"DoubleValue", Unmarshaler{}, `{"dbl":1.2}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}}, {"FloatValue", Unmarshaler{}, `{"flt":1.2}`, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}},
diff --git a/proto/encode.go b/proto/encode.go index 68b9b30..2b30f84 100644 --- a/proto/encode.go +++ b/proto/encode.go
@@ -1075,10 +1075,17 @@ func (o *Buffer) enc_exts(p *Properties, base structPointer) error { exts := structPointer_Extensions(base, p.field) - if err := encodeExtensions(exts); err != nil { + + v, mu := exts.extensionsRead() + if v == nil { + return nil + } + + mu.Lock() + defer mu.Unlock() + if err := encodeExtensionsMap(v); err != nil { return err } - v, _ := exts.extensionsRead() return o.enc_map_body(v) }
diff --git a/proto/extensions_test.go b/proto/extensions_test.go index 403d7c6..b6d9114 100644 --- a/proto/extensions_test.go +++ b/proto/extensions_test.go
@@ -40,6 +40,7 @@ "github.com/golang/protobuf/proto" pb "github.com/golang/protobuf/proto/testdata" + "golang.org/x/sync/errgroup" ) func TestGetExtensionsWithMissingExtensions(t *testing.T) { @@ -506,3 +507,30 @@ t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) } } + +func TestMarshalRace(t *testing.T) { + // unregistered extension + desc := &proto.ExtensionDesc{ + ExtendedType: (*pb.MyMessage)(nil), + ExtensionType: (*bool)(nil), + Field: 101010100, + Name: "emptyextension", + Tag: "varint,0,opt", + } + + m := &pb.MyMessage{Count: proto.Int32(4)} + if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil { + t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err) + } + + var g errgroup.Group + for n := 3; n > 0; n-- { + g.Go(func() error { + _, err := proto.Marshal(m) + return err + }) + } + if err := g.Wait(); err != nil { + t.Fatal(err) + } +}