from enum import Enum
from typing import Any, Iterator, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from pydantic import BaseModel, ConfigDict
from langchain_community.llms.utils import enforce_stop_tokens
[docs]
class Device(str, Enum):
"""The device to use for inference, cuda or cpu"""
cuda = "cuda"
cpu = "cpu"
[docs]
class ReaderConfig(BaseModel):
"""Configuration for the reader to be deployed in Titan Takeoff API."""
model_config = ConfigDict(
protected_namespaces=(),
)
model_name: str
"""The name of the model to use"""
device: Device = Device.cuda
"""The device to use for inference, cuda or cpu"""
consumer_group: str = "primary"
"""The consumer group to place the reader into"""
tensor_parallel: Optional[int] = None
"""The number of gpus you would like your model to be split across"""
max_seq_length: int = 512
"""The maximum sequence length to use for inference, defaults to 512"""
max_batch_size: int = 4
"""The max batch size for continuous batching of requests"""
[docs]
class TitanTakeoff(LLM):
"""Titan Takeoff API LLMs.
Titan Takeoff is a wrapper to interface with Takeoff Inference API for
generative text to text language models.
You can use this wrapper to send requests to a generative language model
and to deploy readers with Takeoff.
Examples:
This is an example how to deploy a generative language model and send
requests.
.. code-block:: python
# Import the TitanTakeoff class from community package
import time
from langchain_community.llms import TitanTakeoff
# Specify the embedding reader you'd like to deploy
reader_1 = {
"model_name": "TheBloke/Llama-2-7b-Chat-AWQ",
"device": "cuda",
"tensor_parallel": 1,
"consumer_group": "llama"
}
# For every reader you pass into models arg Takeoff will spin
# up a reader according to the specs you provide. If you don't
# specify the arg no models are spun up and it assumes you have
# already done this separately.
llm = TitanTakeoff(models=[reader_1])
# Wait for the reader to be deployed, time needed depends on the
# model size and your internet speed
time.sleep(60)
# Returns the query, ie a List[float], sent to `llama` consumer group
# where we just spun up the Llama 7B model
print(embed.invoke(
"Where can I see football?", consumer_group="llama"
))
# You can also send generation parameters to the model, any of the
# following can be passed in as kwargs:
# https://docs.titanml.co/docs/next/apis/Takeoff%20inference_REST_API/generate#request
# for instance:
print(embed.invoke(
"Where can I see football?", consumer_group="llama", max_new_tokens=100
))
"""
base_url: str = "http://localhost"
"""The base URL of the Titan Takeoff (Pro) server. Default = "http://localhost"."""
port: int = 3000
"""The port of the Titan Takeoff (Pro) server. Default = 3000."""
mgmt_port: int = 3001
"""The management port of the Titan Takeoff (Pro) server. Default = 3001."""
streaming: bool = False
"""Whether to stream the output. Default = False."""
client: Any = None
"""Takeoff Client Python SDK used to interact with Takeoff API"""
def __init__(
self,
base_url: str = "http://localhost",
port: int = 3000,
mgmt_port: int = 3001,
streaming: bool = False,
models: List[ReaderConfig] = [],
):
"""Initialize the Titan Takeoff language wrapper.
Args:
base_url (str, optional): The base URL where the Takeoff
Inference Server is listening. Defaults to `http://localhost`.
port (int, optional): What port is Takeoff Inference API
listening on. Defaults to 3000.
mgmt_port (int, optional): What port is Takeoff Management API
listening on. Defaults to 3001.
streaming (bool, optional): Whether you want to by default use the
generate_stream endpoint over generate to stream responses.
Defaults to False. In reality, this is not significantly different
as the streamed response is buffered and returned similar to the
non-streamed response, but the run manager is applied per token
generated.
models (List[ReaderConfig], optional): Any readers you'd like to
spin up on. Defaults to [].
Raises:
ImportError: If you haven't installed takeoff-client, you will
get an ImportError. To remedy run `pip install 'takeoff-client==0.4.0'`
"""
super().__init__( # type: ignore[call-arg]
base_url=base_url, port=port, mgmt_port=mgmt_port, streaming=streaming
)
try:
from takeoff_client import TakeoffClient
except ImportError:
raise ImportError(
"takeoff-client is required for TitanTakeoff. "
"Please install it with `pip install 'takeoff-client>=0.4.0'`."
)
self.client = TakeoffClient(
self.base_url, port=self.port, mgmt_port=self.mgmt_port
)
for model in models:
self.client.create_reader(model)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "titan_takeoff"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Titan Takeoff (Pro) generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
run_manager: Optional callback manager to use when streaming.
Returns:
The string generated by the model.
Example:
.. code-block:: python
model = TitanTakeoff()
prompt = "What is the capital of the United Kingdom?"
# Use of model(prompt), ie `__call__` was deprecated in LangChain 0.1.7,
# use model.invoke(prompt) instead.
response = model.invoke(prompt)
"""
if self.streaming:
text_output = ""
for chunk in self._stream(
prompt=prompt,
stop=stop,
run_manager=run_manager,
):
text_output += chunk.text
return text_output
response = self.client.generate(prompt, **kwargs)
text = response["text"]
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Call out to Titan Takeoff (Pro) stream endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
run_manager: Optional callback manager to use when streaming.
Yields:
A dictionary like object containing a string token.
Example:
.. code-block:: python
model = TitanTakeoff()
prompt = "What is the capital of the United Kingdom?"
response = model.stream(prompt)
# OR
model = TitanTakeoff(streaming=True)
response = model.invoke(prompt)
"""
response = self.client.generate_stream(prompt, **kwargs)
buffer = ""
for text in response:
buffer += text.data
if "data:" in buffer:
# Remove the first instance of "data:" from the buffer.
if buffer.startswith("data:"):
buffer = ""
if len(buffer.split("data:", 1)) == 2:
content, _ = buffer.split("data:", 1)
buffer = content.rstrip("\n")
# Trim the buffer to only have content after the "data:" part.
if buffer: # Ensure that there's content to process.
chunk = GenerationChunk(text=buffer)
buffer = "" # Reset buffer for the next set of data.
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
yield chunk
# Yield any remaining content in the buffer.
if buffer:
chunk = GenerationChunk(text=buffer.replace("</s>", ""))
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
yield chunk