import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import get_from_dict_or_env
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain_community.llms.utils import enforce_stop_tokens
[docs]
class TrainResult(TypedDict):
"""Train result."""
loss: float
[docs]
class GradientLLM(BaseLLM):
"""Gradient.ai LLM Endpoints.
GradientLLM is a class to interact with LLMs on gradient.ai
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
or alternatively provide them as keywords to the constructor of this class.
Example:
.. code-block:: python
from langchain_community.llms import GradientLLM
GradientLLM(
model="99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model",
model_kwargs={
"max_generated_token_count": 128,
"temperature": 0.75,
"top_p": 0.95,
"top_k": 20,
"stop": [],
},
gradient_workspace_id="12345614fc0_workspace",
gradient_access_token="gradientai-access_token",
)
"""
model_id: str = Field(alias="model", min_length=2)
"Underlying gradient.ai model id (base or fine-tuned)."
gradient_workspace_id: Optional[str] = None
"Underlying gradient.ai workspace_id."
gradient_access_token: Optional[str] = None
"""gradient.ai API Token, which can be generated by going to
https://auth.gradient.ai/select-workspace
and selecting "Access tokens" under the profile drop-down.
"""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
gradient_api_url: str = "https://api.gradient.ai/api"
"""Endpoint URL to use."""
aiosession: Optional[aiohttp.ClientSession] = None #: :meta private:
"""ClientSession, private, subject to change in upcoming releases."""
# LLM call kwargs
model_config = ConfigDict(
populate_by_name=True,
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and python package exists in environment."""
values["gradient_access_token"] = get_from_dict_or_env(
values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
)
values["gradient_workspace_id"] = get_from_dict_or_env(
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
)
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
)
return values
@model_validator(mode="after")
def post_init(self) -> Self:
"""Post init validation."""
# Can be most to post_init_validation
try:
import gradientai # noqa
except ImportError:
logging.warning(
"DeprecationWarning: `GradientLLM` will use "
"`pip install gradientai` in future releases of langchain."
)
except Exception:
pass
# Can be most to post_init_validation
if self.gradient_access_token is None or len(self.gradient_access_token) < 10:
raise ValueError("env variable `GRADIENT_ACCESS_TOKEN` must be set")
if self.gradient_workspace_id is None or len(self.gradient_access_token) < 3:
raise ValueError("env variable `GRADIENT_WORKSPACE_ID` must be set")
if self.model_kwargs:
kw = self.model_kwargs
if not 0 <= kw.get("temperature", 0.5) <= 1:
raise ValueError("`temperature` must be in the range [0.0, 1.0]")
if not 0 <= kw.get("top_p", 0.5) <= 1:
raise ValueError("`top_p` must be in the range [0.0, 1.0]")
if 0 >= kw.get("top_k", 0.5):
raise ValueError("`top_k` must be positive")
if 0 >= kw.get("max_generated_token_count", 1):
raise ValueError("`max_generated_token_count` must be positive")
return self
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"gradient_api_url": self.gradient_api_url},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gradient"
def _kwargs_post_fine_tune_request(
self, inputs: Sequence[str], kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
"""Build the kwargs for the Post request, used by sync
Args:
prompt (str): prompt used in query
kwargs (dict): model kwargs in payload
Returns:
Dict[str, Union[str,dict]]: _description_
"""
_model_kwargs = self.model_kwargs or {}
_params = {**_model_kwargs, **kwargs}
multipliers = _params.get("multipliers", None)
return dict(
url=f"{self.gradient_api_url}/models/{self.model_id}/fine-tune",
headers={
"authorization": f"Bearer {self.gradient_access_token}",
"x-gradient-workspace-id": f"{self.gradient_workspace_id}",
"accept": "application/json",
"content-type": "application/json",
},
json=dict(
samples=(
tuple(
{
"inputs": input,
}
for input in inputs
)
if multipliers is None
else tuple(
{
"inputs": input,
"fineTuningParameters": {
"multiplier": multiplier,
},
}
for input, multiplier in zip(inputs, multipliers)
)
),
),
)
def _kwargs_post_request(
self, prompt: str, kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
"""Build the kwargs for the Post request, used by sync
Args:
prompt (str): prompt used in query
kwargs (dict): model kwargs in payload
Returns:
Dict[str, Union[str,dict]]: _description_
"""
_model_kwargs = self.model_kwargs or {}
_params = {**_model_kwargs, **kwargs}
return dict(
url=f"{self.gradient_api_url}/models/{self.model_id}/complete",
headers={
"authorization": f"Bearer {self.gradient_access_token}",
"x-gradient-workspace-id": f"{self.gradient_workspace_id}",
"accept": "application/json",
"content-type": "application/json",
},
json=dict(
query=prompt,
maxGeneratedTokenCount=_params.get("max_generated_token_count", None),
temperature=_params.get("temperature", None),
topK=_params.get("top_k", None),
topP=_params.get("top_p", None),
),
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Gradients API `model/{id}/complete`.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
"""
try:
response = requests.post(**self._kwargs_post_request(prompt, kwargs))
if response.status_code != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
except requests.exceptions.RequestException as e:
raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
text = response.json()["generatedOutput"]
if stop is not None:
# Apply stop tokens when making calls to Gradient
text = enforce_stop_tokens(text, stop)
return text
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async Call to Gradients API `model/{id}/complete`.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
"""
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(
**self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
) as response:
if response.status != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status}: {response.text}"
)
text = (await response.json())["generatedOutput"]
else:
async with self.aiosession.post(
**self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
) as response:
if response.status != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status}: {response.text}"
)
text = (await response.json())["generatedOutput"]
if stop is not None:
# Apply stop tokens when making calls to Gradient
text = enforce_stop_tokens(text, stop)
return text
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# same thing with threading
def _inner_generate(prompt: str) -> List[Generation]:
return [
Generation(
text=self._call(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
)
]
if len(prompts) <= 1:
generations = list(map(_inner_generate, prompts))
else:
with ThreadPoolExecutor(min(8, len(prompts))) as p:
generations = list(p.map(_inner_generate, prompts))
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
for generation in await asyncio.gather(
*[
self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
for prompt in prompts
]
):
generations.append([Generation(text=generation)])
return LLMResult(generations=generations)
[docs]
def train_unsupervised(
self,
inputs: Sequence[str],
**kwargs: Any,
) -> TrainResult:
try:
response = requests.post(
**self._kwargs_post_fine_tune_request(inputs, kwargs)
)
if response.status_code != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
except requests.exceptions.RequestException as e:
raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
response_json = response.json()
loss = response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
return TrainResult(loss=loss)
[docs]
async def atrain_unsupervised(
self,
inputs: Sequence[str],
**kwargs: Any,
) -> TrainResult:
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(
**self._kwargs_post_fine_tune_request(inputs, kwargs)
) as response:
if response.status != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
loss = (
response_json["sumLoss"]
/ response_json["numberOfTrainableTokens"]
)
else:
async with self.aiosession.post(
**self._kwargs_post_fine_tune_request(inputs, kwargs)
) as response:
if response.status != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
loss = (
response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
)
return TrainResult(loss=loss)