b/161441872 Add extension validation logic for extended validator Change-Id: If26bded391c83ab86e23bcd1c3049e9bf13022f9
diff --git a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java index 18400b8..2679e54 100644 --- a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java +++ b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidator.java
@@ -1,5 +1,7 @@ package com.apigee.security.oas.extendedvalidator; +import static com.apigee.security.oas.extendedvalidator.ExtensionName.valueOfExtensionName; + import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.flogger.FluentLogger; @@ -10,19 +12,43 @@ private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private final TraversalHelperFactory traversalHelperFactory; + private final ExtensionSchemaValidator schemaValidator; + // TODO(b/161441872) : Inject ExtensionScopeValidator @Inject - BaseExtendedValidator(TraversalHelperFactory traversalHelperFactory) { + BaseExtendedValidator( + TraversalHelperFactory traversalHelperFactory, ExtensionSchemaValidator schemaValidator) { this.traversalHelperFactory = traversalHelperFactory; + this.schemaValidator = schemaValidator; } @Override - public void validate(OpenApi3 openApiSpec) { + public ImmutableSet<ExtensionValidationMessage> validate(OpenApi3 openApiSpec) { + ImmutableSet<Extension> extensions = collectExtensions(openApiSpec); + + logger.atInfo().log("Found %s extensions", extensions.size()); + + return validateSchemaAndScope(extensions); + } + + private ImmutableSet<Extension> collectExtensions(OpenApi3 openApiSpec) { TraversalHelper traversalHelper = traversalHelperFactory.create(ImmutableList.of()); traversalHelper.sendOpenApiTraversal(openApiSpec); - ImmutableSet<Extension> extensions = traversalHelper.traverse(); + return traversalHelper.traverse(); + } - // TODO(b/161441872) : Add extension validation logic - logger.atInfo().log("%s", extensions); + private ImmutableSet<ExtensionValidationMessage> validateSchemaAndScope( + ImmutableSet<Extension> extensions) { + + return + extensions.stream() + .filter(extension -> valueOfExtensionName(extension.getExtensionName()).isPresent()) + .flatMap( + extension -> + ImmutableSet.<ExtensionValidationMessage>builder() + .addAll(extension.validate(schemaValidator)).build().stream()) + .collect(ImmutableSet.toImmutableSet()); + + // TODO(b/161441872) : Add extension scope validation logic } }
diff --git a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java index dae3f6b..44cc68c 100644 --- a/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java +++ b/oas-core/src/main/java/com/apigee/security/oas/extendedvalidator/ExtendedValidator.java
@@ -1,5 +1,6 @@ package com.apigee.security.oas.extendedvalidator; +import com.google.common.collect.ImmutableSet; import org.openapi4j.parser.model.v3.OpenApi3; /** @@ -9,5 +10,5 @@ public interface ExtendedValidator { /** Validates ApiSecurityTool's extensions in passed {@link OpenApi3} object. */ - void validate(OpenApi3 openApi3); + ImmutableSet<ExtensionValidationMessage> validate(OpenApi3 openApi3); }
diff --git a/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java b/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java index 46871da..0ff912a 100644 --- a/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java +++ b/oas-core/src/test/java/com/apigee/security/oas/extendedvalidator/BaseExtendedValidatorTest.java
@@ -1,10 +1,15 @@ package com.apigee.security.oas.extendedvalidator; +import static com.apigee.security.oas.extendedvalidator.ExtensionName.X_SECURITY_TYPE; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -23,7 +28,10 @@ @Rule public final MockitoRule rule = MockitoJUnit.rule().strictness(Strictness.STRICT_STUBS); @Mock private OpenApi3 openApiSpec; + @Mock private Extension supportedExtension; + @Mock private Extension unsupportedExtension; + @Mock private ExtensionSchemaValidator extensionSchemaValidator; @Mock private TraversalHelperFactory traversalHelperFactory; @Mock private TraversalHelper traversalHelper; @InjectMocks private BaseExtendedValidator baseExtendedValidator; @@ -32,12 +40,48 @@ @Before public void setup() { when(traversalHelperFactory.create(any(ImmutableList.class))).thenReturn(traversalHelper); + when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(unsupportedExtension)); } @Test public void validate_openApiSpec_callsTraversalCoordinatorTraverse() { + when(unsupportedExtension.getExtensionName()).thenReturn("x-custom"); + baseExtendedValidator.validate(openApiSpec); verify(traversalHelper).sendOpenApiTraversal(openApiSpec); } + + @Test + public void validate_supportedExtension_callsExtensionValidate() { + when(supportedExtension.getExtensionName()).thenReturn(X_SECURITY_TYPE.getExtensionName()); + when(supportedExtension.validate(extensionSchemaValidator)).thenReturn(ImmutableSet.of()); + when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(supportedExtension)); + + baseExtendedValidator.validate(openApiSpec); + + verify(supportedExtension, atLeastOnce()).validate(any(ExtensionValidator.class)); + } + + @Test + public void validate_unsupportedExtension_shouldNotCallExtensionValidate() { + when(unsupportedExtension.getExtensionName()).thenReturn("x-custom"); + + baseExtendedValidator.validate(openApiSpec); + + verify(unsupportedExtension, never()).validate(any(ExtensionValidator.class)); + } + + @Test + public void validate_withExtensions_returnsExtensionValidationMessage() { + ExtensionValidationMessage message = + ExtensionValidationMessage.builder().setType("").setMessage("").setPath("").build(); + + when(supportedExtension.getExtensionName()).thenReturn(X_SECURITY_TYPE.getExtensionName()); + when(traversalHelper.traverse()).thenReturn(ImmutableSet.of(supportedExtension)); + when(supportedExtension.validate(any(ExtensionValidator.class))) + .thenReturn(ImmutableSet.of(message)); + + assertThat(baseExtendedValidator.validate(openApiSpec)).hasSize(1); + } }