"""Utilities for tests."""
from __future__ import annotations
import inspect
import textwrap
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
import pydantic # pydantic: ignore
from langchain_core.pydantic_v1 import BaseModel, root_validator
[docs]def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
[docs]def is_pydantic_v1_subclass(cls: Type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if PYDANTIC_MAJOR_VERSION == 1:
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV1):
return True
return False
[docs]def is_pydantic_v2_subclass(cls: Type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
[docs]def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.
Check if the given class is a subclass of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
# Before we can use issubclass on the cls we need to check if it is a class
if not inspect.isclass(cls):
return False
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
if issubclass(cls, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if issubclass(cls, BaseModelV2):
return True
if issubclass(cls, BaseModelV1):
return True
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
return False
[docs]def is_basemodel_instance(obj: Any) -> bool:
"""Check if the given class is an instance of Pydantic BaseModel.
Check if the given class is an instance of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
if isinstance(obj, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if isinstance(obj, BaseModelV2):
return True
if isinstance(obj, BaseModelV1):
return True
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
return False
# How to type hint this?
[docs]def pre_init(func: Callable) -> Any:
"""Decorator to run a function before model initialization.
Args:
func (Callable): The function to run before model initialization.
Returns:
Any: The decorated function.
"""
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
"""Decorator to run a function before model initialization.
Args:
cls (Type[BaseModel]): The model class.
values (Dict[str, Any]): The values to initialize the model with.
Returns:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.__fields__
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
if hasattr(cls, "Config"):
if hasattr(cls.Config, "allow_population_by_field_name"):
if cls.Config.allow_population_by_field_name:
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None:
if not field_info.required:
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
values[name] = field_info.default
# Call the decorated function
return func(cls, values)
return wrapper
def _create_subset_model_v1(
name: str,
model: Type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
from langchain_core.pydantic_v1 import create_model
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
t = (
# this isn't perfect but should work for most functions
field.outer_type_
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _create_subset_model_v2(
name: str,
model: Type[pydantic.BaseModel],
field_names: List[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import create_model # pydantic: ignore
from pydantic.fields import FieldInfo # pydantic: ignore
descriptions_ = descriptions or {}
fields = {}
for field_name in field_names:
field = model.model_fields[field_name] # type: ignore
description = descriptions_.get(field_name, field.description)
field_info = FieldInfo(description=description, default=field.default)
if field.metadata:
field_info.metadata = field.metadata
fields[field_name] = (field.annotation, field_info)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
# Private functionality to create a subset model that's compatible across
# different versions of pydantic.
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
# However, can't find a way to type hint this.
def _create_subset_model(
name: str,
model: TypeBaseModel,
field_names: List[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create subset model using the same pydantic version as the input model."""
if PYDANTIC_MAJOR_VERSION == 1:
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
else:
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
else:
raise NotImplementedError(
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
)
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.fields import FieldInfo as FieldInfoV2
from pydantic.v1 import BaseModel as BaseModelV1
@overload
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
@overload
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
Type[BaseModelV2],
Type[BaseModelV1],
],
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields # type: ignore
elif hasattr(model, "__fields__"):
return model.__fields__ # type: ignore
else:
raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_
[docs] def get_fields( # type: ignore[no-redef]
model: Union[Type[BaseModelV1_], BaseModelV1_],
) -> Dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")