fromtypingimportAny,Dict,List,Optionalimportrequestsfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_modelsimportLLMfromlangchain_core.utilsimportconvert_to_secret_str,get_from_dict_or_env,pre_initfrompydanticimport(BaseModel,ConfigDict,Field,SecretStr,model_validator,)fromlangchain_community.llms.utilsimportenforce_stop_tokensMOONSHOT_SERVICE_URL_BASE="https://api.moonshot.cn/v1"class_MoonshotClient(BaseModel):"""An API client that talks to the Moonshot server."""api_key:SecretStr"""The API key to use for authentication."""base_url:str=MOONSHOT_SERVICE_URL_BASEdefcompletion(self,request:Any)->Any:headers={"Authorization":f"Bearer {self.api_key.get_secret_value()}"}response=requests.post(f"{self.base_url}/chat/completions",headers=headers,json=request,)ifnotresponse.ok:raiseValueError(f"HTTP {response.status_code} error: {response.text}")returnresponse.json()["choices"][0]["message"]["content"]
[docs]classMoonshotCommon(BaseModel):"""Common parameters for Moonshot LLMs."""client:Anybase_url:str=MOONSHOT_SERVICE_URL_BASEmoonshot_api_key:Optional[SecretStr]=Field(default=None,alias="api_key")"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys"""model_name:str=Field(default="moonshot-v1-8k",alias="model")"""Model name. Available models listed here: https://platform.moonshot.cn/pricing"""max_tokens:int=1024"""Maximum number of tokens to generate."""temperature:float=0.3"""Temperature parameter (higher values make the model more creative)."""model_config=ConfigDict(populate_by_name=True,protected_namespaces=())@propertydeflc_secrets(self)->dict:"""A map of constructor argument names to secret ids. For example, {"moonshot_api_key": "MOONSHOT_API_KEY"} """return{"moonshot_api_key":"MOONSHOT_API_KEY"}@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling OpenAI API."""return{"model":self.model_name,"max_tokens":self.max_tokens,"temperature":self.temperature,}@propertydef_invocation_params(self)->Dict[str,Any]:return{**{"model":self.model_name},**self._default_params}@model_validator(mode="before")@classmethoddefbuild_extra(cls,values:Dict[str,Any])->Any:"""Build extra parameters. Override the superclass method, prevent the model parameter from being overridden. """returnvalues
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""values["moonshot_api_key"]=convert_to_secret_str(get_from_dict_or_env(values,"moonshot_api_key","MOONSHOT_API_KEY"))values["client"]=_MoonshotClient(api_key=values["moonshot_api_key"],base_url=values["base_url"]if"base_url"invalueselseMOONSHOT_SERVICE_URL_BASE,)returnvalues
@propertydef_llm_type(self)->str:"""Return type of llm."""return"moonshot"
[docs]classMoonshot(MoonshotCommon,LLM):"""Moonshot large language models. To use, you should have the environment variable ``MOONSHOT_API_KEY`` set with your API key. Referenced from https://platform.moonshot.cn/docs Example: .. code-block:: python from langchain_community.llms.moonshot import Moonshot moonshot = Moonshot(model="moonshot-v1-8k") """model_config=ConfigDict(populate_by_name=True,)def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:request=self._invocation_paramsrequest["messages"]=[{"role":"user","content":prompt}]request.update(kwargs)text=self.client.completion(request)ifstopisnotNone:# This is required since the stop tokens# are not enforced by the model parameterstext=enforce_stop_tokens(text,stop)returntext