"""Utility functions for parsing an OpenAPI spec."""
from __future__ import annotations
import copy
import json
import logging
import re
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import requests
import yaml
from pydantic import ValidationError
logger = logging.getLogger(__name__)
[docs]
class HTTPVerb(str, Enum):
    """Enumerator of the HTTP verbs."""
    GET = "get"
    PUT = "put"
    POST = "post"
    DELETE = "delete"
    OPTIONS = "options"
    HEAD = "head"
    PATCH = "patch"
    TRACE = "trace"
    @classmethod
    def from_str(cls, verb: str) -> HTTPVerb:
        """Parse an HTTP verb."""
        try:
            return cls(verb)
        except ValueError:
            raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}") 
if TYPE_CHECKING:
    from openapi_pydantic import (
        Components,
        Operation,
        Parameter,
        PathItem,
        Paths,
        Reference,
        RequestBody,
        Schema,
    )
try:
    from openapi_pydantic import OpenAPI
except ImportError:
    OpenAPI = object
[docs]
class OpenAPISpec(OpenAPI):
    """OpenAPI Model that removes mis-formatted parts of the spec."""
    openapi: str = "3.1.0"  # overriding overly restrictive type from parent class
    @property
    def _paths_strict(self) -> Paths:
        if not self.paths:
            raise ValueError("No paths found in spec")
        return self.paths
    def _get_path_strict(self, path: str) -> PathItem:
        path_item = self._paths_strict.get(path)
        if not path_item:
            raise ValueError(f"No path found for {path}")
        return path_item
    @property
    def _components_strict(self) -> Components:
        """Get components or err."""
        if self.components is None:
            raise ValueError("No components found in spec. ")
        return self.components
    @property
    def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
        """Get parameters or err."""
        parameters = self._components_strict.parameters
        if parameters is None:
            raise ValueError("No parameters found in spec. ")
        return parameters
    @property
    def _schemas_strict(self) -> Dict[str, Schema]:
        """Get the dictionary of schemas or err."""
        schemas = self._components_strict.schemas
        if schemas is None:
            raise ValueError("No schemas found in spec. ")
        return schemas
    @property
    def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
        """Get the request body or err."""
        request_bodies = self._components_strict.requestBodies
        if request_bodies is None:
            raise ValueError("No request body found in spec. ")
        return request_bodies
    def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
        """Get a parameter (or nested reference) or err."""
        ref_name = ref.ref.split("/")[-1]
        parameters = self._parameters_strict
        if ref_name not in parameters:
            raise ValueError(f"No parameter found for {ref_name}")
        return parameters[ref_name]
    def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
        """Get the root reference or err."""
        from openapi_pydantic import Reference
        parameter = self._get_referenced_parameter(ref)
        while isinstance(parameter, Reference):
            parameter = self._get_referenced_parameter(parameter)
        return parameter
[docs]
    def get_referenced_schema(self, ref: Reference) -> Schema:
        """Get a schema (or nested reference) or err."""
        ref_name = ref.ref.split("/")[-1]
        schemas = self._schemas_strict
        if ref_name not in schemas:
            raise ValueError(f"No schema found for {ref_name}")
        return schemas[ref_name] 
[docs]
    def get_schema(
        self,
        schema: Union[Reference, Schema],
        depth: int = 0,
        max_depth: Optional[int] = None,
    ) -> Schema:
        if max_depth is not None and depth >= max_depth:
            raise RecursionError(
                f"Max depth of {max_depth} has been exceeded when resolving references."
            )
        from openapi_pydantic import Reference
        if isinstance(schema, Reference):
            schema = self.get_referenced_schema(schema)
        # TODO: Resolve references on all fields of Schema ?
        # (e.g. patternProperties, etc...)
        if schema.properties is not None:
            for p_name, p in schema.properties.items():
                schema.properties[p_name] = self.get_schema(p, depth + 1, max_depth)
        if schema.items is not None:
            schema.items = self.get_schema(schema.items, depth + 1, max_depth)
        return schema 
    def _get_root_referenced_schema(self, ref: Reference) -> Schema:
        """Get the root reference or err."""
        from openapi_pydantic import Reference
        schema = self.get_referenced_schema(ref)
        while isinstance(schema, Reference):
            schema = self.get_referenced_schema(schema)
        return schema
    def _get_referenced_request_body(
        self, ref: Reference
    ) -> Optional[Union[Reference, RequestBody]]:
        """Get a request body (or nested reference) or err."""
        ref_name = ref.ref.split("/")[-1]
        request_bodies = self._request_bodies_strict
        if ref_name not in request_bodies:
            raise ValueError(f"No request body found for {ref_name}")
        return request_bodies[ref_name]
    def _get_root_referenced_request_body(
        self, ref: Reference
    ) -> Optional[RequestBody]:
        """Get the root request Body or err."""
        from openapi_pydantic import Reference
        request_body = self._get_referenced_request_body(ref)
        while isinstance(request_body, Reference):
            request_body = self._get_referenced_request_body(request_body)
        return request_body
    @staticmethod
    def _alert_unsupported_spec(obj: dict) -> None:
        """Alert if the spec is not supported."""
        warning_message = (
            " This may result in degraded performance."
            + " Convert your OpenAPI spec to 3.1.* spec"
            + " for better support."
        )
        swagger_version = obj.get("swagger")
        openapi_version = obj.get("openapi")
        if isinstance(openapi_version, str):
            if openapi_version != "3.1.0":
                logger.warning(
                    f"Attempting to load an OpenAPI {openapi_version}"
                    f" spec. {warning_message}"
                )
            else:
                pass
        elif isinstance(swagger_version, str):
            logger.warning(
                f"Attempting to load a Swagger {swagger_version}"
                f" spec. {warning_message}"
            )
        else:
            raise ValueError(
                f"Attempting to load an unsupported spec:\n\n{obj}\n{warning_message}"
            )
[docs]
    @classmethod
    def parse_obj(cls, obj: dict) -> OpenAPISpec:
        try:
            cls._alert_unsupported_spec(obj)
            return super().parse_obj(obj)
        except ValidationError as e:
            # We are handling possibly misconfigured specs and
            # want to do a best-effort job to get a reasonable interface out of it.
            new_obj = copy.deepcopy(obj)
            for error in e.errors():
                keys = error["loc"]
                item = new_obj
                for key in keys[:-1]:
                    item = item[key]
                item.pop(keys[-1], None)
            return cls.parse_obj(new_obj) 
[docs]
    @classmethod
    def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
        """Get an OpenAPI spec from a dict."""
        return cls.parse_obj(spec_dict) 
[docs]
    @classmethod
    def from_text(cls, text: str) -> OpenAPISpec:
        """Get an OpenAPI spec from a text."""
        try:
            spec_dict = json.loads(text)
        except json.JSONDecodeError:
            spec_dict = yaml.safe_load(text)
        return cls.from_spec_dict(spec_dict) 
[docs]
    @classmethod
    def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
        """Get an OpenAPI spec from a file path."""
        path_ = path if isinstance(path, Path) else Path(path)
        if not path_.exists():
            raise FileNotFoundError(f"{path} does not exist")
        with path_.open("r") as f:
            return cls.from_text(f.read()) 
[docs]
    @classmethod
    def from_url(cls, url: str) -> OpenAPISpec:
        """Get an OpenAPI spec from a URL."""
        response = requests.get(url)
        return cls.from_text(response.text) 
    @property
    def base_url(self) -> str:
        """Get the base url."""
        return self.servers[0].url
[docs]
    def get_methods_for_path(self, path: str) -> List[str]:
        """Return a list of valid methods for the specified path."""
        from openapi_pydantic import Operation
        path_item = self._get_path_strict(path)
        results = []
        for method in HTTPVerb:
            operation = getattr(path_item, method.value, None)
            if isinstance(operation, Operation):
                results.append(method.value)
        return results 
[docs]
    def get_parameters_for_path(self, path: str) -> List[Parameter]:
        from openapi_pydantic import Reference
        path_item = self._get_path_strict(path)
        parameters = []
        if not path_item.parameters:
            return []
        for parameter in path_item.parameters:
            if isinstance(parameter, Reference):
                parameter = self._get_root_referenced_parameter(parameter)
            parameters.append(parameter)
        return parameters 
[docs]
    def get_operation(self, path: str, method: str) -> Operation:
        """Get the operation object for a given path and HTTP method."""
        from openapi_pydantic import Operation
        path_item = self._get_path_strict(path)
        operation_obj = getattr(path_item, method, None)
        if not isinstance(operation_obj, Operation):
            raise ValueError(f"No {method} method found for {path}")
        return operation_obj 
[docs]
    def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
        """Get the components for a given operation."""
        from openapi_pydantic import Reference
        parameters = []
        if operation.parameters:
            for parameter in operation.parameters:
                if isinstance(parameter, Reference):
                    parameter = self._get_root_referenced_parameter(parameter)
                parameters.append(parameter)
        return parameters 
[docs]
    def get_request_body_for_operation(
        self, operation: Operation
    ) -> Optional[RequestBody]:
        """Get the request body for a given operation."""
        from openapi_pydantic import Reference
        request_body = operation.requestBody
        if isinstance(request_body, Reference):
            request_body = self._get_root_referenced_request_body(request_body)
        return request_body 
[docs]
    @staticmethod
    def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
        """Get a cleaned operation id from an operation id."""
        operation_id = operation.operationId
        if operation_id is None:
            # Replace all punctuation of any kind with underscore
            path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
            operation_id = f"{path}_{method}"
        return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")