Source code for langchain_community.chat_models.symblai_nebula
importjsonimportosfromjsonimportJSONDecodeErrorfromtypingimportAny,AsyncIterator,Dict,Iterator,List,OptionalimportrequestsfromaiohttpimportClientSessionfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.chat_modelsimport(BaseChatModel,agenerate_from_stream,generate_from_stream,)fromlangchain_core.messagesimportAIMessage,AIMessageChunk,BaseMessagefromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_core.utilsimportconvert_to_secret_strfrompydanticimportConfigDict,Field,SecretStrdef_convert_role(role:str)->str:map={"ai":"assistant","human":"human","chat":"human"}ifroleinmap:returnmap[role]else:raiseValueError(f"Unknown role type: {role}")def_format_nebula_messages(messages:List[BaseMessage])->Dict[str,Any]:system=""formatted_messages=[]formessageinmessages[:-1]:ifmessage.type=="system":ifisinstance(message.content,str):system=message.contentelse:raiseValueError("System prompt must be a string")else:formatted_messages.append({"role":_convert_role(message.type),"text":message.content,})text=messages[-1].contentformatted_messages.append({"role":"human","text":text})return{"system_prompt":system,"messages":formatted_messages}
[docs]classChatNebula(BaseChatModel):"""`Nebula` chat large language model - https://docs.symbl.ai/docs/nebula-llm API Reference: https://docs.symbl.ai/reference/nebula-chat To use, set the environment variable ``NEBULA_API_KEY``, or pass it as a named parameter to the constructor. To request an API key, visit https://platform.symbl.ai/#/login Example: .. code-block:: python from langchain_community.chat_models import ChatNebula from langchain_core.messages import SystemMessage, HumanMessage chat = ChatNebula(max_new_tokens=1024, temperature=0.5) messages = [ SystemMessage( content="You are a helpful assistant." ), HumanMessage( "Answer the following question. How can I help save the world." ), ] chat.invoke(messages) """max_new_tokens:int=1024"""Denotes the number of tokens to predict per generation."""temperature:Optional[float]=0"""A non-negative float that tunes the degree of randomness in generation."""streaming:bool=Falsenebula_api_url:str="https://api-nebula.symbl.ai"nebula_api_key:Optional[SecretStr]=Field(None,description="Nebula API Token")model_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,)def__init__(self,**kwargs:Any)->None:if"nebula_api_key"inkwargs:api_key=convert_to_secret_str(kwargs.pop("nebula_api_key"))elif"NEBULA_API_KEY"inos.environ:api_key=convert_to_secret_str(os.environ["NEBULA_API_KEY"])else:api_key=Nonesuper().__init__(nebula_api_key=api_key,**kwargs)# type: ignore[call-arg]@propertydef_llm_type(self)->str:"""Return type of chat model."""return"nebula-chat"@propertydef_api_key(self)->str:ifself.nebula_api_key:returnself.nebula_api_key.get_secret_value()return""def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:"""Call out to Nebula's chat endpoint."""url=f"{self.nebula_api_url}/v1/model/chat/streaming"headers={"ApiKey":self._api_key,"Content-Type":"application/json",}formatted_data=_format_nebula_messages(messages=messages)payload:Dict[str,Any]={"max_new_tokens":self.max_new_tokens,"temperature":self.temperature,**formatted_data,**kwargs,}payload={k:vfork,vinpayload.items()ifvisnotNone}json_payload=json.dumps(payload)response=requests.request("POST",url,headers=headers,data=json_payload,stream=True)response.raise_for_status()forchunk_responseinresponse.iter_lines():chunk_decoded=chunk_response.decode()[6:]try:chunk=json.loads(chunk_decoded)exceptJSONDecodeError:continuetoken=chunk["delta"]cg_chunk=ChatGenerationChunk(message=AIMessageChunk(content=token))ifrun_manager:run_manager.on_llm_new_token(token,chunk=cg_chunk)yieldcg_chunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:url=f"{self.nebula_api_url}/v1/model/chat/streaming"headers={"ApiKey":self._api_key,"Content-Type":"application/json"}formatted_data=_format_nebula_messages(messages=messages)payload:Dict[str,Any]={"max_new_tokens":self.max_new_tokens,"temperature":self.temperature,**formatted_data,**kwargs,}payload={k:vfork,vinpayload.items()ifvisnotNone}json_payload=json.dumps(payload)asyncwithClientSession()assession:asyncwithsession.post(# type: ignore[call-arg]url,data=json_payload,headers=headers,stream=True)asresponse:response.raise_for_status()asyncforchunk_responseinresponse.content:chunk_decoded=chunk_response.decode()[6:]try:chunk=json.loads(chunk_decoded)exceptJSONDecodeError:continuetoken=chunk["delta"]cg_chunk=ChatGenerationChunk(message=AIMessageChunk(content=token))ifrun_manager:awaitrun_manager.on_llm_new_token(token,chunk=cg_chunk)yieldcg_chunkdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)url=f"{self.nebula_api_url}/v1/model/chat"headers={"ApiKey":self._api_key,"Content-Type":"application/json"}formatted_data=_format_nebula_messages(messages=messages)payload:Dict[str,Any]={"max_new_tokens":self.max_new_tokens,"temperature":self.temperature,**formatted_data,**kwargs,}payload={k:vfork,vinpayload.items()ifvisnotNone}json_payload=json.dumps(payload)response=requests.request("POST",url,headers=headers,data=json_payload)response.raise_for_status()data=response.json()returnChatResult(generations=[ChatGeneration(message=AIMessage(content=data["messages"]))],llm_output=data,)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)url=f"{self.nebula_api_url}/v1/model/chat"headers={"ApiKey":self._api_key,"Content-Type":"application/json"}formatted_data=_format_nebula_messages(messages=messages)payload:Dict[str,Any]={"max_new_tokens":self.max_new_tokens,"temperature":self.temperature,**formatted_data,**kwargs,}payload={k:vfork,vinpayload.items()ifvisnotNone}json_payload=json.dumps(payload)asyncwithClientSession()assession:asyncwithsession.post(url,data=json_payload,headers=headers)asresponse:response.raise_for_status()data=awaitresponse.json()returnChatResult(generations=[ChatGeneration(message=AIMessage(content=data["messages"]))],llm_output=data,)