Source code for langchain_nvidia_ai_endpoints.chat_models
"""Chat Model Components Derived from ChatModel/NVIDIA"""from__future__importannotationsimportbase64importenumimportloggingimportosimportreimporturllib.parseimportwarningsfromtypingimport(Any,Callable,Dict,Iterator,List,Literal,Optional,Sequence,Tuple,Type,Union,)fromlangchain_core.callbacks.managerimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.exceptionsimportOutputParserExceptionfromlangchain_core.language_modelsimportBaseChatModel,LanguageModelInputfromlangchain_core.language_models.chat_modelsimportLangSmithParamsfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,)fromlangchain_core.output_parsersimport(BaseOutputParser,JsonOutputParser,PydanticOutputParser,)fromlangchain_core.outputsimport(ChatGeneration,ChatGenerationChunk,ChatResult,Generation,)fromlangchain_core.runnablesimportRunnablefromlangchain_core.toolsimportBaseToolfromlangchain_core.utils.function_callingimportconvert_to_openai_toolfromlangchain_core.utils.pydanticimportis_basemodel_subclassfrompydanticimportBaseModel,Field,PrivateAttrfromlangchain_nvidia_ai_endpoints._commonimport_NVIDIAClientfromlangchain_nvidia_ai_endpoints._staticsimportModelfromlangchain_nvidia_ai_endpoints._utilsimportconvert_message_to_dict_CallbackManager=Union[AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun]logger=logging.getLogger(__name__)def_is_url(s:str)->bool:try:result=urllib.parse.urlparse(s)returnall([result.scheme,result.netloc])exceptExceptionase:logger.debug(f"Unable to parse URL: {e}")returnFalsedef_url_to_b64_string(image_source:str)->str:try:if_is_url(image_source):returnimage_source# import sys# import io# try:# import PIL.Image# has_pillow = True# except ImportError:# has_pillow = False# def _resize_image(img_data: bytes, max_dim: int = 1024) -> str:# if not has_pillow:# print( # noqa: T201# "Pillow is required to resize images down to reasonable scale." # noqa: E501# " Please install it using `pip install pillow`."# " For now, not resizing; may cause NVIDIA API to fail."# )# return base64.b64encode(img_data).decode("utf-8")# image = PIL.Image.open(io.BytesIO(img_data))# max_dim_size = max(image.size)# aspect_ratio = max_dim / max_dim_size# new_h = int(image.size[1] * aspect_ratio)# new_w = int(image.size[0] * aspect_ratio)# resized_image = image.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS) # noqa: E501# output_buffer = io.BytesIO()# resized_image.save(output_buffer, format="JPEG")# output_buffer.seek(0)# resized_b64_string = base64.b64encode(output_buffer.read()).decode("utf-8") # noqa: E501# return resized_b64_string# b64_template = "data:image/png;base64,{b64_string}"# response = requests.get(# image_source, headers={"User-Agent": "langchain-nvidia-ai-endpoints"}# )# response.raise_for_status()# encoded = base64.b64encode(response.content).decode("utf-8")# if sys.getsizeof(encoded) > 200000:# ## (VK) Temporary fix. NVIDIA API has a limit of 250KB for the input.# encoded = _resize_image(response.content)# return b64_template.format(b64_string=encoded)elifimage_source.startswith("data:image"):returnimage_sourceelifos.path.exists(image_source):withopen(image_source,"rb")asf:image_data=f.read()importfiletype# type: ignorekind=filetype.guess(image_data)image_type=kind.extensionifkindelse"unknown"encoded=base64.b64encode(image_data).decode("utf-8")returnf"data:image/{image_type};base64,{encoded}"else:raiseValueError("The provided string is not a valid URL, base64, or file path.")exceptExceptionase:raiseValueError(f"Unable to process the provided image source: {e}")def_nv_vlm_adjust_input(message_dict:Dict[str,Any],model_type:str)->Dict[str,Any]:""" The NVIDIA VLM API input message.content: { "role": "user", "content": [ ..., { "type": "image_url", "image_url": "{data}" }, ... ] } where OpenAI VLM API input message.content: { "role": "user", "content": [ ..., { "type": "image_url", "image_url": { "url": "{url | data}" } }, ... ] } This function converts the OpenAI VLM API input message to NVIDIA VLM API input message, in place. In the process, it accepts a url or file and converts them to data urls. """ifcontent:=message_dict.get("content"):ifisinstance(content,list):forpartincontent:ifisinstance(part,dict)and"image_url"inpart:if(isinstance(part["image_url"],dict)and"url"inpart["image_url"]):url=_url_to_b64_string(part["image_url"]["url"])ifmodel_type=="nv-vlm":part["image_url"]=urlelse:part["image_url"]["url"]=urlreturnmessage_dictdef_nv_vlm_get_asset_ids(content:Union[str,List[Union[str,Dict[str,Any]]]],)->List[str]:""" VLM APIs accept asset IDs as input in two forms: - content = [{"image_url": {"url": "data:image/{type};asset_id,{asset_id}"}}*] - content = .*<img src="data:image/{type};asset_id,{asset_id}"/>.* This function extracts asset IDs from the message content. """defextract_asset_id(data:str)->List[str]:pattern=re.compile(r'data:image/[^;]+;asset_id,([^"\'\s]+)')returnpattern.findall(data)asset_ids=[]ifisinstance(content,str):asset_ids.extend(extract_asset_id(content))elifisinstance(content,list):forpartincontent:ifisinstance(part,str):asset_ids.extend(extract_asset_id(part))elifisinstance(part,dict)and"image_url"inpart:image_url=part["image_url"]ifisinstance(image_url,dict)and"url"inimage_url:asset_ids.extend(extract_asset_id(image_url["url"]))returnasset_idsdef_process_for_vlm(inputs:List[Dict[str,Any]],model:Optional[Model],# not optional, Optional for type alignment)->Tuple[List[Dict[str,Any]],Dict[str,str]]:""" Process inputs for NVIDIA VLM models. This function processes the input messages for NVIDIA VLM models. It extracts asset IDs from the input messages and adds them to the headers for the NVIDIA VLM API. """ifnotmodelornotmodel.model_type:returninputs,{}extra_headers={}if"vlm"inmodel.model_type:asset_ids=[]forinputininputs:if"content"ininput:asset_ids.extend(_nv_vlm_get_asset_ids(input["content"]))ifasset_ids:extra_headers["NVCF-INPUT-ASSET-REFERENCES"]=",".join(asset_ids)inputs=[_nv_vlm_adjust_input(message,model.model_type)formessageininputs]returninputs,extra_headers_DEFAULT_MODEL_NAME:str="meta/llama3-8b-instruct"
[docs]classChatNVIDIA(BaseChatModel):"""NVIDIA chat model. Example: .. code-block:: python from langchain_nvidia_ai_endpoints import ChatNVIDIA model = ChatNVIDIA(model="meta/llama2-70b") response = model.invoke("Hello") """_client:_NVIDIAClient=PrivateAttr()base_url:Optional[str]=Field(default=None,description="Base url for model listing an invocation",)model:Optional[str]=Field(None,description="Name of the model to invoke")temperature:Optional[float]=Field(None,description="Sampling temperature in [0, 1]")max_tokens:Optional[int]=Field(1024,description="Maximum # of tokens to generate")top_p:Optional[float]=Field(None,description="Top-p for distribution sampling")seed:Optional[int]=Field(None,description="The seed for deterministic results")stop:Optional[Sequence[str]]=Field(None,description="Stop words (cased)")def__init__(self,**kwargs:Any):""" Create a new NVIDIAChat chat model. This class provides access to a NVIDIA NIM for chat. By default, it connects to a hosted NIM, but can be configured to connect to a local NIM using the `base_url` parameter. An API key is required to connect to the hosted NIM. Args: model (str): The model to use for chat. nvidia_api_key (str): The API key to use for connecting to the hosted NIM. api_key (str): Alternative to nvidia_api_key. base_url (str): The base URL of the NIM to connect to. Format for base URL is http://host:port temperature (float): Sampling temperature in [0, 1]. max_tokens (int): Maximum number of tokens to generate. top_p (float): Top-p for distribution sampling. seed (int): A seed for deterministic results. stop (list[str]): A list of cased stop words. API Key: - The recommended way to provide the API key is through the `NVIDIA_API_KEY` environment variable. Base URL: - Connect to a self-hosted model with NVIDIA NIM using the `base_url` arg to link to the local host at localhost:8000: llm = ChatNVIDIA( base_url="http://localhost:8000/v1", model="meta-llama3-8b-instruct" ) """super().__init__(**kwargs)# allow nvidia_base_url as an alternative for base_urlbase_url=kwargs.pop("nvidia_base_url",self.base_url)# allow nvidia_api_key as an alternative for api_keyapi_key=kwargs.pop("nvidia_api_key",kwargs.pop("api_key",None))self._client=_NVIDIAClient(**({"base_url":base_url}ifbase_urlelse{}),# only pass if setmdl_name=self.model,default_hosted_model_name=_DEFAULT_MODEL_NAME,**({"api_key":api_key}ifapi_keyelse{}),# only pass if setinfer_path="{base_url}/chat/completions",# instead of self.__class__.__name__ to assist in subclassing ChatNVIDIAcls="ChatNVIDIA",)# todo: only store the model in one place# the model may be updated to a newer name during initializationself.model=self._client.mdl_name# same for base_urlself.base_url=self._client.base_url@propertydefavailable_models(self)->List[Model]:""" Get a list of available models that work with ChatNVIDIA. """returnself._client.get_available_models(self.__class__.__name__)
[docs]@classmethoddefget_available_models(cls,**kwargs:Any,)->List[Model]:""" Get a list of available models that work with ChatNVIDIA. """returncls(**kwargs).available_models
@propertydef_llm_type(self)->str:"""Return type of NVIDIA AI Foundation Model Interface."""return"chat-nvidia-ai-playground"def_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any,)->LangSmithParams:"""Get standard LangSmith parameters for tracing."""params=self._get_invocation_params(stop=stop,**kwargs)returnLangSmithParams(ls_provider="NVIDIA",# error: Incompatible types (expression has type "Optional[str]",# TypedDict item "ls_model_name" has type "str") [typeddict-item]ls_model_name=self.modelor"UNKNOWN",ls_model_type="chat",ls_temperature=params.get("temperature",self.temperature),ls_max_tokens=params.get("max_tokens",self.max_tokens),# mypy error: Extra keys ("ls_top_p", "ls_seed")# for TypedDict "LangSmithParams" [typeddict-item]# ls_top_p=params.get("top_p", self.top_p),# ls_seed=params.get("seed", self.seed),ls_stop=params.get("stop",self.stop),)def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:inputs=[messageformessagein[convert_message_to_dict(message)formessageinmessages]]inputs,extra_headers=_process_for_vlm(inputs,self._client.model)payload=self._get_payload(inputs=inputs,stop=stop,stream=False,**kwargs)response=self._client.get_req(payload=payload,extra_headers=extra_headers)responses,_=self._client.postprocess(response)self._set_callback_out(responses,run_manager)parsed_response=self._custom_postprocess(responses,streaming=False)# for pre 0.2 compatibility w/ ChatMessage# ChatMessage had a role property that was not present in AIMessageparsed_response.update({"role":"assistant"})generation=ChatGeneration(message=AIMessage(**parsed_response))returnChatResult(generations=[generation],llm_output=responses)def_stream(self,messages:List[BaseMessage],stop:Optional[Sequence[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:"""Allows streaming to model!"""inputs=[messageformessagein[convert_message_to_dict(message)formessageinmessages]]inputs,extra_headers=_process_for_vlm(inputs,self._client.model)payload=self._get_payload(inputs=inputs,stop=stop,stream=True,stream_options={"include_usage":True},**kwargs,)# todo: get vlm endpoints fixed and remove this# vlm endpoints do not accept standard stream_options parameterif(self._client.modelandself._client.model.model_typeandself._client.model.model_type=="nv-vlm"):payload.pop("stream_options")forresponseinself._client.get_req_stream(payload=payload,extra_headers=extra_headers):self._set_callback_out(response,run_manager)parsed_response=self._custom_postprocess(response,streaming=True)# for pre 0.2 compatibility w/ ChatMessageChunk# ChatMessageChunk had a role property that was not# present in AIMessageChunk# unfortunately, AIMessageChunk does not have extensible propery# parsed_response.update({"role": "assistant"})message=AIMessageChunk(**parsed_response)chunk=ChatGenerationChunk(message=message)ifrun_manager:run_manager.on_llm_new_token(chunk.text,chunk=chunk)yieldchunkdef_set_callback_out(self,result:dict,run_manager:Optional[_CallbackManager],)->None:result.update({"model_name":self.model})ifrun_manager:forcbinrun_manager.handlers:ifhasattr(cb,"llm_output"):cb.llm_output=resultdef_custom_postprocess(self,msg:dict,streaming:bool=False)->dict:# todo: removekw_left=msg.copy()out_dict={"role":kw_left.pop("role","assistant")or"assistant","name":kw_left.pop("name",None),"id":kw_left.pop("id",None),"content":kw_left.pop("content","")or"","additional_kwargs":{},"response_metadata":{},}iftoken_usage:=kw_left.pop("token_usage",None):out_dict["usage_metadata"]={"input_tokens":token_usage.get("prompt_tokens",0),"output_tokens":token_usage.get("completion_tokens",0),"total_tokens":token_usage.get("total_tokens",0),}# "tool_calls" is set for invoke and stream responsesiftool_calls:=kw_left.pop("tool_calls",None):assertisinstance(tool_calls,list),"invalid response from server: tool_calls must be a list"# todo: break this into post-processing for invoke and streamifnotstreaming:out_dict["additional_kwargs"]["tool_calls"]=tool_callselifstreaming:out_dict["tool_call_chunks"]=[]fortool_callintool_calls:# todo: the nim api does not return the function index# for tool calls in stream responses. this is# an issue that needs to be resolved server-side.# the only reason we can skip this for now# is because the nim endpoint returns only full# tool calls, no deltas.# assert "index" in tool_call, (# "invalid response from server: "# "tool_call must have an 'index' key"# )assert"function"intool_call,("invalid response from server: ""tool_call must have a 'function' key")out_dict["tool_call_chunks"].append({"index":tool_call.get("index",None),"id":tool_call.get("id",None),"name":tool_call["function"].get("name",None),"args":tool_call["function"].get("arguments",None),})# we only create the response_metadata from the last message in a stream.# if we do it for all messages, we'll end up with things like# "model_name" = "mode-xyz" * # messages.if"finish_reason"inkw_left:out_dict["response_metadata"]=kw_leftreturnout_dict######################################################################################## Core client-side interfacesdef_get_payload(self,inputs:Sequence[Dict],**kwargs:Any)->dict:# todo: remove"""Generates payload for the _NVIDIAClient API to send to service."""messages:List[Dict[str,Any]]=[]formsgininputs:ifisinstance(msg,str):# (WFH) this shouldn't ever be reached but leaving this here bcs# it's a Chesterton's fence I'm unwilling to touchmessages.append(dict(role="user",content=msg))elifisinstance(msg,dict):ifmsg.get("content",None)isNone:# content=None is valid for assistant messages (tool calling)ifnotmsg.get("role")=="assistant":raiseValueError(f"Message {msg} has no content.")messages.append(msg)else:raiseValueError(f"Unknown message received: {msg} of type {type(msg)}")# special handling for "stop" because it always comes in kwargs.# if user provided "stop" to invoke/stream, it will be non-None# in kwargs.# note: we cannot tell if the user specified stop=None to invoke/stream because# the default value of stop is None.# todo: remove self.stopassert"stop"inkwargs,'"stop" param is expected in kwargs'ifkwargs["stop"]isNone:kwargs.pop("stop")# setup default payload valuespayload:Dict[str,Any]={"model":self.model,"temperature":self.temperature,"max_tokens":self.max_tokens,"top_p":self.top_p,"seed":self.seed,"stop":self.stop,}# merge incoming kwargs with attr_kwargs giving preference to# the incoming kwargspayload.update(kwargs)# remove keys with None values from payloadpayload={k:vfork,vinpayload.items()ifvisnotNone}return{"messages":messages,**payload}
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type,Callable,BaseTool]],*,tool_choice:Optional[Union[dict,str,Literal["auto","none","any","required"],bool]]=None,**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:""" Bind tools to the model. Notes: - The `strict` mode is always in effect, if you need it disabled, please file an issue. Args: tools (list): A list of tools to bind to the model. tool_choice (Optional[Union[dict, str, Literal["auto", "none", "any", "required"], bool]]): Control tool choice. "any" and "required" - force a tool call. "auto" - let the model decide. "none" - force no tool call. string or dict - force a specific tool call. bool - if True, force a tool call; if False, force no tool call. Defaults to passing no value. **kwargs: Additional keyword arguments. see https://python.langchain.com/v0.1/docs/modules/model_io/chat/function_calling/#request-forcing-a-tool-call """# check if the model supports tools, warn if it does notifself._client.modelandnotself._client.model.supports_tools:warnings.warn(f"Model '{self.model}' is not known to support tools. ""Your tool binding may fail at inference time.")ifkwargs.get("strict",True)isnotTrue:warnings.warn("The `strict` parameter is not necessary and is ignored.")tool_name=Noneifisinstance(tool_choice,bool):tool_choice="required"iftool_choiceelse"none"elifisinstance(tool_choice,str):# LangChain documents "any" as an option, server API uses "required"iftool_choice=="any":tool_choice="required"# if a string that's not "auto", "none", or "required"# then it must be a tool nameiftool_choicenotin["auto","none","required"]:tool_name=tool_choicetool_choice={"type":"function","function":{"name":tool_choice},}elifisinstance(tool_choice,dict):# if a dict, it must be a tool choice dict, e.g.# {"type": "function", "function": {"name": "my_tool"}}if"type"notintool_choice:tool_choice["type"]="function"if"function"notintool_choice:raiseValueError("Tool choice dict must have a 'function' key")if"name"notintool_choice["function"]:raiseValueError("Tool choice function dict must have a 'name' key")tool_name=tool_choice["function"]["name"]# check that the specified tool is in the tools listtool_dicts=[convert_to_openai_tool(tool)fortoolintools]iftool_name:ifnotany(tool["function"]["name"]==tool_namefortoolintool_dicts):raiseValueError(f"Tool choice '{tool_name}' not found in the tools list")returnsuper().bind(tools=tool_dicts,tool_choice=tool_choice,**kwargs,)
[docs]defbind_functions(self,functions:Sequence[Union[Dict[str,Any],Type[BaseModel],Callable]],function_call:Optional[str]=None,**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:raiseNotImplementedError("Not implemented, use `bind_tools` instead.")
# we have an Enum extension to BaseChatModel.with_structured_output and# as a result need to type ignore for the schema parameter and return type.
[docs]defwith_structured_output(# type: ignoreself,schema:Union[Dict,Type],*,include_raw:bool=False,**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:""" Bind a structured output schema to the model. Args: schema (Union[Dict, Type]): The schema to bind to the model. include_raw (bool): Always False. Passing True raises an error. **kwargs: Additional keyword arguments. Notes: - `strict` mode is always in effect, if you need it disabled, please file an issue. - if you need `include_raw=True` consider using an unstructured model and output formatter, or file an issue. The schema can be - 0. a dictionary representing a JSON schema 1. a Pydantic object 2. an Enum 0. If a dictionary is provided, the model will return a dictionary. Example: ``` json_schema = { "title": "joke", "description": "Joke to tell user.", "type": "object", "properties": { "setup": { "type": "string", "description": "The setup of the joke", }, "punchline": { "type": "string", "description": "The punchline to the joke", }, }, "required": ["setup", "punchline"], } structured_llm = llm.with_structured_output(json_schema) structured_llm.invoke("Tell me a joke about NVIDIA") # Output: {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', # 'punchline': 'It took a big bite out of their main board.'} ``` 1. If a Pydantic schema is provided, the model will return a Pydantic object. Example: ``` from pydantic import BaseModel, Field class Joke(BaseModel): setup: str = Field(description="The setup of the joke") punchline: str = Field(description="The punchline to the joke") structured_llm = llm.with_structured_output(Joke) structured_llm.invoke("Tell me a joke about NVIDIA") # Output: Joke(setup='Why did NVIDIA go broke? The hardware ate all the software.', # punchline='It took a big bite out of their main board.') ``` 2. If an Enum is provided, all values must be strings, and the model will return an Enum object. Example: ``` import enum class Choices(enum.Enum): A = "A" B = "B" C = "C" structured_llm = llm.with_structured_output(Choices) structured_llm.invoke("What is the first letter in this list? [X, Y, Z, C]") # Output: <Choices.C: 'C'> ``` Note about streaming: Unlike other streaming responses, the streamed chunks will be increasingly complete. They will not be deltas. The last chunk will contain the complete response. For instance with a dictionary schema, the chunks will be: ``` structured_llm = llm.with_structured_output(json_schema) for chunk in structured_llm.stream("Tell me a joke about NVIDIA"): print(chunk) # Output: # {} # {'setup': ''} # {'setup': 'Why'} # {'setup': 'Why did'} # {'setup': 'Why did N'} # {'setup': 'Why did NVID'} # ... # {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', 'punchline': 'It took a big bite out of their main board'} # {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', 'punchline': 'It took a big bite out of their main board.'} ``` For instnace with a Pydantic schema, the chunks will be: ``` structured_llm = llm.with_structured_output(Joke) for chunk in structured_llm.stream("Tell me a joke about NVIDIA"): print(chunk) # Output: # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='' # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It' # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took' # ... # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took a big bite out of their main board' # setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took a big bite out of their main board.' ``` For Pydantic schema and Enum, the output will be None if the response is insufficient to construct the object or otherwise invalid. For instance, ``` llm = ChatNVIDIA(max_tokens=1) structured_llm = llm.with_structured_output(Joke) print(structured_llm.invoke("Tell me a joke about NVIDIA")) # Output: None ``` For more, see https://python.langchain.com/docs/how_to/structured_output/ """# noqa: E501if"method"inkwargs:warnings.warn("The 'method' parameter is unnecessary and is ignored. ""The appropriate method will be chosen automatically depending ""on the type of schema provided.")ifkwargs.get("strict",True)isnotTrue:warnings.warn("Structured output always follows strict validation. ""`strict` is ignored. Please file an issue if you ""need strict validation disabled.")ifinclude_raw:raiseNotImplementedError("include_raw=True is not implemented, consider ""https://python.langchain.com/docs/how_to/""structured_output/#prompting-and-parsing-model""-outputs-directly or rely on the structured response ""being None when the LLM produces an incomplete response.")# check if the model supports structured output, warn if it does notknown_good=False# todo: we need to store model: Model in this class# instead of model: str (= Model.id)# this should be: if not self.model.supports_tools: warnings.warn...candidates=[modelformodelinself.available_modelsifmodel.id==self.model]ifnotcandidates:# user must have specified the model themselvesknown_good=Falseelse:assertlen(candidates)==1,"Multiple models with the same id"known_good=candidates[0].supports_structured_outputisTrueifnotknown_good:warnings.warn(f"Model '{self.model}' is not known to support structured output. ""Your output may fail at inference time.")ifisinstance(schema,dict):output_parser:BaseOutputParser=JsonOutputParser()nvext_param:Dict[str,Any]={"guided_json":schema}elifissubclass(schema,enum.Enum):# langchain's EnumOutputParser is not in langchain_core# and doesn't support streaming. this is a simple implementation# that supports streaming with our semantics of returning None# if no complete object can be constructed.classEnumOutputParser(BaseOutputParser):enum:Type[enum.Enum]defparse(self,response:str)->Any:try:returnself.enum(response.strip())exceptValueError:passreturnNone# guided_choice only supports string choiceschoices=[choice.valueforchoiceinschema]ifnotall(isinstance(choice,str)forchoiceinchoices):# instead of erroring out we could coerce the enum values to# strings, but would then need to coerce them back to their# original type for Enum construction.raiseValueError("Enum schema must only contain string choices. ""Use StrEnum or ensure all member values are strings.")output_parser=EnumOutputParser(enum=schema)nvext_param={"guided_choice":choices}elifis_basemodel_subclass(schema):# PydanticOutputParser does not support streaming. what we do# instead is ignore all inputs that are incomplete wrt the# underlying Pydantic schema. if the entire input is invalid,# we return None.classForgivingPydanticOutputParser(PydanticOutputParser):defparse_result(self,result:List[Generation],*,partial:bool=False)->Any:try:returnsuper().parse_result(result,partial=partial)exceptOutputParserException:passreturnNoneoutput_parser=ForgivingPydanticOutputParser(pydantic_object=schema)ifhasattr(schema,"model_json_schema"):json_schema=schema.model_json_schema()else:json_schema=schema.schema()nvext_param={"guided_json":json_schema}else:raiseValueError("Schema must be a Pydantic object, a dictionary ""representing a JSON schema, or an Enum.")returnsuper().bind(nvext=nvext_param)|output_parser