[docs]@dataclasses.dataclassclassAviaryBackend:"""Aviary backend. Attributes: backend_url: The URL for the Aviary backend. bearer: The bearer token for the Aviary backend. """backend_url:strbearer:strdef__post_init__(self)->None:self.header={"Authorization":self.bearer}
[docs]@classmethoddeffrom_env(cls)->"AviaryBackend":aviary_url=os.getenv("AVIARY_URL")assertaviary_url,"AVIARY_URL must be set"aviary_token=os.getenv("AVIARY_TOKEN","")bearer=f"Bearer {aviary_token}"ifaviary_tokenelse""aviary_url+="/"ifnotaviary_url.endswith("/")else""returncls(aviary_url,bearer)
[docs]defget_models()->List[str]:"""List available models"""backend=AviaryBackend.from_env()request_url=backend.backend_url+"-/routes"response=requests.get(request_url,headers=backend.header,timeout=TIMEOUT)try:result=response.json()exceptrequests.JSONDecodeErrorase:raiseRuntimeError(f"Error decoding JSON from {request_url}. Text response: {response.text}")fromeresult=sorted([k.lstrip("/").replace("--","/")forkinresult.keys()if"--"ink])returnresult
[docs]defget_completions(model:str,prompt:str,use_prompt_format:bool=True,version:str="",)->Dict[str,Union[str,float,int]]:"""Get completions from Aviary models."""backend=AviaryBackend.from_env()url=backend.backend_url+model.replace("/","--")+"/"+version+"query"response=requests.post(url,headers=backend.header,json={"prompt":prompt,"use_prompt_format":use_prompt_format},timeout=TIMEOUT,)try:returnresponse.json()exceptrequests.JSONDecodeErrorase:raiseRuntimeError(f"Error decoding JSON from {url}. Text response: {response.text}")frome
[docs]classAviary(LLM):"""Aviary hosted models. Aviary is a backend for hosted models. You can find out more about aviary at http://github.com/ray-project/aviary To get a list of the models supported on an aviary, follow the instructions on the website to install the aviary CLI and then use: `aviary models` AVIARY_URL and AVIARY_TOKEN environment variables must be set. Attributes: model: The name of the model to use. Defaults to "amazon/LightGPT". aviary_url: The URL for the Aviary backend. Defaults to None. aviary_token: The bearer token for the Aviary backend. Defaults to None. use_prompt_format: If True, the prompt template for the model will be ignored. Defaults to True. version: API version to use for Aviary. Defaults to None. Example: .. code-block:: python from langchain_community.llms import Aviary os.environ["AVIARY_URL"] = "<URL>" os.environ["AVIARY_TOKEN"] = "<TOKEN>" light = Aviary(model='amazon/LightGPT') output = light('How do you make fried rice?') """model:str="amazon/LightGPT"aviary_url:Optional[str]=Noneaviary_token:Optional[str]=None# If True the prompt template for the model will be ignored.use_prompt_format:bool=True# API version to use for Aviaryversion:Optional[str]=NoneclassConfig:extra="forbid"@root_validator(pre=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""aviary_url=get_from_dict_or_env(values,"aviary_url","AVIARY_URL")aviary_token=get_from_dict_or_env(values,"aviary_token","AVIARY_TOKEN")# Set env viarables for aviary sdkos.environ["AVIARY_URL"]=aviary_urlos.environ["AVIARY_TOKEN"]=aviary_tokentry:aviary_models=get_models()exceptrequests.exceptions.RequestExceptionase:raiseValueError(e)model=values.get("model")ifmodelandmodelnotinaviary_models:raiseValueError(f"{aviary_url} does not support model {values['model']}.")returnvalues@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""return{"model_name":self.model,"aviary_url":self.aviary_url,}@propertydef_llm_type(self)->str:"""Return type of llm."""returnf"aviary-{self.model.replace('/','-')}"def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to Aviary Args: prompt: The prompt to pass into the model. Returns: The string generated by the model. Example: .. code-block:: python response = aviary("Tell me a joke.") """kwargs={"use_prompt_format":self.use_prompt_format}ifself.version:kwargs["version"]=self.versionoutput=get_completions(model=self.model,prompt=prompt,**kwargs,)text=cast(str,output["generated_text"])ifstop:text=enforce_stop_tokens(text,stop)returntext