[docs]classSparkLLM(LLM):"""iFlyTek Spark completion model integration. Setup: To use, you should set environment variables ``IFLYTEK_SPARK_APP_ID``, ``IFLYTEK_SPARK_API_KEY`` and ``IFLYTEK_SPARK_API_SECRET``. .. code-block:: bash export IFLYTEK_SPARK_APP_ID="your-app-id" export IFLYTEK_SPARK_API_KEY="your-api-key" export IFLYTEK_SPARK_API_SECRET="your-api-secret" Key init args — completion params: model: Optional[str] Name of IFLYTEK SPARK model to use. temperature: Optional[float] Sampling temperature. top_k: Optional[float] What search sampling control to use. streaming: Optional[bool] Whether to stream the results or not. Key init args — client params: app_id: Optional[str] IFLYTEK SPARK API KEY. Automatically inferred from env var `IFLYTEK_SPARK_APP_ID` if not provided. api_key: Optional[str] IFLYTEK SPARK API KEY. If not passed in will be read from env var IFLYTEK_SPARK_API_KEY. api_secret: Optional[str] IFLYTEK SPARK API SECRET. If not passed in will be read from env var IFLYTEK_SPARK_API_SECRET. api_url: Optional[str] Base URL for API requests. timeout: Optional[int] Timeout for requests. See full list of supported init args and their descriptions in the params section. Instantiate: .. code-block:: python from langchain_community.llms import SparkLLM llm = SparkLLM( app_id="your-app-id", api_key="your-api_key", api_secret="your-api-secret", # model='Spark4.0 Ultra', # temperature=..., # other params... ) Invoke: .. code-block:: python input_text = "用50个字左右阐述,生命的意义在于" llm.invoke(input_text) .. code-block:: python '生命的意义在于实现自我价值,追求内心的平静与快乐,同时为他人和社会带来正面影响。' Stream: .. code-block:: python for chunk in llm.stream(input_text): print(chunk) .. code-block:: python 生命 | 的意义在于 | 不断探索和 | 实现个人潜能,通过 | 学习 | 、成长和对社会 | 的贡献,追求内心的满足和幸福。 Async: .. code-block:: python await llm.ainvoke(input_text) # stream: # async for chunk in llm.astream(input_text): # print(chunk) # batch: # await llm.abatch([input_text]) .. code-block:: python '生命的意义在于实现自我价值,追求内心的平静与快乐,同时为他人和社会带来正面影响。' """# noqa: E501client:Any=None#: :meta private:spark_app_id:Optional[str]=Field(default=None,alias="app_id")"""Automatically inferred from env var `IFLYTEK_SPARK_APP_ID` if not provided."""spark_api_key:Optional[str]=Field(default=None,alias="api_key")"""IFLYTEK SPARK API KEY. If not passed in will be read from env var IFLYTEK_SPARK_API_KEY."""spark_api_secret:Optional[str]=Field(default=None,alias="api_secret")"""IFLYTEK SPARK API SECRET. If not passed in will be read from env var IFLYTEK_SPARK_API_SECRET."""spark_api_url:Optional[str]=Field(default=None,alias="api_url")"""Base URL path for API requests, leave blank if not using a proxy or service emulator."""spark_llm_domain:Optional[str]=Field(default=None,alias="model")"""Model name to use."""spark_user_id:str="lc_user"streaming:bool=False"""Whether to stream the results or not."""request_timeout:int=Field(default=30,alias="timeout")"""request timeout for chat http requests"""temperature:float=0.5"""What sampling temperature to use."""top_k:int=4"""What search sampling control to use."""model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Holds any model parameters valid for API call not explicitly specified."""
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:values["spark_app_id"]=get_from_dict_or_env(values,["spark_app_id","app_id"],"IFLYTEK_SPARK_APP_ID",)values["spark_api_key"]=get_from_dict_or_env(values,["spark_api_key","api_key"],"IFLYTEK_SPARK_API_KEY",)values["spark_api_secret"]=get_from_dict_or_env(values,["spark_api_secret","api_secret"],"IFLYTEK_SPARK_API_SECRET",)values["spark_api_url"]=get_from_dict_or_env(values,["spark_api_url","api_url"],"IFLYTEK_SPARK_API_URL","wss://spark-api.xf-yun.com/v3.5/chat",)values["spark_llm_domain"]=get_from_dict_or_env(values,["spark_llm_domain","model"],"IFLYTEK_SPARK_LLM_DOMAIN","generalv3.5",)# put extra params into model_kwargsvalues["model_kwargs"]["temperature"]=values["temperature"]orcls.temperaturevalues["model_kwargs"]["top_k"]=values["top_k"]orcls.top_kvalues["client"]=_SparkLLMClient(app_id=values["spark_app_id"],api_key=values["spark_api_key"],api_secret=values["spark_api_secret"],api_url=values["spark_api_url"],spark_domain=values["spark_llm_domain"],model_kwargs=values["model_kwargs"],)returnvalues
@propertydef_llm_type(self)->str:"""Return type of llm."""return"spark-llm-chat"@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling SparkLLM API."""normal_params={"spark_llm_domain":self.spark_llm_domain,"stream":self.streaming,"request_timeout":self.request_timeout,"top_k":self.top_k,"temperature":self.temperature,}return{**normal_params,**self.model_kwargs}def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to an sparkllm for each generation with a prompt. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the llm. Example: .. code-block:: python response = client("Tell me a joke.") """ifself.streaming:completion=""forchunkinself._stream(prompt,stop,run_manager,**kwargs):completion+=chunk.textreturncompletioncompletion=""self.client.arun([{"role":"user","content":prompt}],self.spark_user_id,self.model_kwargs,self.streaming,)forcontentinself.client.subscribe(timeout=self.request_timeout):if"data"notincontent:continuecompletion=content["data"]["content"]returncompletiondef_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:self.client.run([{"role":"user","content":prompt}],self.spark_user_id,self.model_kwargs,True,)forcontentinself.client.subscribe(timeout=self.request_timeout):if"data"notincontent:continuedelta=content["data"]ifrun_manager:run_manager.on_llm_new_token(delta)yieldGenerationChunk(text=delta["content"])
class_SparkLLMClient:""" Use websocket-client to call the SparkLLM interface provided by Xfyun, which is the iFlyTek's open platform for AI capabilities """def__init__(self,app_id:str,api_key:str,api_secret:str,api_url:Optional[str]=None,spark_domain:Optional[str]=None,model_kwargs:Optional[dict]=None,):try:importwebsocketself.websocket_client=websocketexceptImportError:raiseImportError("Could not import websocket client python package. ""Please install it with `pip install websocket-client`.")self.api_url=("wss://spark-api.xf-yun.com/v3.5/chat"ifnotapi_urlelseapi_url)self.app_id=app_idself.model_kwargs=model_kwargsself.spark_domain=spark_domainor"generalv3.5"self.queue:Queue[Dict]=Queue()self.blocking_message={"content":"","role":"assistant"}self.api_key=api_keyself.api_secret=api_secret@staticmethoddef_create_url(api_url:str,api_key:str,api_secret:str)->str:""" Generate a request url with an api key and an api secret. """# generate timestamp by RFC1123date=format_date_time(mktime(datetime.now().timetuple()))# urlparseparsed_url=urlparse(api_url)host=parsed_url.netlocpath=parsed_url.pathsignature_origin=f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"# encrypt using hmac-sha256signature_sha=hmac.new(api_secret.encode("utf-8"),signature_origin.encode("utf-8"),digestmod=hashlib.sha256,).digest()signature_sha_base64=base64.b64encode(signature_sha).decode(encoding="utf-8")authorization_origin=f'api_key="{api_key}", algorithm="hmac-sha256", \ headers="host date request-line", signature="{signature_sha_base64}"'authorization=base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")# generate urlparams_dict={"authorization":authorization,"date":date,"host":host}encoded_params=urlencode(params_dict)url=urlunparse((parsed_url.scheme,parsed_url.netloc,parsed_url.path,parsed_url.params,encoded_params,parsed_url.fragment,))returnurldefrun(self,messages:List[Dict],user_id:str,model_kwargs:Optional[dict]=None,streaming:bool=False,)->None:self.websocket_client.enableTrace(False)ws=self.websocket_client.WebSocketApp(_SparkLLMClient._create_url(self.api_url,self.api_key,self.api_secret,),on_message=self.on_message,on_error=self.on_error,on_close=self.on_close,on_open=self.on_open,)ws.messages=messages# type: ignore[attr-defined]ws.user_id=user_id# type: ignore[attr-defined]ws.model_kwargs=self.model_kwargsifmodel_kwargsisNoneelsemodel_kwargs# type: ignore[attr-defined]ws.streaming=streaming# type: ignore[attr-defined]ws.run_forever()defarun(self,messages:List[Dict],user_id:str,model_kwargs:Optional[dict]=None,streaming:bool=False,)->threading.Thread:ws_thread=threading.Thread(target=self.run,args=(messages,user_id,model_kwargs,streaming,),)ws_thread.start()returnws_threaddefon_error(self,ws:Any,error:Optional[Any])->None:self.queue.put({"error":error})ws.close()defon_close(self,ws:Any,close_status_code:int,close_reason:str)->None:logger.debug({"log":{"close_status_code":close_status_code,"close_reason":close_reason,}})self.queue.put({"done":True})defon_open(self,ws:Any)->None:self.blocking_message={"content":"","role":"assistant"}data=json.dumps(self.gen_params(messages=ws.messages,user_id=ws.user_id,model_kwargs=ws.model_kwargs))ws.send(data)defon_message(self,ws:Any,message:str)->None:data=json.loads(message)code=data["header"]["code"]ifcode!=0:self.queue.put({"error":f"Code: {code}, Error: {data['header']['message']}"})ws.close()else:choices=data["payload"]["choices"]status=choices["status"]content=choices["text"][0]["content"]ifws.streaming:self.queue.put({"data":choices["text"][0]})else:self.blocking_message["content"]+=contentifstatus==2:ifnotws.streaming:self.queue.put({"data":self.blocking_message})usage_data=(data.get("payload",{}).get("usage",{}).get("text",{})ifdataelse{})self.queue.put({"usage":usage_data})ws.close()defgen_params(self,messages:list,user_id:str,model_kwargs:Optional[dict]=None)->dict:data:Dict={"header":{"app_id":self.app_id,"uid":user_id},"parameter":{"chat":{"domain":self.spark_domain}},"payload":{"message":{"text":messages}},}ifmodel_kwargs:data["parameter"]["chat"].update(model_kwargs)logger.debug(f"Spark Request Parameters: {data}")returndatadefsubscribe(self,timeout:Optional[int]=30)->Generator[Dict,None,None]:whileTrue:try:content=self.queue.get(timeout=timeout)exceptqueue.Emptyas_:raiseTimeoutError(f"SparkLLMClient wait LLM api response timeout {timeout} seconds")if"error"incontent:raiseConnectionError(content["error"])if"usage"incontent:yieldcontentcontinueif"done"incontent:breakif"data"notincontent:breakyieldcontent