"""Utilities for tests."""
from __future__ import annotations
import inspect
import textwrap
import warnings
from contextlib import nullcontext
from functools import lru_cache, wraps
from types import GenericAlias
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
cast,
overload,
)
import pydantic
from pydantic import (
BaseModel,
ConfigDict,
PydanticDeprecationWarning,
RootModel,
root_validator,
)
from pydantic import (
create_model as _create_model_base,
)
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
JsonSchemaValue,
)
from pydantic_core import core_schema
[docs]
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise ValueError(msg)
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
[docs]
def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if PYDANTIC_MAJOR_VERSION == 1:
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV1):
return True
return False
[docs]
def is_pydantic_v2_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
[docs]
def is_basemodel_subclass(cls: type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.
Check if the given class is a subclass of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
# Before we can use issubclass on the cls we need to check if it is a class
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
return False
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
if issubclass(cls, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV2):
return True
if issubclass(cls, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise ValueError(msg)
return False
[docs]
def is_basemodel_instance(obj: Any) -> bool:
"""Check if the given class is an instance of Pydantic BaseModel.
Check if the given class is an instance of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
if isinstance(obj, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if isinstance(obj, BaseModelV2):
return True
if isinstance(obj, BaseModelV1):
return True
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise ValueError(msg)
return False
# How to type hint this?
[docs]
def pre_init(func: Callable) -> Any:
"""Decorator to run a function before model initialization.
Args:
func (Callable): The function to run before model initialization.
Returns:
Any: The decorated function.
"""
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
"""Decorator to run a function before model initialization.
Args:
cls (Type[BaseModel]): The model class.
values (Dict[str, Any]): The values to initialize the model with.
Returns:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.model_fields
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
if (
hasattr(cls, "Config")
and hasattr(cls.Config, "allow_population_by_field_name")
and cls.Config.allow_population_by_field_name
and field_info.alias in values
):
values[name] = values.pop(field_info.alias)
if (
hasattr(cls, "model_config")
and cls.model_config.get("populate_by_name")
and field_info.alias in values
):
values[name] = values.pop(field_info.alias)
if (
name not in values or values[name] is None
) and not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
values[name] = field_info.default
# Call the decorated function
return func(cls, values)
return wrapper
class _IgnoreUnserializable(GenerateJsonSchema):
"""A JSON schema generator that ignores unknown types.
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
"""
def handle_invalid_for_json_schema(
self, schema: core_schema.CoreSchema, error_info: str
) -> JsonSchemaValue:
return {}
def _create_subset_model_v1(
name: str,
model: type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import create_model
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import create_model # type: ignore
else:
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise NotImplementedError(msg)
fields = {}
for field_name in field_names:
# Using pydantic v1 so can access __fields__ as a dict.
field = model.__fields__[field_name] # type: ignore
t = (
# this isn't perfect but should work for most functions
field.outer_type_
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _create_subset_model_v2(
name: str,
model: type[pydantic.BaseModel],
field_names: list[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import ConfigDict, create_model
from pydantic.fields import FieldInfo
descriptions_ = descriptions or {}
fields = {}
for field_name in field_names:
field = model.model_fields[field_name] # type: ignore
description = descriptions_.get(field_name, field.description)
field_info = FieldInfo(description=description, default=field.default)
if field.metadata:
field_info.metadata = field.metadata
fields[field_name] = (field.annotation, field_info)
rtn = create_model( # type: ignore
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
)
# TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations.
# This is done to preserve __annotations__ when working with pydantic 2.x
# and using the Annotated type with TypedDict.
# Comment out the following line, to trigger the relevant test case.
selected_annotations = [
(name, annotation)
for name, annotation in model.__annotations__.items()
if name in field_names
]
rtn.__annotations__ = dict(selected_annotations)
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
# Private functionality to create a subset model that's compatible across
# different versions of pydantic.
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
# However, can't find a way to type hint this.
def _create_subset_model(
name: str,
model: TypeBaseModel,
field_names: list[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> type[BaseModel]:
"""Create subset model using the same pydantic version as the input model."""
if PYDANTIC_MAJOR_VERSION == 1:
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
else:
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
else:
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise NotImplementedError(msg)
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.fields import FieldInfo as FieldInfoV2
from pydantic.v1 import BaseModel as BaseModelV1
@overload
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
@overload
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
type[BaseModelV2],
type[BaseModelV1],
],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields # type: ignore
elif hasattr(model, "__fields__"):
return model.__fields__ # type: ignore
else:
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_
[docs]
def get_fields( # type: ignore[no-redef]
model: Union[type[BaseModelV1_], BaseModelV1_],
) -> dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore
else:
msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise ValueError(msg)
_SchemaConfig = ConfigDict(
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
)
NO_DEFAULT = object()
def _create_root_model(
name: str,
type_: Any,
module_name: Optional[str] = None,
default_: object = NO_DEFAULT,
) -> type[BaseModel]:
"""Create a base class."""
def schema(
cls: type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_
def model_json_schema(
cls: type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_
base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": module_name or "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_
with warnings.catch_warnings():
try:
if (
isinstance(type_, type)
and not isinstance(type_, GenericAlias)
and issubclass(type_, BaseModelV1)
):
warnings.filterwarnings(
action="ignore", category=PydanticDeprecationWarning
)
except TypeError:
pass
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(type[BaseModel], custom_root_type)
@lru_cache(maxsize=256)
def _create_root_model_cached(
model_name: str,
type_: Any,
*,
module_name: Optional[str] = None,
default_: object = NO_DEFAULT,
) -> type[BaseModel]:
return _create_root_model(
model_name, type_, default_=default_, module_name=module_name
)
@lru_cache(maxsize=256)
def _create_model_cached(
__model_name: str,
**field_definitions: Any,
) -> type[BaseModel]:
return _create_model_base(
__model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)
[docs]
def create_model(
__model_name: str,
__module_name: Optional[str] = None,
**field_definitions: Any,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Please use create_model_v2 instead of this function.
Args:
__model_name: The name of the model.
__module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
**field_definitions: The field definitions for the model.
Returns:
Type[BaseModel]: The created model.
"""
kwargs = {}
if "__root__" in field_definitions:
kwargs["root"] = field_definitions.pop("__root__")
return create_model_v2(
__model_name,
module_name=__module_name,
field_definitions=field_definitions,
**kwargs,
)
# Reserved names should capture all the `public` names / methods that are
# used by BaseModel internally. This will keep the reserved names up-to-date.
# For reference, the reserved names are:
# "construct", "copy", "dict", "from_orm", "json", "parse_file", "parse_obj",
# "parse_raw", "schema", "schema_json", "update_forward_refs", "validate",
# "model_computed_fields", "model_config", "model_construct", "model_copy",
# "model_dump", "model_dump_json", "model_extra", "model_fields",
# "model_fields_set", "model_json_schema", "model_parametrized_name",
# "model_post_init", "model_rebuild", "model_validate", "model_validate_json",
# "model_validate_strings"
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
"""This remaps fields to avoid colliding with internal pydantic fields."""
from pydantic import Field
from pydantic.fields import FieldInfo
remapped = {}
for key, value in field_definitions.items():
if key.startswith("_") or key in _RESERVED_NAMES:
# Let's add a prefix to avoid colliding with internal pydantic fields
if isinstance(value, FieldInfo):
msg = (
f"Remapping for fields starting with '_' or fields with a name "
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
f" the field is a pydantic Field instance. Got {key}."
)
raise NotImplementedError(msg)
type_, default_ = value
remapped[f"private_{key}"] = (
type_,
Field(
default=default_,
alias=key,
serialization_alias=key,
title=key.lstrip("_").replace("_", " ").title(),
),
)
else:
remapped[key] = value
return remapped
[docs]
def create_model_v2(
model_name: str,
*,
module_name: Optional[str] = None,
field_definitions: Optional[dict[str, Any]] = None,
root: Optional[Any] = None,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Attention:
Please do not use outside of langchain packages. This API
is subject to change at any time.
Args:
model_name: The name of the model.
module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
field_definitions: The field definitions for the model.
root: Type for a root model (RootModel)
Returns:
Type[BaseModel]: The created model.
"""
field_definitions = cast(dict[str, Any], field_definitions or {}) # type: ignore[no-redef]
if root:
if field_definitions:
msg = (
"When specifying __root__ no other "
f"fields should be provided. Got {field_definitions}"
)
raise NotImplementedError(msg)
if isinstance(root, tuple):
kwargs = {"type_": root[0], "default_": root[1]}
else:
kwargs = {"type_": root}
try:
named_root_model = _create_root_model_cached(
model_name, module_name=module_name, **kwargs
)
except TypeError:
# something in the arguments into _create_root_model_cached is not hashable
named_root_model = _create_root_model(
model_name,
module_name=module_name,
**kwargs,
)
return named_root_model
# No root, just field definitions
names = set(field_definitions.keys())
capture_warnings = False
for name in names:
# Also if any non-reserved name is used (e.g., model_id or model_name)
if name.startswith("model"):
capture_warnings = True
with warnings.catch_warnings() if capture_warnings else nullcontext(): # type: ignore[attr-defined]
if capture_warnings:
warnings.filterwarnings(action="ignore")
try:
return _create_model_cached(model_name, **field_definitions)
except TypeError:
# something in field definitions is not hashable
return _create_model_base(
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)