Source code for langchain_google_vertexai.chains

from typing import (
    Dict,
    Optional,
    Sequence,
    Type,
    Union,
)

import google.cloud.aiplatform_v1beta1.types as gapic
from langchain_core.output_parsers import (
    BaseGenerationOutputParser,
    BaseOutputParser,
    StrOutputParser,
)
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable
from pydantic import BaseModel

from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser


[docs] def get_output_parser( functions: Sequence[Type[BaseModel]], ) -> Union[BaseOutputParser, BaseGenerationOutputParser]: """Get the appropriate function output parser given the user functions. Args: functions: Sequence where element is a dictionary, a pydantic.BaseModel class, or a Python function. If a dictionary is passed in, it is assumed to already be a valid OpenAI function. Returns: A PydanticFunctionsOutputParser """ function_names = [f.__name__ for f in functions] if len(functions) > 1: pydantic_schema: Union[Dict, Type[BaseModel]] = { name: fn for name, fn in zip(function_names, functions) } else: pydantic_schema = functions[0] output_parser: Union[ BaseOutputParser, BaseGenerationOutputParser ] = PydanticFunctionsOutputParser(pydantic_schema=pydantic_schema) return output_parser
def _create_structured_runnable_extra_step( functions: Sequence[Type[BaseModel]], llm: Runnable, *, prompt: Optional[BasePromptTemplate] = None, ) -> Runnable: names = [ schema.model_json_schema()["title"] if hasattr(schema, "model_json_schema") else schema.schema()["title"] for schema in functions ] if hasattr(llm, "is_gemini_advanced") and llm._is_gemini_advanced: # type: ignore llm_with_functions = llm.bind( functions=functions, tool_config={ "function_calling_config": { "mode": gapic.FunctionCallingConfig.Mode.ANY, "allowed_function_names": names, } }, ) else: llm_with_functions = llm.bind( functions=functions, ) parsing_prompt = ChatPromptTemplate.from_template( "You are a world class algorithm for recording entities.\nMake calls " "to the relevant function to record the entities in the following " "input:\n{output}\nTip: Make sure to answer in the correct format." ) output_parser = get_output_parser(functions) if prompt: initial_chain = ( prompt | llm | StrOutputParser() | parsing_prompt | llm_with_functions ) else: initial_chain = parsing_prompt | llm_with_functions return initial_chain | output_parser
[docs] def create_structured_runnable( function: Union[Type[BaseModel], Sequence[Type[BaseModel]]], llm: Runnable, *, prompt: Optional[BasePromptTemplate] = None, use_extra_step: bool = False, ) -> Runnable: """Create a runnable sequence that uses OpenAI functions. Args: function: Either a single pydantic.BaseModel class or a sequence of pydantic.BaseModels classes. For best results, pydantic.BaseModels should have descriptions of the parameters. llm: Language model to use, assumed to support the Google Vertex function-calling API. prompt: BasePromptTemplate to pass to the model. use_extra_step: whether to make an extra step to parse output into a function Returns: A runnable sequence that will pass in the given functions to the model when run. Example: .. code-block:: python from typing import Optional from langchain_google_vertexai import ChatVertexAI, create_structured_runnable from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field class RecordPerson(BaseModel): \"\"\"Record some identifying information about a person.\"\"\" name: str = Field(..., description="The person's name") age: int = Field(..., description="The person's age") fav_food: Optional[str] = Field(None, description="The person's favorite food") class RecordDog(BaseModel): \"\"\"Record some identifying information about a dog.\"\"\" name: str = Field(..., description="The dog's name") color: str = Field(..., description="The dog's color") fav_food: Optional[str] = Field(None, description="The dog's favorite food") llm = ChatVertexAI(model_name="gemini-pro") prompt = ChatPromptTemplate.from_template(\"\"\" You are a world class algorithm for recording entities. Make calls to the relevant function to record the entities in the following input: {input} Tip: Make sure to answer in the correct format\"\"\" ) chain = create_structured_runnable([RecordPerson, RecordDog], llm, prompt=prompt) chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) # -> RecordDog(name="Harry", color="brown", fav_food="chicken") """ # noqa: E501 if not function: raise ValueError("Need to pass in at least one function. Received zero.") functions = function if isinstance(function, Sequence) else [function] if use_extra_step: return _create_structured_runnable_extra_step( functions=functions, llm=llm, prompt=prompt ) output_parser = get_output_parser(functions) llm_with_functions = llm.bind(functions=functions) if prompt is None: initial_chain = llm_with_functions else: initial_chain = prompt | llm_with_functions return initial_chain | output_parser