fromtypingimportAny,Dict,List,Mapping,Optionalimportrequestsfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_models.llmsimportLLMfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfrompydanticimportConfigDictfromlangchain_community.llms.utilsimportenforce_stop_tokensINSTRUCTION_KEY="### Instruction:"RESPONSE_KEY="### Response:"INTRO_BLURB=("Below is an instruction that describes a task. ""Write a response that appropriately completes the request.")PROMPT_FOR_GENERATION_FORMAT="""{intro}{instruction_key}{instruction}{response_key}""".format(intro=INTRO_BLURB,instruction_key=INSTRUCTION_KEY,instruction="{instruction}",response_key=RESPONSE_KEY,)
[docs]classMosaicML(LLM):"""MosaicML LLM service. To use, you should have the environment variable ``MOSAICML_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.llms import MosaicML endpoint_url = ( "https://models.hosted-on.mosaicml.hosting/mpt-7b-instruct/v1/predict" ) mosaic_llm = MosaicML( endpoint_url=endpoint_url, mosaicml_api_token="my-api-key" ) """endpoint_url:str=("https://models.hosted-on.mosaicml.hosting/mpt-7b-instruct/v1/predict")"""Endpoint URL to use."""inject_instruction_format:bool=False"""Whether to inject the instruction format into the prompt."""model_kwargs:Optional[dict]=None"""Keyword arguments to pass to the model."""retry_sleep:float=1.0"""How long to try sleeping for if a rate limit is encountered"""mosaicml_api_token:Optional[str]=Nonemodel_config=ConfigDict(extra="forbid",)
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""mosaicml_api_token=get_from_dict_or_env(values,"mosaicml_api_token","MOSAICML_API_TOKEN")values["mosaicml_api_token"]=mosaicml_api_tokenreturnvalues
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""_model_kwargs=self.model_kwargsor{}return{**{"endpoint_url":self.endpoint_url},**{"model_kwargs":_model_kwargs},}@propertydef_llm_type(self)->str:"""Return type of llm."""return"mosaic"def_transform_prompt(self,prompt:str)->str:"""Transform prompt."""ifself.inject_instruction_format:prompt=PROMPT_FOR_GENERATION_FORMAT.format(instruction=prompt,)returnpromptdef_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,is_retry:bool=False,**kwargs:Any,)->str:"""Call out to a MosaicML LLM inference endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = mosaic_llm.invoke("Tell me a joke.") """_model_kwargs=self.model_kwargsor{}prompt=self._transform_prompt(prompt)payload={"inputs":[prompt]}payload.update(_model_kwargs)payload.update(kwargs)# HTTP headers for authorizationheaders={"Authorization":f"{self.mosaicml_api_token}","Content-Type":"application/json",}# send requesttry:response=requests.post(self.endpoint_url,headers=headers,json=payload)exceptrequests.exceptions.RequestExceptionase:raiseValueError(f"Error raised by inference endpoint: {e}")try:ifresponse.status_code==429:ifnotis_retry:importtimetime.sleep(self.retry_sleep)returnself._call(prompt,stop,run_manager,is_retry=True)raiseValueError(f"Error raised by inference API: rate limit exceeded.\nResponse: "f"{response.text}")parsed_response=response.json()# The inference API has changed a couple of times, so we add some handling# to be robust to multiple response formats.ifisinstance(parsed_response,dict):output_keys=["data","output","outputs"]forkeyinoutput_keys:ifkeyinparsed_response:output_item=parsed_response[key]breakelse:raiseValueError(f"No valid key ({', '.join(output_keys)}) in response:"f" {parsed_response}")ifisinstance(output_item,list):text=output_item[0]else:text=output_itemelse:raiseValueError(f"Unexpected response type: {parsed_response}")# Older versions of the API include the input in the output responseiftext.startswith(prompt):text=text[len(prompt):]exceptrequests.exceptions.JSONDecodeErrorase:raiseValueError(f"Error raised by inference API: {e}.\nResponse: {response.text}")# TODO: replace when MosaicML supports custom stop tokens nativelyifstopisnotNone:text=enforce_stop_tokens(text,stop)returntext