Merge "b/161441872 Add extension validation logic for extended validator"
diff --git a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java
index 18400b8..2679e54 100644
--- a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java
+++ b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java
@@ -1,5 +1,7 @@
 package com.apigee.security.oas.extendedvalidator;
 
+import static com.apigee.security.oas.extendedvalidator.ExtensionName.valueOfExtensionName;
+
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.flogger.FluentLogger;
@@ -10,19 +12,43 @@
 
   private static final FluentLogger logger = FluentLogger.forEnclosingClass();
   private final TraversalHelperFactory traversalHelperFactory;
+  private final ExtensionSchemaValidator schemaValidator;
 
+  // TODO(b/161441872) : Inject ExtensionScopeValidator
   @Inject
-  BaseExtendedValidator(TraversalHelperFactory traversalHelperFactory) {
+  BaseExtendedValidator(
+      TraversalHelperFactory traversalHelperFactory, ExtensionSchemaValidator schemaValidator) {
     this.traversalHelperFactory = traversalHelperFactory;
+    this.schemaValidator = schemaValidator;
   }
 
   @Override
-  public void validate(OpenApi3 openApiSpec) {
+  public ImmutableSet<ExtensionValidationMessage> validate(OpenApi3 openApiSpec) {
+    ImmutableSet<Extension> extensions = collectExtensions(openApiSpec);
+
+    logger.atInfo().log("Found %s extensions", extensions.size());
+
+    return validateSchemaAndScope(extensions);
+  }
+
+  private ImmutableSet<Extension> collectExtensions(OpenApi3 openApiSpec) {
     TraversalHelper traversalHelper = traversalHelperFactory.create(ImmutableList.of());
     traversalHelper.sendOpenApiTraversal(openApiSpec);
-    ImmutableSet<Extension> extensions = traversalHelper.traverse();
+    return traversalHelper.traverse();
+  }
 
-    // TODO(b/161441872) : Add extension validation logic
-    logger.atInfo().log("%s", extensions);
+  private ImmutableSet<ExtensionValidationMessage> validateSchemaAndScope(
+      ImmutableSet<Extension> extensions) {
+
+    return
+        extensions.stream()
+            .filter(extension -> valueOfExtensionName(extension.getExtensionName()).isPresent())
+            .flatMap(
+                extension ->
+                    ImmutableSet.<ExtensionValidationMessage>builder()
+                        .addAll(extension.validate(schemaValidator)).build().stream())
+            .collect(ImmutableSet.toImmutableSet());
+
+    // TODO(b/161441872) : Add extension scope validation logic
   }
 }
diff --git a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java
index dae3f6b..44cc68c 100644
--- a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java
+++ b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java
@@ -1,5 +1,6 @@
 package com.apigee.security.oas.extendedvalidator;
 
+import com.google.common.collect.ImmutableSet;
 import org.openapi4j.parser.model.v3.OpenApi3;
 
 /**
@@ -9,5 +10,5 @@
 public interface ExtendedValidator {
 
   /** Validates ApiSecurityTool's extensions in passed {@link OpenApi3} object. */
-  void validate(OpenApi3 openApi3);
+  ImmutableSet<ExtensionValidationMessage> validate(OpenApi3 openApi3);
 }
diff --git a/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java b/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java
index 46871da..0ff912a 100644
--- a/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java
+++ b/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java
@@ -1,10 +1,15 @@
 package com.apigee.security.oas.extendedvalidator;
 
+import static com.apigee.security.oas.extendedvalidator.ExtensionName.X_SECURITY_TYPE;
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -23,7 +28,10 @@
   @Rule public final MockitoRule rule = MockitoJUnit.rule().strictness(Strictness.STRICT_STUBS);
 
   @Mock private OpenApi3 openApiSpec;
+  @Mock private Extension supportedExtension;
+  @Mock private Extension unsupportedExtension;
 
+  @Mock private ExtensionSchemaValidator extensionSchemaValidator;
   @Mock private TraversalHelperFactory traversalHelperFactory;
   @Mock private TraversalHelper traversalHelper;
   @InjectMocks private BaseExtendedValidator baseExtendedValidator;
@@ -32,12 +40,48 @@
   @Before
   public void setup() {
     when(traversalHelperFactory.create(any(ImmutableList.class))).thenReturn(traversalHelper);
+    when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(unsupportedExtension));
   }
 
   @Test
   public void validate_openApiSpec_callsTraversalCoordinatorTraverse() {
+    when(unsupportedExtension.getExtensionName()).thenReturn("x-custom");
+
     baseExtendedValidator.validate(openApiSpec);
 
     verify(traversalHelper).sendOpenApiTraversal(openApiSpec);
   }
+
+  @Test
+  public void validate_supportedExtension_callsExtensionValidate() {
+    when(supportedExtension.getExtensionName()).thenReturn(X_SECURITY_TYPE.getExtensionName());
+    when(supportedExtension.validate(extensionSchemaValidator)).thenReturn(ImmutableSet.of());
+    when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(supportedExtension));
+
+    baseExtendedValidator.validate(openApiSpec);
+
+    verify(supportedExtension, atLeastOnce()).validate(any(ExtensionValidator.class));
+  }
+
+  @Test
+  public void validate_unsupportedExtension_shouldNotCallExtensionValidate() {
+    when(unsupportedExtension.getExtensionName()).thenReturn("x-custom");
+
+    baseExtendedValidator.validate(openApiSpec);
+
+    verify(unsupportedExtension, never()).validate(any(ExtensionValidator.class));
+  }
+
+  @Test
+  public void validate_withExtensions_returnsExtensionValidationMessage() {
+    ExtensionValidationMessage message =
+        ExtensionValidationMessage.builder().setType("").setMessage("").setPath("").build();
+
+    when(supportedExtension.getExtensionName()).thenReturn(X_SECURITY_TYPE.getExtensionName());
+    when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(supportedExtension));
+    when(supportedExtension.validate(any(ExtensionValidator.class)))
+        .thenReturn(ImmutableSet.of(message));
+
+    assertThat(baseExtendedValidator.validate(openApiSpec)).hasSize(1);
+  }
 }