"""Utility code for runnables."""from__future__importannotationsimportastimportasyncioimportinspectimporttextwrapfromfunctoolsimportlru_cachefrominspectimportsignaturefromitertoolsimportgroupbyfromtypingimport(Any,AsyncIterable,AsyncIterator,Awaitable,Callable,Coroutine,Dict,Iterable,List,Mapping,NamedTuple,Optional,Protocol,Sequence,Set,Type,TypeVar,Union,)fromtyping_extensionsimportTypeGuardfromlangchain_core.pydantic_v1importBaseConfig,BaseModelfromlangchain_core.pydantic_v1importcreate_modelas_create_model_basefromlangchain_core.runnables.schemaimportStreamEventInput=TypeVar("Input",contravariant=True)# Output type should implement __concat__, as eg str, list, dict doOutput=TypeVar("Output",covariant=True)
[docs]asyncdefgated_coro(semaphore:asyncio.Semaphore,coro:Coroutine)->Any:"""Run a coroutine with a semaphore. Args: semaphore: The semaphore to use. coro: The coroutine to run. Returns: The result of the coroutine. """asyncwithsemaphore:returnawaitcoro
[docs]asyncdefgather_with_concurrency(n:Union[int,None],*coros:Coroutine)->list:"""Gather coroutines with a limit on the number of concurrent coroutines. Args: n: The number of coroutines to run concurrently. *coros: The coroutines to run. Returns: The results of the coroutines. """ifnisNone:returnawaitasyncio.gather(*coros)semaphore=asyncio.Semaphore(n)returnawaitasyncio.gather(*(gated_coro(semaphore,c)forcincoros))
[docs]defaccepts_run_manager(callable:Callable[...,Any])->bool:"""Check if a callable accepts a run_manager argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a run_manager argument, False otherwise. """try:returnsignature(callable).parameters.get("run_manager")isnotNoneexceptValueError:returnFalse
[docs]defaccepts_config(callable:Callable[...,Any])->bool:"""Check if a callable accepts a config argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a config argument, False otherwise. """try:returnsignature(callable).parameters.get("config")isnotNoneexceptValueError:returnFalse
[docs]defaccepts_context(callable:Callable[...,Any])->bool:"""Check if a callable accepts a context argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a context argument, False otherwise. """try:returnsignature(callable).parameters.get("context")isnotNoneexceptValueError:returnFalse
[docs]classIsLocalDict(ast.NodeVisitor):"""Check if a name is a local dict."""
[docs]def__init__(self,name:str,keys:Set[str])->None:"""Initialize the visitor. Args: name: The name to check. keys: The keys to populate. """self.name=nameself.keys=keys
[docs]defvisit_Subscript(self,node:ast.Subscript)->Any:"""Visit a subscript node. Args: node: The node to visit. Returns: Any: The result of the visit. """if(isinstance(node.ctx,ast.Load)andisinstance(node.value,ast.Name)andnode.value.id==self.nameandisinstance(node.slice,ast.Constant)andisinstance(node.slice.value,str)):# we've found a subscript access on the name we're looking forself.keys.add(node.slice.value)
[docs]defvisit_Call(self,node:ast.Call)->Any:"""Visit a call node. Args: node: The node to visit. Returns: Any: The result of the visit. """if(isinstance(node.func,ast.Attribute)andisinstance(node.func.value,ast.Name)andnode.func.value.id==self.nameandnode.func.attr=="get"andlen(node.args)in(1,2)andisinstance(node.args[0],ast.Constant)andisinstance(node.args[0].value,str)):# we've found a .get() call on the name we're looking forself.keys.add(node.args[0].value)
[docs]classIsFunctionArgDict(ast.NodeVisitor):"""Check if the first argument of a function is a dict."""
[docs]defvisit_Lambda(self,node:ast.Lambda)->Any:"""Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """ifnotnode.args.args:returninput_arg_name=node.args.args[0].argIsLocalDict(input_arg_name,self.keys).visit(node.body)
[docs]defvisit_FunctionDef(self,node:ast.FunctionDef)->Any:"""Visit a function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ifnotnode.args.args:returninput_arg_name=node.args.args[0].argIsLocalDict(input_arg_name,self.keys).visit(node)
[docs]defvisit_AsyncFunctionDef(self,node:ast.AsyncFunctionDef)->Any:"""Visit an async function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ifnotnode.args.args:returninput_arg_name=node.args.args[0].argIsLocalDict(input_arg_name,self.keys).visit(node)
[docs]defvisit_Name(self,node:ast.Name)->Any:"""Visit a name node. Args: node: The node to visit. Returns: Any: The result of the visit. """ifisinstance(node.ctx,ast.Load):self.loads.add(node.id)elifisinstance(node.ctx,ast.Store):self.stores.add(node.id)
[docs]defvisit_Attribute(self,node:ast.Attribute)->Any:"""Visit an attribute node. Args: node: The node to visit. Returns: Any: The result of the visit. """ifisinstance(node.ctx,ast.Load):parent=node.valueattr_expr=node.attrwhileisinstance(parent,ast.Attribute):attr_expr=parent.attr+"."+attr_exprparent=parent.valueifisinstance(parent,ast.Name):self.loads.add(parent.id+"."+attr_expr)self.loads.discard(parent.id)
[docs]classFunctionNonLocals(ast.NodeVisitor):"""Get the nonlocal variables accessed of a function."""
[docs]defvisit_FunctionDef(self,node:ast.FunctionDef)->Any:"""Visit a function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """visitor=NonLocals()visitor.visit(node)self.nonlocals.update(visitor.loads-visitor.stores)
[docs]defvisit_AsyncFunctionDef(self,node:ast.AsyncFunctionDef)->Any:"""Visit an async function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """visitor=NonLocals()visitor.visit(node)self.nonlocals.update(visitor.loads-visitor.stores)
[docs]defvisit_Lambda(self,node:ast.Lambda)->Any:"""Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """visitor=NonLocals()visitor.visit(node)self.nonlocals.update(visitor.loads-visitor.stores)
[docs]classGetLambdaSource(ast.NodeVisitor):"""Get the source code of a lambda function."""
[docs]def__init__(self)->None:"""Initialize the visitor."""self.source:Optional[str]=Noneself.count=0
[docs]defvisit_Lambda(self,node:ast.Lambda)->Any:"""Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """self.count+=1ifhasattr(ast,"unparse"):self.source=ast.unparse(node)
[docs]defget_function_first_arg_dict_keys(func:Callable)->Optional[List[str]]:"""Get the keys of the first argument of a function if it is a dict. Args: func: The function to check. Returns: Optional[List[str]]: The keys of the first argument if it is a dict, None otherwise. """try:code=inspect.getsource(func)tree=ast.parse(textwrap.dedent(code))visitor=IsFunctionArgDict()visitor.visit(tree)returnlist(visitor.keys)ifvisitor.keyselseNoneexcept(SyntaxError,TypeError,OSError,SystemError):returnNone
[docs]defget_lambda_source(func:Callable)->Optional[str]:"""Get the source code of a lambda function. Args: func: a Callable that can be a lambda function. Returns: str: the source code of the lambda function. """try:name=func.__name__iffunc.__name__!="<lambda>"elseNoneexceptAttributeError:name=Nonetry:code=inspect.getsource(func)tree=ast.parse(textwrap.dedent(code))visitor=GetLambdaSource()visitor.visit(tree)returnvisitor.sourceifvisitor.count==1elsenameexcept(SyntaxError,TypeError,OSError,SystemError):returnname
[docs]defget_function_nonlocals(func:Callable)->List[Any]:"""Get the nonlocal variables accessed by a function. Args: func: The function to check. Returns: List[Any]: The nonlocal variables accessed by the function. """try:code=inspect.getsource(func)tree=ast.parse(textwrap.dedent(code))visitor=FunctionNonLocals()visitor.visit(tree)values:List[Any]=[]closure=inspect.getclosurevars(func)candidates={**closure.globals,**closure.nonlocals}fork,vincandidates.items():ifkinvisitor.nonlocals:values.append(v)forkkinvisitor.nonlocals:if"."inkkandkk.startswith(k):vv=vforpartinkk.split(".")[1:]:ifvvisNone:breakelse:try:vv=getattr(vv,part)exceptAttributeError:breakelse:values.append(vv)returnvaluesexcept(SyntaxError,TypeError,OSError,SystemError):return[]
[docs]defindent_lines_after_first(text:str,prefix:str)->str:"""Indent all lines of text after the first line. Args: text: The text to indent. prefix: Used to determine the number of spaces to indent. Returns: str: The indented text. """n_spaces=len(prefix)spaces=" "*n_spaceslines=text.splitlines()return"\n".join([lines[0]]+[spaces+lineforlineinlines[1:]])
[docs]classAddableDict(Dict[str,Any]):""" Dictionary that can be added to another dictionary. """def__add__(self,other:AddableDict)->AddableDict:chunk=AddableDict(self)forkeyinother:ifkeynotinchunkorchunk[key]isNone:chunk[key]=other[key]elifother[key]isnotNone:try:added=chunk[key]+other[key]exceptTypeError:added=other[key]chunk[key]=addedreturnchunkdef__radd__(self,other:AddableDict)->AddableDict:chunk=AddableDict(other)forkeyinself:ifkeynotinchunkorchunk[key]isNone:chunk[key]=self[key]elifself[key]isnotNone:try:added=chunk[key]+self[key]exceptTypeError:added=self[key]chunk[key]=addedreturnchunk
[docs]defadd(addables:Iterable[Addable])->Optional[Addable]:"""Add a sequence of addable objects together. Args: addables: The addable objects to add. Returns: Optional[Addable]: The result of adding the addable objects. """final=Noneforchunkinaddables:iffinalisNone:final=chunkelse:final=final+chunkreturnfinal
[docs]asyncdefaadd(addables:AsyncIterable[Addable])->Optional[Addable]:"""Asynchronously add a sequence of addable objects together. Args: addables: The addable objects to add. Returns: Optional[Addable]: The result of adding the addable objects. """final=Noneasyncforchunkinaddables:iffinalisNone:final=chunkelse:final=final+chunkreturnfinal
[docs]classConfigurableField(NamedTuple):"""Field that can be configured by the user. Parameters: id: The unique identifier of the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. annotation: The annotation of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """id:strname:Optional[str]=Nonedescription:Optional[str]=Noneannotation:Optional[Any]=Noneis_shared:bool=Falsedef__hash__(self)->int:returnhash((self.id,self.annotation))
[docs]classConfigurableFieldSingleOption(NamedTuple):"""Field that can be configured by the user with a default value. Parameters: id: The unique identifier of the field. options: The options for the field. default: The default value for the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """id:stroptions:Mapping[str,Any]default:strname:Optional[str]=Nonedescription:Optional[str]=Noneis_shared:bool=Falsedef__hash__(self)->int:returnhash((self.id,tuple(self.options.keys()),self.default))
[docs]classConfigurableFieldMultiOption(NamedTuple):"""Field that can be configured by the user with multiple default values. Parameters: id: The unique identifier of the field. options: The options for the field. default: The default values for the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """id:stroptions:Mapping[str,Any]default:Sequence[str]name:Optional[str]=Nonedescription:Optional[str]=Noneis_shared:bool=Falsedef__hash__(self)->int:returnhash((self.id,tuple(self.options.keys()),tuple(self.default)))
[docs]classConfigurableFieldSpec(NamedTuple):"""Field that can be configured by the user. It is a specification of a field. Parameters: id: The unique identifier of the field. annotation: The annotation of the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. default: The default value for the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. dependencies: The dependencies of the field. Defaults to None. """id:strannotation:Anyname:Optional[str]=Nonedescription:Optional[str]=Nonedefault:Any=Noneis_shared:bool=Falsedependencies:Optional[List[str]]=None
[docs]defget_unique_config_specs(specs:Iterable[ConfigurableFieldSpec],)->List[ConfigurableFieldSpec]:"""Get the unique config specs from a sequence of config specs. Args: specs: The config specs. Returns: List[ConfigurableFieldSpec]: The unique config specs. Raises: ValueError: If the runnable sequence contains conflicting config specs. """grouped=groupby(sorted(specs,key=lambdas:(s.id,*(s.dependenciesor[]))),lambdas:s.id)unique:List[ConfigurableFieldSpec]=[]forid,dupesingrouped:first=next(dupes)others=list(dupes)iflen(others)==0:unique.append(first)elifall(o==firstforoinothers):unique.append(first)else:raiseValueError("RunnableSequence contains conflicting config specs"f"for {id}: {[first]+others}")returnunique
class_RootEventFilter:def__init__(self,*,include_names:Optional[Sequence[str]]=None,include_types:Optional[Sequence[str]]=None,include_tags:Optional[Sequence[str]]=None,exclude_names:Optional[Sequence[str]]=None,exclude_types:Optional[Sequence[str]]=None,exclude_tags:Optional[Sequence[str]]=None,)->None:"""Utility to filter the root event in the astream_events implementation. This is simply binding the arguments to the namespace to make save on a bit of typing in the astream_events implementation. """self.include_names=include_namesself.include_types=include_typesself.include_tags=include_tagsself.exclude_names=exclude_namesself.exclude_types=exclude_typesself.exclude_tags=exclude_tagsdefinclude_event(self,event:StreamEvent,root_type:str)->bool:"""Determine whether to include an event."""if(self.include_namesisNoneandself.include_typesisNoneandself.include_tagsisNone):include=Trueelse:include=Falseevent_tags=event.get("tags")or[]ifself.include_namesisnotNone:include=includeorevent["name"]inself.include_namesifself.include_typesisnotNone:include=includeorroot_typeinself.include_typesifself.include_tagsisnotNone:include=includeorany(taginself.include_tagsfortaginevent_tags)ifself.exclude_namesisnotNone:include=includeandevent["name"]notinself.exclude_namesifself.exclude_typesisnotNone:include=includeandroot_typenotinself.exclude_typesifself.exclude_tagsisnotNone:include=includeandall(tagnotinself.exclude_tagsfortaginevent_tags)returnincludeclass_SchemaConfig(BaseConfig):arbitrary_types_allowed=Truefrozen=True
[docs]defcreate_model(__model_name:str,**field_definitions:Any,)->Type[BaseModel]:"""Create a pydantic model with the given field definitions. Args: __model_name: The name of the model. **field_definitions: The field definitions for the model. Returns: Type[BaseModel]: The created model. """try:return_create_model_cached(__model_name,**field_definitions)exceptTypeError:# something in field definitions is not hashablereturn_create_model_base(__model_name,__config__=_SchemaConfig,**field_definitions)
[docs]defis_async_generator(func:Any,)->TypeGuard[Callable[...,AsyncIterator]]:"""Check if a function is an async generator. Args: func: The function to check. Returns: TypeGuard[Callable[..., AsyncIterator]: True if the function is an async generator, False otherwise. """return(inspect.isasyncgenfunction(func)orhasattr(func,"__call__")# noqa: B004andinspect.isasyncgenfunction(func.__call__))
[docs]defis_async_callable(func:Any,)->TypeGuard[Callable[...,Awaitable]]:"""Check if a function is async. Args: func: The function to check. Returns: TypeGuard[Callable[..., Awaitable]: True if the function is async, False otherwise. """return(asyncio.iscoroutinefunction(func)orhasattr(func,"__call__")# noqa: B004andasyncio.iscoroutinefunction(func.__call__))