Source code for langchain_community.chat_models.symblai_nebula

import json
import os
from json import JSONDecodeError
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

import requests
from aiohttp import ClientSession
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    agenerate_from_stream,
    generate_from_stream,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.utils import convert_to_secret_str


def _convert_role(role: str) -> str:
    map = {"ai": "assistant", "human": "human", "chat": "human"}
    if role in map:
        return map[role]
    else:
        raise ValueError(f"Unknown role type: {role}")


def _format_nebula_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
    system = ""
    formatted_messages = []
    for message in messages[:-1]:
        if message.type == "system":
            if isinstance(message.content, str):
                system = message.content
            else:
                raise ValueError("System prompt must be a string")
        else:
            formatted_messages.append(
                {
                    "role": _convert_role(message.type),
                    "text": message.content,
                }
            )

    text = messages[-1].content
    formatted_messages.append({"role": "human", "text": text})
    return {"system_prompt": system, "messages": formatted_messages}


[docs]class ChatNebula(BaseChatModel): """`Nebula` chat large language model - https://docs.symbl.ai/docs/nebula-llm API Reference: https://docs.symbl.ai/reference/nebula-chat To use, set the environment variable ``NEBULA_API_KEY``, or pass it as a named parameter to the constructor. To request an API key, visit https://platform.symbl.ai/#/login Example: .. code-block:: python from langchain_community.chat_models import ChatNebula from langchain_core.messages import SystemMessage, HumanMessage chat = ChatNebula(max_new_tokens=1024, temperature=0.5) messages = [ SystemMessage( content="You are a helpful assistant." ), HumanMessage( "Answer the following question. How can I help save the world." ), ] chat.invoke(messages) """ max_new_tokens: int = 1024 """Denotes the number of tokens to predict per generation.""" temperature: Optional[float] = 0 """A non-negative float that tunes the degree of randomness in generation.""" streaming: bool = False nebula_api_url: str = "https://api-nebula.symbl.ai" nebula_api_key: Optional[SecretStr] = Field(None, description="Nebula API Token") class Config: """Configuration for this pydantic object.""" allow_population_by_field_name = True arbitrary_types_allowed = True def __init__(self, **kwargs: Any) -> None: if "nebula_api_key" in kwargs: api_key = convert_to_secret_str(kwargs.pop("nebula_api_key")) elif "NEBULA_API_KEY" in os.environ: api_key = convert_to_secret_str(os.environ["NEBULA_API_KEY"]) else: api_key = None super().__init__(nebula_api_key=api_key, **kwargs) # type: ignore[call-arg] @property def _llm_type(self) -> str: """Return type of chat model.""" return "nebula-chat" @property def _api_key(self) -> str: if self.nebula_api_key: return self.nebula_api_key.get_secret_value() return "" def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Call out to Nebula's chat endpoint.""" url = f"{self.nebula_api_url}/v1/model/chat/streaming" headers = { "ApiKey": self._api_key, "Content-Type": "application/json", } formatted_data = _format_nebula_messages(messages=messages) payload: Dict[str, Any] = { "max_new_tokens": self.max_new_tokens, "temperature": self.temperature, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} json_payload = json.dumps(payload) response = requests.request( "POST", url, headers=headers, data=json_payload, stream=True ) response.raise_for_status() for chunk_response in response.iter_lines(): chunk_decoded = chunk_response.decode()[6:] try: chunk = json.loads(chunk_decoded) except JSONDecodeError: continue token = chunk["delta"] cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) if run_manager: run_manager.on_llm_new_token(token, chunk=cg_chunk) yield cg_chunk async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: url = f"{self.nebula_api_url}/v1/model/chat/streaming" headers = {"ApiKey": self._api_key, "Content-Type": "application/json"} formatted_data = _format_nebula_messages(messages=messages) payload: Dict[str, Any] = { "max_new_tokens": self.max_new_tokens, "temperature": self.temperature, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} json_payload = json.dumps(payload) async with ClientSession() as session: async with session.post( # type: ignore[call-arg] url, data=json_payload, headers=headers, stream=True ) as response: response.raise_for_status() async for chunk_response in response.content: chunk_decoded = chunk_response.decode()[6:] try: chunk = json.loads(chunk_decoded) except JSONDecodeError: continue token = chunk["delta"] cg_chunk = ChatGenerationChunk( message=AIMessageChunk(content=token) ) if run_manager: await run_manager.on_llm_new_token(token, chunk=cg_chunk) yield cg_chunk def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) url = f"{self.nebula_api_url}/v1/model/chat" headers = {"ApiKey": self._api_key, "Content-Type": "application/json"} formatted_data = _format_nebula_messages(messages=messages) payload: Dict[str, Any] = { "max_new_tokens": self.max_new_tokens, "temperature": self.temperature, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} json_payload = json.dumps(payload) response = requests.request("POST", url, headers=headers, data=json_payload) response.raise_for_status() data = response.json() return ChatResult( generations=[ChatGeneration(message=AIMessage(content=data["messages"]))], llm_output=data, ) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) url = f"{self.nebula_api_url}/v1/model/chat" headers = {"ApiKey": self._api_key, "Content-Type": "application/json"} formatted_data = _format_nebula_messages(messages=messages) payload: Dict[str, Any] = { "max_new_tokens": self.max_new_tokens, "temperature": self.temperature, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} json_payload = json.dumps(payload) async with ClientSession() as session: async with session.post( url, data=json_payload, headers=headers ) as response: response.raise_for_status() data = await response.json() return ChatResult( generations=[ ChatGeneration(message=AIMessage(content=data["messages"])) ], llm_output=data, )