Source code for langchain_google_genai.llms

from __future__ import annotations

from typing import Any, Iterator, List, Optional

from langchain_core.callbacks import (
    CallbackManagerForLLMRun,
)
from langchain_core.language_models import LangSmithParams
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import HumanMessage
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from pydantic import ConfigDict, model_validator
from typing_extensions import Self

from langchain_google_genai._common import (
    _BaseGoogleGenerativeAI,
)
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI


[docs] class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): """Google GenerativeAI models. Example: .. code-block:: python from langchain_google_genai import GoogleGenerativeAI llm = GoogleGenerativeAI(model="gemini-pro") """ client: Any = None #: :meta private: model_config = ConfigDict( populate_by_name=True, ) @model_validator(mode="after") def validate_environment(self) -> Self: """Validates params and passes them to google-generativeai package.""" self.client = ChatGoogleGenerativeAI( api_key=self.google_api_key, credentials=self.credentials, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, max_tokens=self.max_output_tokens, timeout=self.timeout, model=self.model, client_options=self.client_options, transport=self.transport, additional_headers=self.additional_headers, safety_settings=self.safety_settings, ) return self def _get_ls_params( self, stop: Optional[List[str]] = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" ls_params = super()._get_ls_params(stop=stop, **kwargs) ls_params["ls_provider"] = "google_genai" if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens): ls_params["ls_max_tokens"] = ls_max_tokens return ls_params def _generate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: generations = [] for prompt in prompts: chat_result = self.client._generate( [HumanMessage(content=prompt)], stop=stop, run_manager=run_manager, **kwargs, ) generations.append( [ Generation( text=g.message.content, generation_info={ **g.generation_info, **{"usage_metadata": g.message.usage_metadata}, }, ) for g in chat_result.generations ] ) return LLMResult(generations=generations) def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: for stream_chunk in self.client._stream( [HumanMessage(content=prompt)], stop=stop, run_manager=run_manager, **kwargs, ): chunk = GenerationChunk(text=stream_chunk.message.content) yield chunk if run_manager: run_manager.on_llm_new_token( chunk.text, chunk=chunk, verbose=self.verbose, ) @property def _llm_type(self) -> str: """Return type of llm.""" return "google_gemini"
[docs] def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text. Useful for checking if an input will fit in a model's context window. Args: text: The string input to tokenize. Returns: The integer number of tokens in the text. """ return self.client.get_num_tokens(text)