"""Utility code for runnables."""from__future__importannotationsimportastimportasyncioimportinspectimporttextwrapfromcollections.abcimport(AsyncIterable,AsyncIterator,Awaitable,Coroutine,Iterable,Mapping,Sequence,)fromfunctoolsimportlru_cachefrominspectimportsignaturefromitertoolsimportgroupbyfromtypingimport(Any,Callable,NamedTuple,Optional,Protocol,TypeVar,Union,)fromtyping_extensionsimportTypeGuard,overridefromlangchain_core.runnables.schemaimportStreamEvent# Re-export create-model for backwards compatibilityfromlangchain_core.utils.pydanticimportcreate_modelascreate_modelInput=TypeVar("Input",contravariant=True)# Output type should implement __concat__, as eg str, list, dict doOutput=TypeVar("Output",covariant=True)
[docs]asyncdefgated_coro(semaphore:asyncio.Semaphore,coro:Coroutine)->Any:"""Run a coroutine with a semaphore. Args: semaphore: The semaphore to use. coro: The coroutine to run. Returns: The result of the coroutine. """asyncwithsemaphore:returnawaitcoro
[docs]asyncdefgather_with_concurrency(n:Union[int,None],*coros:Coroutine)->list:"""Gather coroutines with a limit on the number of concurrent coroutines. Args: n: The number of coroutines to run concurrently. *coros: The coroutines to run. Returns: The results of the coroutines. """ifnisNone:returnawaitasyncio.gather(*coros)semaphore=asyncio.Semaphore(n)returnawaitasyncio.gather(*(gated_coro(semaphore,c)forcincoros))
[docs]defaccepts_run_manager(callable:Callable[...,Any])->bool:"""Check if a callable accepts a run_manager argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a run_manager argument, False otherwise. """try:returnsignature(callable).parameters.get("run_manager")isnotNoneexceptValueError:returnFalse
[docs]defaccepts_config(callable:Callable[...,Any])->bool:"""Check if a callable accepts a config argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a config argument, False otherwise. """try:returnsignature(callable).parameters.get("config")isnotNoneexceptValueError:returnFalse
[docs]defaccepts_context(callable:Callable[...,Any])->bool:"""Check if a callable accepts a context argument. Args: callable: The callable to check. Returns: bool: True if the callable accepts a context argument, False otherwise. """try:returnsignature(callable).parameters.get("context")isnotNoneexceptValueError:returnFalse
[docs]classIsLocalDict(ast.NodeVisitor):"""Check if a name is a local dict."""
[docs]def__init__(self,name:str,keys:set[str])->None:"""Initialize the visitor. Args: name: The name to check. keys: The keys to populate. """self.name=nameself.keys=keys
[docs]@overridedefvisit_Subscript(self,node:ast.Subscript)->Any:"""Visit a subscript node. Args: node: The node to visit. Returns: Any: The result of the visit. """if(isinstance(node.ctx,ast.Load)andisinstance(node.value,ast.Name)andnode.value.id==self.nameandisinstance(node.slice,ast.Constant)andisinstance(node.slice.value,str)):# we've found a subscript access on the name we're looking forself.keys.add(node.slice.value)
[docs]@overridedefvisit_Call(self,node:ast.Call)->Any:"""Visit a call node. Args: node: The node to visit. Returns: Any: The result of the visit. """if(isinstance(node.func,ast.Attribute)andisinstance(node.func.value,ast.Name)andnode.func.value.id==self.nameandnode.func.attr=="get"andlen(node.args)in(1,2)andisinstance(node.args[0],ast.Constant)andisinstance(node.args[0].value,str)):# we've found a .get() call on the name we're looking forself.keys.add(node.args[0].value)
[docs]classIsFunctionArgDict(ast.NodeVisitor):"""Check if the first argument of a function is a dict."""
[docs]@overridedefvisit_Lambda(self,node:ast.Lambda)->Any:"""Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """ifnotnode.args.args:returninput_arg_name=node.args.args[0].argIsLocalDict(input_arg_name,self.keys).visit(node.body)
[docs]@overridedefvisit_FunctionDef(self,node:ast.FunctionDef)->Any:"""Visit a function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ifnotnode.args.args:returninput_arg_name=node.args.args[0].argIsLocalDict(input_arg_name,self.keys).visit(node)
[docs]@overridedefvisit_AsyncFunctionDef(self,node:ast.AsyncFunctionDef)->Any:"""Visit an async function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """ifnotnode.args.args:returninput_arg_name=node.args.args[0].argIsLocalDict(input_arg_name,self.keys).visit(node)
[docs]@overridedefvisit_Name(self,node:ast.Name)->Any:"""Visit a name node. Args: node: The node to visit. Returns: Any: The result of the visit. """ifisinstance(node.ctx,ast.Load):self.loads.add(node.id)elifisinstance(node.ctx,ast.Store):self.stores.add(node.id)
[docs]@overridedefvisit_Attribute(self,node:ast.Attribute)->Any:"""Visit an attribute node. Args: node: The node to visit. Returns: Any: The result of the visit. """ifisinstance(node.ctx,ast.Load):parent=node.valueattr_expr=node.attrwhileisinstance(parent,ast.Attribute):attr_expr=parent.attr+"."+attr_exprparent=parent.valueifisinstance(parent,ast.Name):self.loads.add(parent.id+"."+attr_expr)self.loads.discard(parent.id)elifisinstance(parent,ast.Call):ifisinstance(parent.func,ast.Name):self.loads.add(parent.func.id)else:parent=parent.funcattr_expr=""whileisinstance(parent,ast.Attribute):ifattr_expr:attr_expr=parent.attr+"."+attr_exprelse:attr_expr=parent.attrparent=parent.valueifisinstance(parent,ast.Name):self.loads.add(parent.id+"."+attr_expr)
[docs]classFunctionNonLocals(ast.NodeVisitor):"""Get the nonlocal variables accessed of a function."""
[docs]@overridedefvisit_FunctionDef(self,node:ast.FunctionDef)->Any:"""Visit a function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """visitor=NonLocals()visitor.visit(node)self.nonlocals.update(visitor.loads-visitor.stores)
[docs]@overridedefvisit_AsyncFunctionDef(self,node:ast.AsyncFunctionDef)->Any:"""Visit an async function definition. Args: node: The node to visit. Returns: Any: The result of the visit. """visitor=NonLocals()visitor.visit(node)self.nonlocals.update(visitor.loads-visitor.stores)
[docs]@overridedefvisit_Lambda(self,node:ast.Lambda)->Any:"""Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """visitor=NonLocals()visitor.visit(node)self.nonlocals.update(visitor.loads-visitor.stores)
[docs]classGetLambdaSource(ast.NodeVisitor):"""Get the source code of a lambda function."""
[docs]def__init__(self)->None:"""Initialize the visitor."""self.source:Optional[str]=Noneself.count=0
[docs]@overridedefvisit_Lambda(self,node:ast.Lambda)->Any:"""Visit a lambda function. Args: node: The node to visit. Returns: Any: The result of the visit. """self.count+=1ifhasattr(ast,"unparse"):self.source=ast.unparse(node)
[docs]defget_function_first_arg_dict_keys(func:Callable)->Optional[list[str]]:"""Get the keys of the first argument of a function if it is a dict. Args: func: The function to check. Returns: Optional[List[str]]: The keys of the first argument if it is a dict, None otherwise. """try:code=inspect.getsource(func)tree=ast.parse(textwrap.dedent(code))visitor=IsFunctionArgDict()visitor.visit(tree)returnsorted(visitor.keys)ifvisitor.keyselseNoneexcept(SyntaxError,TypeError,OSError,SystemError):returnNone
[docs]defget_lambda_source(func:Callable)->Optional[str]:"""Get the source code of a lambda function. Args: func: a Callable that can be a lambda function. Returns: str: the source code of the lambda function. """try:name=func.__name__iffunc.__name__!="<lambda>"elseNoneexceptAttributeError:name=Nonetry:code=inspect.getsource(func)tree=ast.parse(textwrap.dedent(code))visitor=GetLambdaSource()visitor.visit(tree)except(SyntaxError,TypeError,OSError,SystemError):returnnamereturnvisitor.sourceifvisitor.count==1elsename
@lru_cache(maxsize=256)defget_function_nonlocals(func:Callable)->list[Any]:"""Get the nonlocal variables accessed by a function. Args: func: The function to check. Returns: List[Any]: The nonlocal variables accessed by the function. """try:code=inspect.getsource(func)tree=ast.parse(textwrap.dedent(code))visitor=FunctionNonLocals()visitor.visit(tree)values:list[Any]=[]closure=(inspect.getclosurevars(func.__wrapped__)ifhasattr(func,"__wrapped__")andcallable(func.__wrapped__)elseinspect.getclosurevars(func))candidates={**closure.globals,**closure.nonlocals}fork,vincandidates.items():ifkinvisitor.nonlocals:values.append(v)forkkinvisitor.nonlocals:if"."inkkandkk.startswith(k):vv=vforpartinkk.split(".")[1:]:ifvvisNone:breakelse:try:vv=getattr(vv,part)exceptAttributeError:breakelse:values.append(vv)except(SyntaxError,TypeError,OSError,SystemError):return[]returnvalues
[docs]defindent_lines_after_first(text:str,prefix:str)->str:"""Indent all lines of text after the first line. Args: text: The text to indent. prefix: Used to determine the number of spaces to indent. Returns: str: The indented text. """n_spaces=len(prefix)spaces=" "*n_spaceslines=text.splitlines()return"\n".join([lines[0]]+[spaces+lineforlineinlines[1:]])
[docs]classAddableDict(dict[str,Any]):"""Dictionary that can be added to another dictionary."""def__add__(self,other:AddableDict)->AddableDict:chunk=AddableDict(self)forkeyinother:ifkeynotinchunkorchunk[key]isNone:chunk[key]=other[key]elifother[key]isnotNone:try:added=chunk[key]+other[key]exceptTypeError:added=other[key]chunk[key]=addedreturnchunkdef__radd__(self,other:AddableDict)->AddableDict:chunk=AddableDict(other)forkeyinself:ifkeynotinchunkorchunk[key]isNone:chunk[key]=self[key]elifself[key]isnotNone:try:added=chunk[key]+self[key]exceptTypeError:added=self[key]chunk[key]=addedreturnchunk
[docs]defadd(addables:Iterable[Addable])->Optional[Addable]:"""Add a sequence of addable objects together. Args: addables: The addable objects to add. Returns: Optional[Addable]: The result of adding the addable objects. """final:Optional[Addable]=Noneforchunkinaddables:final=chunkiffinalisNoneelsefinal+chunkreturnfinal
[docs]asyncdefaadd(addables:AsyncIterable[Addable])->Optional[Addable]:"""Asynchronously add a sequence of addable objects together. Args: addables: The addable objects to add. Returns: Optional[Addable]: The result of adding the addable objects. """final:Optional[Addable]=Noneasyncforchunkinaddables:final=chunkiffinalisNoneelsefinal+chunkreturnfinal
[docs]classConfigurableField(NamedTuple):"""Field that can be configured by the user. Parameters: id: The unique identifier of the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. annotation: The annotation of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """id:strname:Optional[str]=Nonedescription:Optional[str]=Noneannotation:Optional[Any]=Noneis_shared:bool=Falsedef__hash__(self)->int:returnhash((self.id,self.annotation))
[docs]classConfigurableFieldSingleOption(NamedTuple):"""Field that can be configured by the user with a default value. Parameters: id: The unique identifier of the field. options: The options for the field. default: The default value for the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """id:stroptions:Mapping[str,Any]default:strname:Optional[str]=Nonedescription:Optional[str]=Noneis_shared:bool=Falsedef__hash__(self)->int:returnhash((self.id,tuple(self.options.keys()),self.default))
[docs]classConfigurableFieldMultiOption(NamedTuple):"""Field that can be configured by the user with multiple default values. Parameters: id: The unique identifier of the field. options: The options for the field. default: The default values for the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. """id:stroptions:Mapping[str,Any]default:Sequence[str]name:Optional[str]=Nonedescription:Optional[str]=Noneis_shared:bool=Falsedef__hash__(self)->int:returnhash((self.id,tuple(self.options.keys()),tuple(self.default)))
[docs]classConfigurableFieldSpec(NamedTuple):"""Field that can be configured by the user. It is a specification of a field. Parameters: id: The unique identifier of the field. annotation: The annotation of the field. name: The name of the field. Defaults to None. description: The description of the field. Defaults to None. default: The default value for the field. Defaults to None. is_shared: Whether the field is shared. Defaults to False. dependencies: The dependencies of the field. Defaults to None. """id:strannotation:Anyname:Optional[str]=Nonedescription:Optional[str]=Nonedefault:Any=Noneis_shared:bool=Falsedependencies:Optional[list[str]]=None
[docs]defget_unique_config_specs(specs:Iterable[ConfigurableFieldSpec],)->list[ConfigurableFieldSpec]:"""Get the unique config specs from a sequence of config specs. Args: specs: The config specs. Returns: List[ConfigurableFieldSpec]: The unique config specs. Raises: ValueError: If the runnable sequence contains conflicting config specs. """grouped=groupby(sorted(specs,key=lambdas:(s.id,*(s.dependenciesor[]))),lambdas:s.id)unique:list[ConfigurableFieldSpec]=[]forid,dupesingrouped:first=next(dupes)others=list(dupes)iflen(others)==0orall(o==firstforoinothers):unique.append(first)else:msg=("RunnableSequence contains conflicting config specs"f"for {id}: {[first]+others}")raiseValueError(msg)returnunique
class_RootEventFilter:def__init__(self,*,include_names:Optional[Sequence[str]]=None,include_types:Optional[Sequence[str]]=None,include_tags:Optional[Sequence[str]]=None,exclude_names:Optional[Sequence[str]]=None,exclude_types:Optional[Sequence[str]]=None,exclude_tags:Optional[Sequence[str]]=None,)->None:"""Utility to filter the root event in the astream_events implementation. This is simply binding the arguments to the namespace to make save on a bit of typing in the astream_events implementation. """self.include_names=include_namesself.include_types=include_typesself.include_tags=include_tagsself.exclude_names=exclude_namesself.exclude_types=exclude_typesself.exclude_tags=exclude_tagsdefinclude_event(self,event:StreamEvent,root_type:str)->bool:"""Determine whether to include an event."""if(self.include_namesisNoneandself.include_typesisNoneandself.include_tagsisNone):include=Trueelse:include=Falseevent_tags=event.get("tags")or[]ifself.include_namesisnotNone:include=includeorevent["name"]inself.include_namesifself.include_typesisnotNone:include=includeorroot_typeinself.include_typesifself.include_tagsisnotNone:include=includeorany(taginself.include_tagsfortaginevent_tags)ifself.exclude_namesisnotNone:include=includeandevent["name"]notinself.exclude_namesifself.exclude_typesisnotNone:include=includeandroot_typenotinself.exclude_typesifself.exclude_tagsisnotNone:include=includeandall(tagnotinself.exclude_tagsfortaginevent_tags)returninclude
[docs]defis_async_generator(func:Any,)->TypeGuard[Callable[...,AsyncIterator]]:"""Check if a function is an async generator. Args: func: The function to check. Returns: TypeGuard[Callable[..., AsyncIterator]: True if the function is an async generator, False otherwise. """return(inspect.isasyncgenfunction(func)orhasattr(func,"__call__")# noqa: B004andinspect.isasyncgenfunction(func.__call__))
[docs]defis_async_callable(func:Any,)->TypeGuard[Callable[...,Awaitable]]:"""Check if a function is async. Args: func: The function to check. Returns: TypeGuard[Callable[..., Awaitable]: True if the function is async, False otherwise. """return(asyncio.iscoroutinefunction(func)orhasattr(func,"__call__")# noqa: B004andasyncio.iscoroutinefunction(func.__call__))