"""Pydantic models for parsing an OpenAPI spec."""
from __future__ import annotations
import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
logger = logging.getLogger(__name__)
PRIMITIVE_TYPES = {
"integer": int,
"number": float,
"string": str,
"boolean": bool,
"array": List,
"object": Dict,
"null": None,
}
# See https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#parameterIn
# for more info.
[docs]class APIPropertyLocation(Enum):
"""The location of the property."""
QUERY = "query"
PATH = "path"
HEADER = "header"
COOKIE = "cookie" # Not yet supported
@classmethod
def from_str(cls, location: str) -> "APIPropertyLocation":
"""Parse an APIPropertyLocation."""
try:
return cls(location)
except ValueError:
raise ValueError(
f"Invalid APIPropertyLocation. Valid values are {cls.__members__}"
)
_SUPPORTED_MEDIA_TYPES = ("application/json",)
SUPPORTED_LOCATIONS = {
APIPropertyLocation.HEADER,
APIPropertyLocation.QUERY,
APIPropertyLocation.PATH,
}
INVALID_LOCATION_TEMPL = (
'Unsupported APIPropertyLocation "{location}"'
" for parameter {name}. "
+ f"Valid values are {[loc.value for loc in SUPPORTED_LOCATIONS]}"
)
SCHEMA_TYPE = Union[str, Type, tuple, None, Enum]
[docs]class APIPropertyBase(BaseModel):
"""Base model for an API property."""
# The name of the parameter is required and is case-sensitive.
# If "in" is "path", the "name" field must correspond to a template expression
# within the path field in the Paths Object.
# If "in" is "header" and the "name" field is "Accept", "Content-Type",
# or "Authorization", the parameter definition is ignored.
# For all other cases, the "name" corresponds to the parameter
# name used by the "in" property.
name: str = Field(alias="name")
"""The name of the property."""
required: bool = Field(alias="required")
"""Whether the property is required."""
type: SCHEMA_TYPE = Field(alias="type")
"""The type of the property.
Either a primitive type, a component/parameter type,
or an array or 'object' (dict) of the above."""
default: Optional[Any] = Field(alias="default", default=None)
"""The default value of the property."""
description: Optional[str] = Field(alias="description", default=None)
"""The description of the property."""
if TYPE_CHECKING:
from openapi_pydantic import (
MediaType,
Parameter,
RequestBody,
Schema,
)
[docs]class APIProperty(APIPropertyBase):
"""A model for a property in the query, path, header, or cookie params."""
location: APIPropertyLocation = Field(alias="location")
"""The path/how it's being passed to the endpoint."""
@staticmethod
def _cast_schema_list_type(
schema: Schema,
) -> Optional[Union[str, Tuple[str, ...]]]:
type_ = schema.type
if not isinstance(type_, list):
return type_
else:
return tuple(type_)
@staticmethod
def _get_schema_type_for_enum(parameter: Parameter, schema: Schema) -> Enum:
"""Get the schema type when the parameter is an enum."""
param_name = f"{parameter.name}Enum"
return Enum(param_name, {str(v): v for v in schema.enum})
@staticmethod
def _get_schema_type_for_array(
schema: Schema,
) -> Optional[Union[str, Tuple[str, ...]]]:
from openapi_pydantic import (
Reference,
Schema,
)
items = schema.items
if isinstance(items, Schema):
schema_type = APIProperty._cast_schema_list_type(items)
elif isinstance(items, Reference):
ref_name = items.ref.split("/")[-1]
schema_type = ref_name # TODO: Add ref definitions to make his valid
else:
raise ValueError(f"Unsupported array items: {items}")
if isinstance(schema_type, str):
# TODO: recurse
schema_type = (schema_type,)
return schema_type
@staticmethod
def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
if schema is None:
return None
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
if schema_type == "array":
schema_type = APIProperty._get_schema_type_for_array(schema)
elif schema_type == "object":
# TODO: Resolve array and object types to components.
raise NotImplementedError("Objects not yet supported")
elif schema_type in PRIMITIVE_TYPES:
if schema.enum:
schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
else:
# Directly use the primitive type
pass
else:
raise NotImplementedError(f"Unsupported type: {schema_type}")
return schema_type
@staticmethod
def _validate_location(location: APIPropertyLocation, name: str) -> None:
if location not in SUPPORTED_LOCATIONS:
raise NotImplementedError(
INVALID_LOCATION_TEMPL.format(location=location, name=name)
)
@staticmethod
def _validate_content(content: Optional[Dict[str, MediaType]]) -> None:
if content:
raise ValueError(
"API Properties with media content not supported. "
"Media content only supported within APIRequestBodyProperty's"
)
@staticmethod
def _get_schema(parameter: Parameter, spec: OpenAPISpec) -> Optional[Schema]:
from openapi_pydantic import (
Reference,
Schema,
)
schema = parameter.param_schema
if isinstance(schema, Reference):
schema = spec.get_referenced_schema(schema)
elif schema is None:
return None
elif not isinstance(schema, Schema):
raise ValueError(f"Error dereferencing schema: {schema}")
return schema
[docs] @staticmethod
def is_supported_location(location: str) -> bool:
"""Return whether the provided location is supported."""
try:
return APIPropertyLocation.from_str(location) in SUPPORTED_LOCATIONS
except ValueError:
return False
[docs] @classmethod
def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
"""Instantiate from an OpenAPI Parameter."""
location = APIPropertyLocation.from_str(parameter.param_in)
cls._validate_location(
location,
parameter.name,
)
cls._validate_content(parameter.content)
schema = cls._get_schema(parameter, spec)
schema_type = cls._get_schema_type(parameter, schema)
default_val = schema.default if schema is not None else None
return cls(
name=parameter.name,
location=location,
default=default_val,
description=parameter.description,
required=parameter.required,
type=schema_type,
)
[docs]class APIRequestBodyProperty(APIPropertyBase):
"""A model for a request body property."""
properties: List["APIRequestBodyProperty"] = Field(alias="properties")
"""The sub-properties of the property."""
# This is useful for handling nested property cycles.
# We can define separate types in that case.
references_used: List[str] = Field(alias="references_used")
"""The references used by the property."""
@classmethod
def _process_object_schema(
cls, schema: Schema, spec: OpenAPISpec, references_used: List[str]
) -> Tuple[Union[str, List[str], None], List["APIRequestBodyProperty"]]:
from openapi_pydantic import (
Reference,
)
properties = []
required_props = schema.required or []
if schema.properties is None:
raise ValueError(
f"No properties found when processing object schema: {schema}"
)
for prop_name, prop_schema in schema.properties.items():
if isinstance(prop_schema, Reference):
ref_name = prop_schema.ref.split("/")[-1]
if ref_name not in references_used:
references_used.append(ref_name)
prop_schema = spec.get_referenced_schema(prop_schema)
else:
continue
properties.append(
cls.from_schema(
schema=prop_schema,
name=prop_name,
required=prop_name in required_props,
spec=spec,
references_used=references_used,
)
)
return schema.type, properties
@classmethod
def _process_array_schema(
cls,
schema: Schema,
name: str,
spec: OpenAPISpec,
references_used: List[str],
) -> str:
from openapi_pydantic import Reference, Schema
items = schema.items
if items is not None:
if isinstance(items, Reference):
ref_name = items.ref.split("/")[-1]
if ref_name not in references_used:
references_used.append(ref_name)
items = spec.get_referenced_schema(items)
else:
pass
return f"Array<{ref_name}>"
else:
pass
if isinstance(items, Schema):
array_type = cls.from_schema(
schema=items,
name=f"{name}Item",
required=True, # TODO: Add required
spec=spec,
references_used=references_used,
)
return f"Array<{array_type.type}>"
return "array"
[docs] @classmethod
def from_schema(
cls,
schema: Schema,
name: str,
required: bool,
spec: OpenAPISpec,
references_used: Optional[List[str]] = None,
) -> "APIRequestBodyProperty":
"""Recursively populate from an OpenAPI Schema."""
if references_used is None:
references_used = []
schema_type = schema.type
properties: List[APIRequestBodyProperty] = []
if schema_type == "object" and schema.properties:
schema_type, properties = cls._process_object_schema(
schema, spec, references_used
)
elif schema_type == "array":
schema_type = cls._process_array_schema(schema, name, spec, references_used)
elif schema_type in PRIMITIVE_TYPES:
# Use the primitive type directly
pass
elif schema_type is None:
# No typing specified/parsed. WIll map to 'any'
pass
else:
raise ValueError(f"Unsupported type: {schema_type}")
return cls(
name=name,
required=required,
type=schema_type,
default=schema.default,
description=schema.description,
properties=properties,
references_used=references_used,
)
# class APIRequestBodyProperty(APIPropertyBase):
[docs]class APIRequestBody(BaseModel):
"""A model for a request body."""
description: Optional[str] = Field(alias="description")
"""The description of the request body."""
properties: List[APIRequestBodyProperty] = Field(alias="properties")
# E.g., application/json - we only support JSON at the moment.
media_type: str = Field(alias="media_type")
"""The media type of the request body."""
@classmethod
def _process_supported_media_type(
cls,
media_type_obj: MediaType,
spec: OpenAPISpec,
) -> List[APIRequestBodyProperty]:
"""Process the media type of the request body."""
from openapi_pydantic import Reference
references_used = []
schema = media_type_obj.media_type_schema
if isinstance(schema, Reference):
references_used.append(schema.ref.split("/")[-1])
schema = spec.get_referenced_schema(schema)
if schema is None:
raise ValueError(
f"Could not resolve schema for media type: {media_type_obj}"
)
api_request_body_properties = []
required_properties = schema.required or []
if schema.type == "object" and schema.properties:
for prop_name, prop_schema in schema.properties.items():
if isinstance(prop_schema, Reference):
prop_schema = spec.get_referenced_schema(prop_schema)
api_request_body_properties.append(
APIRequestBodyProperty.from_schema(
schema=prop_schema,
name=prop_name,
required=prop_name in required_properties,
spec=spec,
)
)
else:
api_request_body_properties.append(
APIRequestBodyProperty(
name="body",
required=True,
type=schema.type,
default=schema.default,
description=schema.description,
properties=[],
references_used=references_used,
)
)
return api_request_body_properties
[docs] @classmethod
def from_request_body(
cls, request_body: RequestBody, spec: OpenAPISpec
) -> "APIRequestBody":
"""Instantiate from an OpenAPI RequestBody."""
properties = []
for media_type, media_type_obj in request_body.content.items():
if media_type not in _SUPPORTED_MEDIA_TYPES:
continue
api_request_body_properties = cls._process_supported_media_type(
media_type_obj,
spec,
)
properties.extend(api_request_body_properties)
return cls(
description=request_body.description,
properties=properties,
media_type=media_type,
)
# class APIRequestBodyProperty(APIPropertyBase):
# class APIRequestBody(BaseModel):
[docs]class APIOperation(BaseModel):
"""A model for a single API operation."""
operation_id: str = Field(alias="operation_id")
"""The unique identifier of the operation."""
description: Optional[str] = Field(alias="description")
"""The description of the operation."""
base_url: str = Field(alias="base_url")
"""The base URL of the operation."""
path: str = Field(alias="path")
"""The path of the operation."""
method: HTTPVerb = Field(alias="method")
"""The HTTP method of the operation."""
properties: Sequence[APIProperty] = Field(alias="properties")
# TODO: Add parse in used components to be able to specify what type of
# referenced object it is.
# """The properties of the operation."""
# components: Dict[str, BaseModel] = Field(alias="components")
request_body: Optional[APIRequestBody] = Field(alias="request_body")
"""The request body of the operation."""
@staticmethod
def _get_properties_from_parameters(
parameters: List[Parameter], spec: OpenAPISpec
) -> List[APIProperty]:
"""Get the properties of the operation."""
properties = []
for param in parameters:
if APIProperty.is_supported_location(param.param_in):
properties.append(APIProperty.from_parameter(param, spec))
elif param.required:
raise ValueError(
INVALID_LOCATION_TEMPL.format(
location=param.param_in, name=param.name
)
)
else:
logger.warning(
INVALID_LOCATION_TEMPL.format(
location=param.param_in, name=param.name
)
+ " Ignoring optional parameter"
)
pass
return properties
[docs] @classmethod
def from_openapi_url(
cls,
spec_url: str,
path: str,
method: str,
) -> "APIOperation":
"""Create an APIOperation from an OpenAPI URL."""
spec = OpenAPISpec.from_url(spec_url)
return cls.from_openapi_spec(spec, path, method)
[docs] @classmethod
def from_openapi_spec(
cls,
spec: OpenAPISpec,
path: str,
method: str,
) -> "APIOperation":
"""Create an APIOperation from an OpenAPI spec."""
operation = spec.get_operation(path, method)
parameters = spec.get_parameters_for_operation(operation)
properties = cls._get_properties_from_parameters(parameters, spec)
operation_id = OpenAPISpec.get_cleaned_operation_id(operation, path, method)
request_body = spec.get_request_body_for_operation(operation)
api_request_body = (
APIRequestBody.from_request_body(request_body, spec)
if request_body is not None
else None
)
description = operation.description or operation.summary
if not description and spec.paths is not None:
description = spec.paths[path].description or spec.paths[path].summary
return cls(
operation_id=operation_id,
description=description or "",
base_url=spec.base_url,
path=path,
method=method, # type: ignore[arg-type]
properties=properties,
request_body=api_request_body,
)
[docs] @staticmethod
def ts_type_from_python(type_: SCHEMA_TYPE) -> str:
if type_ is None:
# TODO: Handle Nones better. These often result when
# parsing specs that are < v3
return "any"
elif isinstance(type_, str):
return {
"str": "string",
"integer": "number",
"float": "number",
"date-time": "string",
}.get(type_, type_)
elif isinstance(type_, tuple):
return f"Array<{APIOperation.ts_type_from_python(type_[0])}>"
elif isinstance(type_, type) and issubclass(type_, Enum):
return " | ".join([f"'{e.value}'" for e in type_])
else:
return str(type_)
def _format_nested_properties(
self, properties: List[APIRequestBodyProperty], indent: int = 2
) -> str:
"""Format nested properties."""
formatted_props = []
for prop in properties:
prop_name = prop.name
prop_type = self.ts_type_from_python(prop.type)
prop_required = "" if prop.required else "?"
prop_desc = f"/* {prop.description} */" if prop.description else ""
if prop.properties:
nested_props = self._format_nested_properties(
prop.properties, indent + 2
)
prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"
formatted_props.append(
f"{prop_desc}\n{' ' * indent}{prop_name}"
f"{prop_required}: {prop_type},"
)
return "\n".join(formatted_props)
[docs] def to_typescript(self) -> str:
"""Get typescript string representation of the operation."""
operation_name = self.operation_id
params = []
if self.request_body:
formatted_request_body_props = self._format_nested_properties(
self.request_body.properties
)
params.append(formatted_request_body_props)
for prop in self.properties:
prop_name = prop.name
prop_type = self.ts_type_from_python(prop.type)
prop_required = "" if prop.required else "?"
prop_desc = f"/* {prop.description} */" if prop.description else ""
params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")
formatted_params = "\n".join(params).strip()
description_str = f"/* {self.description} */" if self.description else ""
typescript_definition = f"""
{description_str}
type {operation_name} = (_: {{
{formatted_params}
}}) => any;
"""
return typescript_definition.strip()
@property
def query_params(self) -> List[str]:
return [
property.name
for property in self.properties
if property.location == APIPropertyLocation.QUERY
]
@property
def path_params(self) -> List[str]:
return [
property.name
for property in self.properties
if property.location == APIPropertyLocation.PATH
]
@property
def body_params(self) -> List[str]:
if self.request_body is None:
return []
return [prop.name for prop in self.request_body.properties]