[docs]classDeepInfra(LLM):"""DeepInfra models. To use, you should have the environment variable ``DEEPINFRA_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. Only supports `text-generation` and `text2text-generation` for now. Example: .. code-block:: python from langchain_community.llms import DeepInfra di = DeepInfra(model_id="google/flan-t5-xl", deepinfra_api_token="my-api-key") """model_id:str=DEFAULT_MODEL_IDmodel_kwargs:Optional[Dict]=Nonedeepinfra_api_token:Optional[str]=Nonemodel_config=ConfigDict(extra="forbid",)
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""deepinfra_api_token=get_from_dict_or_env(values,"deepinfra_api_token","DEEPINFRA_API_TOKEN")values["deepinfra_api_token"]=deepinfra_api_tokenreturnvalues
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""return{**{"model_id":self.model_id},**{"model_kwargs":self.model_kwargs},}@propertydef_llm_type(self)->str:"""Return type of llm."""return"deepinfra"def_url(self)->str:returnf"https://api.deepinfra.com/v1/inference/{self.model_id}"def_headers(self)->Dict:return{"Authorization":f"bearer {self.deepinfra_api_token}","Content-Type":"application/json",}def_body(self,prompt:str,kwargs:Any)->Dict:model_kwargs=self.model_kwargsor{}model_kwargs={**model_kwargs,**kwargs}return{"input":prompt,**model_kwargs,}def_handle_status(self,code:int,text:Any)->None:ifcode>=500:raiseException(f"DeepInfra Server: Error {text}")elifcode==401:raiseException("DeepInfra Server: Unauthorized")elifcode==403:raiseException("DeepInfra Server: Unauthorized")elifcode==404:raiseException(f"DeepInfra Server: Model not found {self.model_id}")elifcode==429:raiseException("DeepInfra Server: Rate limit exceeded")elifcode>=400:raiseValueError(f"DeepInfra received an invalid payload: {text}")elifcode!=200:raiseException(f"DeepInfra returned an unexpected response with status {code}: {text}")def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to DeepInfra's inference API endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = di("Tell me a joke.") """request=Requests(headers=self._headers())response=request.post(url=self._url(),data=self._body(prompt,kwargs))self._handle_status(response.status_code,response.text)data=response.json()returndata["results"][0]["generated_text"]asyncdef_acall(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->str:request=Requests(headers=self._headers())asyncwithrequest.apost(url=self._url(),data=self._body(prompt,kwargs))asresponse:self._handle_status(response.status,response.text)data=awaitresponse.json()returndata["results"][0]["generated_text"]def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:request=Requests(headers=self._headers())response=request.post(url=self._url(),data=self._body(prompt,{**kwargs,"stream":True}))response_text=response.textself._handle_body_errors(response_text)self._handle_status(response.status_code,response.text)forlinein_parse_stream(response.iter_lines()):chunk=_handle_sse_line(line)ifchunk:ifrun_manager:run_manager.on_llm_new_token(chunk.text)yieldchunkasyncdef_astream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[GenerationChunk]:request=Requests(headers=self._headers())asyncwithrequest.apost(url=self._url(),data=self._body(prompt,{**kwargs,"stream":True}))asresponse:response_text=awaitresponse.text()self._handle_body_errors(response_text)self._handle_status(response.status,response.text)asyncforlinein_parse_stream_async(response.content):chunk=_handle_sse_line(line)ifchunk:ifrun_manager:awaitrun_manager.on_llm_new_token(chunk.text)yieldchunkdef_handle_body_errors(self,body:str)->None:""" Example error response: data: {"error_type": "validation_error", "error_message": "ConnectionError: ..."} """if"error"inbody:try:# Remove data: prefix if presentifbody.startswith("data:"):body=body[len("data:"):]error_data=json.loads(body)error_message=error_data.get("error_message","Unknown error")raiseException(f"DeepInfra Server Error: {error_message}")exceptjson.JSONDecodeError:raiseException(f"DeepInfra Server: {body}")
def_parse_stream(rbody:Iterator[bytes])->Iterator[str]:forlineinrbody:_line=_parse_stream_helper(line)if_lineisnotNone:yield_lineasyncdef_parse_stream_async(rbody:aiohttp.StreamReader)->AsyncIterator[str]:asyncforlineinrbody:_line=_parse_stream_helper(line)if_lineisnotNone:yield_linedef_parse_stream_helper(line:bytes)->Optional[str]:iflineandline.startswith(b"data:"):ifline.startswith(b"data: "):# SSE event may be valid when it contain whitespaceline=line[len(b"data: "):]else:line=line[len(b"data:"):]ifline.strip()==b"[DONE]":# return here will cause GeneratorExit exception in urllib3# and it will close http connection with TCP ResetreturnNoneelse:returnline.decode("utf-8")returnNonedef_handle_sse_line(line:str)->Optional[GenerationChunk]:try:obj=json.loads(line)returnGenerationChunk(text=obj.get("token",{}).get("text"),)exceptException:returnNone