from__future__importannotationsimportloggingfromfunctoolsimportcached_propertyfromtypingimportTYPE_CHECKING,Any,AsyncIterator,Dict,Iterator,List,Optionalfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.llmsimportBaseLLMfromlangchain_core.load.serializableimportSerializablefromlangchain_core.outputsimportGeneration,GenerationChunk,LLMResultfromlangchain_core.utilsimportpre_initfromlangchain_core.utils.pydanticimportget_fieldsfrompydanticimportConfigDictifTYPE_CHECKING:importgigachatimportgigachat.modelsasgmlogger=logging.getLogger(__name__)class_BaseGigaChat(Serializable):base_url:Optional[str]=None""" Base API URL """auth_url:Optional[str]=None""" Auth URL """credentials:Optional[str]=None""" Auth Token """scope:Optional[str]=None""" Permission scope for access token """access_token:Optional[str]=None""" Access token for GigaChat """model:Optional[str]=None"""Model name to use."""user:Optional[str]=None""" Username for authenticate """password:Optional[str]=None""" Password for authenticate """timeout:Optional[float]=None""" Timeout for request """verify_ssl_certs:Optional[bool]=None""" Check certificates for all requests """ca_bundle_file:Optional[str]=Nonecert_file:Optional[str]=Nonekey_file:Optional[str]=Nonekey_file_password:Optional[str]=None# Support for connection to GigaChat through SSL certificatesprofanity:bool=True""" DEPRECATED: Check for profanity """profanity_check:Optional[bool]=None""" Check for profanity """streaming:bool=False""" Whether to stream the results or not. """temperature:Optional[float]=None""" What sampling temperature to use. """max_tokens:Optional[int]=None""" Maximum number of tokens to generate """use_api_for_tokens:bool=False""" Use GigaChat API for tokens count """verbose:bool=False""" Verbose logging """top_p:Optional[float]=None""" top_p value to use for nucleus sampling. Must be between 0.0 and 1.0 """repetition_penalty:Optional[float]=None""" The penalty applied to repeated tokens """update_interval:Optional[float]=None""" Minimum interval in seconds that elapses between sending tokens """@propertydef_llm_type(self)->str:return"giga-chat-model"@propertydeflc_secrets(self)->Dict[str,str]:return{"credentials":"GIGACHAT_CREDENTIALS","access_token":"GIGACHAT_ACCESS_TOKEN","password":"GIGACHAT_PASSWORD","key_file_password":"GIGACHAT_KEY_FILE_PASSWORD",}@propertydeflc_serializable(self)->bool:returnTrue@cached_propertydef_client(self)->gigachat.GigaChat:"""Returns GigaChat API client"""importgigachatreturngigachat.GigaChat(base_url=self.base_url,auth_url=self.auth_url,credentials=self.credentials,scope=self.scope,access_token=self.access_token,model=self.model,profanity_check=self.profanity_check,user=self.user,password=self.password,timeout=self.timeout,verify_ssl_certs=self.verify_ssl_certs,ca_bundle_file=self.ca_bundle_file,cert_file=self.cert_file,key_file=self.key_file,key_file_password=self.key_file_password,verbose=self.verbose,)@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate authenticate data in environment and python package is installed."""try:importgigachat# noqa: F401exceptImportError:raiseImportError("Could not import gigachat python package. ""Please install it with `pip install gigachat`.")fields=set(get_fields(cls).keys())diff=set(values.keys())-fieldsifdiff:logger.warning(f"Extra fields {diff} in GigaChat class")if"profanity"infieldsandvalues.get("profanity")isFalse:logger.warning("'profanity' field is deprecated. Use 'profanity_check' instead.")ifvalues.get("profanity_check")isNone:values["profanity_check"]=values.get("profanity")returnvalues@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{"temperature":self.temperature,"model":self.model,"profanity":self.profanity_check,"streaming":self.streaming,"max_tokens":self.max_tokens,"top_p":self.top_p,"repetition_penalty":self.repetition_penalty,}deftokens_count(self,input_:List[str],model:Optional[str]=None)->List[gm.TokensCount]:"""Get tokens of string list"""returnself._client.tokens_count(input_,model)asyncdefatokens_count(self,input_:List[str],model:Optional[str]=None)->List[gm.TokensCount]:"""Get tokens of strings list (async)"""returnawaitself._client.atokens_count(input_,model)defget_models(self)->gm.Models:"""Get available models of Gigachat"""returnself._client.get_models()asyncdefaget_models(self)->gm.Models:"""Get available models of Gigachat (async)"""returnawaitself._client.aget_models()defget_model(self,model:str)->gm.Model:"""Get info about model"""returnself._client.get_model(model)asyncdefaget_model(self,model:str)->gm.Model:"""Get info about model (async)"""returnawaitself._client.aget_model(model)defget_num_tokens(self,text:str)->int:"""Count approximate number of tokens"""ifself.use_api_for_tokens:returnself.tokens_count([text])[0].tokens# type: ignoreelse:returnround(len(text)/4.6)
[docs]classGigaChat(_BaseGigaChat,BaseLLM):"""`GigaChat` large language models API. To use, you should pass login and password to access GigaChat API or use token. Example: .. code-block:: python from langchain_community.llms import GigaChat giga = GigaChat(credentials=..., scope=..., verify_ssl_certs=False) """payload_role:str="user"def_build_payload(self,messages:List[str])->Dict[str,Any]:payload:Dict[str,Any]={"messages":[{"role":self.payload_role,"content":m}forminmessages],}ifself.model:payload["model"]=self.modelifself.profanity_checkisnotNone:payload["profanity_check"]=self.profanity_checkifself.temperatureisnotNone:payload["temperature"]=self.temperatureifself.top_pisnotNone:payload["top_p"]=self.top_pifself.max_tokensisnotNone:payload["max_tokens"]=self.max_tokensifself.repetition_penaltyisnotNone:payload["repetition_penalty"]=self.repetition_penaltyifself.update_intervalisnotNone:payload["update_interval"]=self.update_intervalifself.verbose:logger.info("Giga request: %s",payload)returnpayloaddef_create_llm_result(self,response:Any)->LLMResult:generations=[]forresinresponse.choices:finish_reason=res.finish_reasongen=Generation(text=res.message.content,generation_info={"finish_reason":finish_reason},)generations.append([gen])iffinish_reason!="stop":logger.warning("Giga generation stopped with reason: %s",finish_reason,)ifself.verbose:logger.info("Giga response: %s",res.message.content)token_usage=response.usagellm_output={"token_usage":token_usage,"model_name":response.model}returnLLMResult(generations=generations,llm_output=llm_output)def_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->LLMResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:generation:Optional[GenerationChunk]=Nonestream_iter=self._stream(prompts[0],stop=stop,run_manager=run_manager,**kwargs)forchunkinstream_iter:ifgenerationisNone:generation=chunkelse:generation+=chunkassertgenerationisnotNonereturnLLMResult(generations=[[generation]])payload=self._build_payload(prompts)response=self._client.chat(payload)returnself._create_llm_result(response)asyncdef_agenerate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->LLMResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:generation:Optional[GenerationChunk]=Nonestream_iter=self._astream(prompts[0],stop=stop,run_manager=run_manager,**kwargs)asyncforchunkinstream_iter:ifgenerationisNone:generation=chunkelse:generation+=chunkassertgenerationisnotNonereturnLLMResult(generations=[[generation]])payload=self._build_payload(prompts)response=awaitself._client.achat(payload)returnself._create_llm_result(response)def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:payload=self._build_payload([prompt])forchunkinself._client.stream(payload):ifchunk.choices:content=chunk.choices[0].delta.contentifrun_manager:run_manager.on_llm_new_token(content)yieldGenerationChunk(text=content)asyncdef_astream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[GenerationChunk]:payload=self._build_payload([prompt])asyncforchunkinself._client.astream(payload):ifchunk.choices:content=chunk.choices[0].delta.contentifrun_manager:awaitrun_manager.on_llm_new_token(content)yieldGenerationChunk(text=content)model_config=ConfigDict(extra="allow",)