Merge "b/161441872 Add extension validation logic for extended validator"
diff --git a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java
index 18400b8..2679e54 100644
--- a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java
+++ b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java
@@ -1,5 +1,7 @@
package com.apigee.security.oas.extendedvalidator;
+import static com.apigee.security.oas.extendedvalidator.ExtensionName.valueOfExtensionName;
+
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.FluentLogger;
@@ -10,19 +12,43 @@
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private final TraversalHelperFactory traversalHelperFactory;
+ private final ExtensionSchemaValidator schemaValidator;
+ // TODO(b/161441872) : Inject ExtensionScopeValidator
@Inject
- BaseExtendedValidator(TraversalHelperFactory traversalHelperFactory) {
+ BaseExtendedValidator(
+ TraversalHelperFactory traversalHelperFactory, ExtensionSchemaValidator schemaValidator) {
this.traversalHelperFactory = traversalHelperFactory;
+ this.schemaValidator = schemaValidator;
}
@Override
- public void validate(OpenApi3 openApiSpec) {
+ public ImmutableSet<ExtensionValidationMessage> validate(OpenApi3 openApiSpec) {
+ ImmutableSet<Extension> extensions = collectExtensions(openApiSpec);
+
+ logger.atInfo().log("Found %s extensions", extensions.size());
+
+ return validateSchemaAndScope(extensions);
+ }
+
+ private ImmutableSet<Extension> collectExtensions(OpenApi3 openApiSpec) {
TraversalHelper traversalHelper = traversalHelperFactory.create(ImmutableList.of());
traversalHelper.sendOpenApiTraversal(openApiSpec);
- ImmutableSet<Extension> extensions = traversalHelper.traverse();
+ return traversalHelper.traverse();
+ }
- // TODO(b/161441872) : Add extension validation logic
- logger.atInfo().log("%s", extensions);
+ private ImmutableSet<ExtensionValidationMessage> validateSchemaAndScope(
+ ImmutableSet<Extension> extensions) {
+
+ return
+ extensions.stream()
+ .filter(extension -> valueOfExtensionName(extension.getExtensionName()).isPresent())
+ .flatMap(
+ extension ->
+ ImmutableSet.<ExtensionValidationMessage>builder()
+ .addAll(extension.validate(schemaValidator)).build().stream())
+ .collect(ImmutableSet.toImmutableSet());
+
+ // TODO(b/161441872) : Add extension scope validation logic
}
}
diff --git a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java
index dae3f6b..44cc68c 100644
--- a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java
+++ b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java
@@ -1,5 +1,6 @@
package com.apigee.security.oas.extendedvalidator;
+import com.google.common.collect.ImmutableSet;
import org.openapi4j.parser.model.v3.OpenApi3;
/**
@@ -9,5 +10,5 @@
public interface ExtendedValidator {
/** Validates ApiSecurityTool's extensions in passed {@link OpenApi3} object. */
- void validate(OpenApi3 openApi3);
+ ImmutableSet<ExtensionValidationMessage> validate(OpenApi3 openApi3);
}
diff --git a/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java b/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java
index 46871da..0ff912a 100644
--- a/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java
+++ b/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java
@@ -1,10 +1,15 @@
package com.apigee.security.oas.extendedvalidator;
+import static com.apigee.security.oas.extendedvalidator.ExtensionName.X_SECURITY_TYPE;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -23,7 +28,10 @@
@Rule public final MockitoRule rule = MockitoJUnit.rule().strictness(Strictness.STRICT_STUBS);
@Mock private OpenApi3 openApiSpec;
+ @Mock private Extension supportedExtension;
+ @Mock private Extension unsupportedExtension;
+ @Mock private ExtensionSchemaValidator extensionSchemaValidator;
@Mock private TraversalHelperFactory traversalHelperFactory;
@Mock private TraversalHelper traversalHelper;
@InjectMocks private BaseExtendedValidator baseExtendedValidator;
@@ -32,12 +40,48 @@
@Before
public void setup() {
when(traversalHelperFactory.create(any(ImmutableList.class))).thenReturn(traversalHelper);
+ when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(unsupportedExtension));
}
@Test
public void validate_openApiSpec_callsTraversalCoordinatorTraverse() {
+ when(unsupportedExtension.getExtensionName()).thenReturn("x-custom");
+
baseExtendedValidator.validate(openApiSpec);
verify(traversalHelper).sendOpenApiTraversal(openApiSpec);
}
+
+ @Test
+ public void validate_supportedExtension_callsExtensionValidate() {
+ when(supportedExtension.getExtensionName()).thenReturn(X_SECURITY_TYPE.getExtensionName());
+ when(supportedExtension.validate(extensionSchemaValidator)).thenReturn(ImmutableSet.of());
+ when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(supportedExtension));
+
+ baseExtendedValidator.validate(openApiSpec);
+
+ verify(supportedExtension, atLeastOnce()).validate(any(ExtensionValidator.class));
+ }
+
+ @Test
+ public void validate_unsupportedExtension_shouldNotCallExtensionValidate() {
+ when(unsupportedExtension.getExtensionName()).thenReturn("x-custom");
+
+ baseExtendedValidator.validate(openApiSpec);
+
+ verify(unsupportedExtension, never()).validate(any(ExtensionValidator.class));
+ }
+
+ @Test
+ public void validate_withExtensions_returnsExtensionValidationMessage() {
+ ExtensionValidationMessage message =
+ ExtensionValidationMessage.builder().setType("").setMessage("").setPath("").build();
+
+ when(supportedExtension.getExtensionName()).thenReturn(X_SECURITY_TYPE.getExtensionName());
+ when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(supportedExtension));
+ when(supportedExtension.validate(any(ExtensionValidator.class)))
+ .thenReturn(ImmutableSet.of(message));
+
+ assertThat(baseExtendedValidator.validate(openApiSpec)).hasSize(1);
+ }
}