proto: clean up proto API for extensions

Methods that manipulate protos with extensions will now take
proto.Message instead of the internal extendableProto interface.

A ClearExtensions method is added to clear all extensions on protos.
diff --git a/proto/extensions.go b/proto/extensions.go
index 054f4f1..0de0b42 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -92,8 +92,12 @@
 }
 
 // SetRawExtension is for testing only.
-func SetRawExtension(base extendableProto, id int32, b []byte) {
-	base.ExtensionMap()[id] = Extension{enc: b}
+func SetRawExtension(base Message, id int32, b []byte) {
+	epb, ok := base.(extendableProto)
+	if !ok {
+		return
+	}
+	epb.ExtensionMap()[id] = Extension{enc: b}
 }
 
 // isExtensionField returns true iff the given field number is in an extension range.
@@ -209,26 +213,39 @@
 }
 
 // HasExtension returns whether the given extension is present in pb.
-func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
+func HasExtension(pb Message, extension *ExtensionDesc) bool {
 	// TODO: Check types, field numbers, etc.?
-	_, ok := pb.ExtensionMap()[extension.Field]
+	epb, ok := pb.(extendableProto)
+	if !ok {
+		return false
+	}
+	_, ok = epb.ExtensionMap()[extension.Field]
 	return ok
 }
 
 // ClearExtension removes the given extension from pb.
-func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
+func ClearExtension(pb Message, extension *ExtensionDesc) {
+	epb, ok := pb.(extendableProto)
+	if !ok {
+		return
+	}
 	// TODO: Check types, field numbers, etc.?
-	delete(pb.ExtensionMap(), extension.Field)
+	delete(epb.ExtensionMap(), extension.Field)
 }
 
 // GetExtension parses and returns the given extension of pb.
 // If the extension is not present and has no default value it returns ErrMissingExtension.
-func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) {
-	if err := checkExtensionTypes(pb, extension); err != nil {
+func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
+	epb, ok := pb.(extendableProto)
+	if !ok {
+		return nil, errors.New("proto: not an extendable proto")
+	}
+
+	if err := checkExtensionTypes(epb, extension); err != nil {
 		return nil, err
 	}
 
-	emap := pb.ExtensionMap()
+	emap := epb.ExtensionMap()
 	e, ok := emap[extension.Field]
 	if !ok {
 		// defaultExtensionValue returns the default value or
@@ -334,8 +351,7 @@
 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
 	epb, ok := pb.(extendableProto)
 	if !ok {
-		err = errors.New("proto: not an extendable proto")
-		return
+		return nil, errors.New("proto: not an extendable proto")
 	}
 	extensions = make([]interface{}, len(es))
 	for i, e := range es {
@@ -351,8 +367,12 @@
 }
 
 // SetExtension sets the specified extension of pb to the specified value.
-func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error {
-	if err := checkExtensionTypes(pb, extension); err != nil {
+func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
+	epb, ok := pb.(extendableProto)
+	if !ok {
+		return errors.New("proto: not an extendable proto")
+	}
+	if err := checkExtensionTypes(epb, extension); err != nil {
 		return err
 	}
 	typ := reflect.TypeOf(extension.ExtensionType)
@@ -368,10 +388,22 @@
 		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
 	}
 
-	pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
+	epb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
 	return nil
 }
 
+// ClearAllExtensions clears all extensions from pb.
+func ClearAllExtensions(pb Message) {
+	epb, ok := pb.(extendableProto)
+	if !ok {
+		return
+	}
+	m := epb.ExtensionMap()
+	for k := range m {
+		delete(m, k)
+	}
+}
+
 // A global registry of extensions.
 // The generated code will register the generated descriptors by calling RegisterExtension.
 
diff --git a/proto/extensions_test.go b/proto/extensions_test.go
index 8012210..ed6a27d 100644
--- a/proto/extensions_test.go
+++ b/proto/extensions_test.go
@@ -428,3 +428,28 @@
 		}
 	}
 }
+
+func TestClearAllExtensions(t *testing.T) {
+	// unregistered extension
+	desc := &proto.ExtensionDesc{
+		ExtendedType:  (*pb.MyMessage)(nil),
+		ExtensionType: (*bool)(nil),
+		Field:         101010100,
+		Name:          "emptyextension",
+		Tag:           "varint,0,opt",
+	}
+	m := &pb.MyMessage{}
+	if proto.HasExtension(m, desc) {
+		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
+	}
+	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
+		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
+	}
+	if !proto.HasExtension(m, desc) {
+		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
+	}
+	proto.ClearAllExtensions(m)
+	if proto.HasExtension(m, desc) {
+		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
+	}
+}