"""Utilities for tests."""from__future__importannotationsimportinspectimporttextwrapfromfunctoolsimportwrapsfromtypingimportAny,Callable,Dict,List,Optional,Type,TypeVar,Union,overloadimportpydantic# pydantic: ignorefromlangchain_core.pydantic_v1importBaseModel,root_validator
[docs]defget_pydantic_major_version()->int:"""Get the major version of Pydantic."""try:importpydanticreturnint(pydantic.__version__.split(".")[0])exceptImportError:return0
PYDANTIC_MAJOR_VERSION=get_pydantic_major_version()ifPYDANTIC_MAJOR_VERSION==1:frompydantic.fieldsimportFieldInfoasFieldInfoV1PydanticBaseModel=pydantic.BaseModelTypeBaseModel=Type[BaseModel]elifPYDANTIC_MAJOR_VERSION==2:frompydantic.v1.fieldsimportFieldInfoasFieldInfoV1# type: ignore[assignment]# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.PydanticBaseModel=Union[BaseModel,pydantic.BaseModel]# type: ignoreTypeBaseModel=Union[Type[BaseModel],Type[pydantic.BaseModel]]# type: ignoreelse:raiseValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")TBaseModel=TypeVar("TBaseModel",bound=PydanticBaseModel)
[docs]defis_pydantic_v1_subclass(cls:Type)->bool:"""Check if the installed Pydantic version is 1.x-like."""ifPYDANTIC_MAJOR_VERSION==1:returnTrueelifPYDANTIC_MAJOR_VERSION==2:frompydantic.v1importBaseModelasBaseModelV1ifissubclass(cls,BaseModelV1):returnTruereturnFalse
[docs]defis_pydantic_v2_subclass(cls:Type)->bool:"""Check if the installed Pydantic version is 1.x-like."""frompydanticimportBaseModelreturnPYDANTIC_MAJOR_VERSION==2andissubclass(cls,BaseModel)
[docs]defis_basemodel_subclass(cls:Type)->bool:"""Check if the given class is a subclass of Pydantic BaseModel. Check if the given class is a subclass of any of the following: * pydantic.BaseModel in Pydantic 1.x * pydantic.BaseModel in Pydantic 2.x * pydantic.v1.BaseModel in Pydantic 2.x """# Before we can use issubclass on the cls we need to check if it is a classifnotinspect.isclass(cls):returnFalseifPYDANTIC_MAJOR_VERSION==1:frompydanticimportBaseModelasBaseModelV1Proper# pydantic: ignoreifissubclass(cls,BaseModelV1Proper):returnTrueelifPYDANTIC_MAJOR_VERSION==2:frompydanticimportBaseModelasBaseModelV2# pydantic: ignorefrompydantic.v1importBaseModelasBaseModelV1# pydantic: ignoreifissubclass(cls,BaseModelV2):returnTrueifissubclass(cls,BaseModelV1):returnTrueelse:raiseValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")returnFalse
[docs]defis_basemodel_instance(obj:Any)->bool:"""Check if the given class is an instance of Pydantic BaseModel. Check if the given class is an instance of any of the following: * pydantic.BaseModel in Pydantic 1.x * pydantic.BaseModel in Pydantic 2.x * pydantic.v1.BaseModel in Pydantic 2.x """ifPYDANTIC_MAJOR_VERSION==1:frompydanticimportBaseModelasBaseModelV1Proper# pydantic: ignoreifisinstance(obj,BaseModelV1Proper):returnTrueelifPYDANTIC_MAJOR_VERSION==2:frompydanticimportBaseModelasBaseModelV2# pydantic: ignorefrompydantic.v1importBaseModelasBaseModelV1# pydantic: ignoreifisinstance(obj,BaseModelV2):returnTrueifisinstance(obj,BaseModelV1):returnTrueelse:raiseValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")returnFalse
# How to type hint this?
[docs]defpre_init(func:Callable)->Any:"""Decorator to run a function before model initialization. Args: func (Callable): The function to run before model initialization. Returns: Any: The decorated function. """@root_validator(pre=True)@wraps(func)defwrapper(cls:Type[BaseModel],values:Dict[str,Any])->Dict[str,Any]:"""Decorator to run a function before model initialization. Args: cls (Type[BaseModel]): The model class. values (Dict[str, Any]): The values to initialize the model with. Returns: Dict[str, Any]: The values to initialize the model with. """# Insert default valuesfields=cls.__fields__forname,field_infoinfields.items():# Check if allow_population_by_field_name is enabled# If yes, then set the field name to the aliasifhasattr(cls,"Config"):ifhasattr(cls.Config,"allow_population_by_field_name"):ifcls.Config.allow_population_by_field_name:iffield_info.aliasinvalues:values[name]=values.pop(field_info.alias)ifnamenotinvaluesorvalues[name]isNone:ifnotfield_info.required:iffield_info.default_factoryisnotNone:values[name]=field_info.default_factory()else:values[name]=field_info.default# Call the decorated functionreturnfunc(cls,values)returnwrapper
def_create_subset_model_v1(name:str,model:Type[BaseModel],field_names:list,*,descriptions:Optional[dict]=None,fn_description:Optional[str]=None,)->Type[BaseModel]:"""Create a pydantic model with only a subset of model's fields."""fromlangchain_core.pydantic_v1importcreate_modelfields={}forfield_nameinfield_names:field=model.__fields__[field_name]t=(# this isn't perfect but should work for most functionsfield.outer_type_iffield.requiredandnotfield.allow_noneelseOptional[field.outer_type_])ifdescriptionsandfield_nameindescriptions:field.field_info.description=descriptions[field_name]fields[field_name]=(t,field.field_info)rtn=create_model(name,**fields)# type: ignorertn.__doc__=textwrap.dedent(fn_descriptionormodel.__doc__or"")returnrtndef_create_subset_model_v2(name:str,model:Type[pydantic.BaseModel],field_names:List[str],*,descriptions:Optional[dict]=None,fn_description:Optional[str]=None,)->Type[pydantic.BaseModel]:"""Create a pydantic model with a subset of the model fields."""frompydanticimportcreate_model# pydantic: ignorefrompydantic.fieldsimportFieldInfo# pydantic: ignoredescriptions_=descriptionsor{}fields={}forfield_nameinfield_names:field=model.model_fields[field_name]# type: ignoredescription=descriptions_.get(field_name,field.description)field_info=FieldInfo(description=description,default=field.default)iffield.metadata:field_info.metadata=field.metadatafields[field_name]=(field.annotation,field_info)rtn=create_model(name,**fields)# type: ignorertn.__doc__=textwrap.dedent(fn_descriptionormodel.__doc__or"")returnrtn# Private functionality to create a subset model that's compatible across# different versions of pydantic.# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.# However, can't find a way to type hint this.def_create_subset_model(name:str,model:TypeBaseModel,field_names:List[str],*,descriptions:Optional[dict]=None,fn_description:Optional[str]=None,)->Type[BaseModel]:"""Create subset model using the same pydantic version as the input model."""ifPYDANTIC_MAJOR_VERSION==1:return_create_subset_model_v1(name,model,field_names,descriptions=descriptions,fn_description=fn_description,)elifPYDANTIC_MAJOR_VERSION==2:frompydantic.v1importBaseModelasBaseModelV1# pydantic: ignoreifissubclass(model,BaseModelV1):return_create_subset_model_v1(name,model,field_names,descriptions=descriptions,fn_description=fn_description,)else:return_create_subset_model_v2(name,model,field_names,descriptions=descriptions,fn_description=fn_description,)else:raiseNotImplementedError(f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}")ifPYDANTIC_MAJOR_VERSION==2:frompydanticimportBaseModelasBaseModelV2frompydantic.fieldsimportFieldInfoasFieldInfoV2frompydantic.v1importBaseModelasBaseModelV1@overloaddefget_fields(model:Type[BaseModelV2])->Dict[str,FieldInfoV2]:...@overloaddefget_fields(model:BaseModelV2)->Dict[str,FieldInfoV2]:...@overloaddefget_fields(model:Type[BaseModelV1])->Dict[str,FieldInfoV1]:...@overloaddefget_fields(model:BaseModelV1)->Dict[str,FieldInfoV1]:...defget_fields(model:Union[BaseModelV2,BaseModelV1,Type[BaseModelV2],Type[BaseModelV1],],)->Union[Dict[str,FieldInfoV2],Dict[str,FieldInfoV1]]:"""Get the field names of a Pydantic model."""ifhasattr(model,"model_fields"):returnmodel.model_fields# type: ignoreelifhasattr(model,"__fields__"):returnmodel.__fields__# type: ignoreelse:raiseTypeError(f"Expected a Pydantic model. Got {type(model)}")elifPYDANTIC_MAJOR_VERSION==1:frompydanticimportBaseModelasBaseModelV1_
[docs]defget_fields(# type: ignore[no-redef]model:Union[Type[BaseModelV1_],BaseModelV1_],)->Dict[str,FieldInfoV1]:"""Get the field names of a Pydantic model."""returnmodel.__fields__# type: ignore