Source code for langchain_community.tools.edenai.edenai_base_tool
from __future__ import annotations
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional
import requests
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import BaseTool
from langchain_core.utils import secret_from_env
from pydantic import Field, SecretStr
logger = logging.getLogger(__name__)
[docs]
class EdenaiTool(BaseTool): # type: ignore[override]
"""
the base tool for all the EdenAI Tools .
you should have
the environment variable ``EDENAI_API_KEY`` set with your API token.
You can find your token here: https://app.edenai.run/admin/account/settings
"""
feature: str
subfeature: str
edenai_api_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("EDENAI_API_KEY", default=None)
)
is_async: bool = False
providers: List[str]
"""provider to use for the API call."""
[docs]
@staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
def _call_eden_ai(self, query_params: Dict[str, Any]) -> str:
"""
Make an API call to the EdenAI service with the specified query parameters.
Args:
query_params (dict): The parameters to include in the API call.
Returns:
requests.Response: The response from the EdenAI API call.
"""
api_key = self.edenai_api_key.get_secret_value() if self.edenai_api_key else ""
headers = {
"Authorization": f"Bearer {api_key}",
"User-Agent": self.get_user_agent(),
}
url = f"https://api.edenai.run/v2/{self.feature}/{self.subfeature}"
payload = {
"providers": str(self.providers),
"response_as_dict": False,
"attributes_as_list": True,
"show_original_response": False,
}
payload.update(query_params)
response = requests.post(url, json=payload, headers=headers)
self._raise_on_error(response)
try:
return self._parse_response(response.json())
except Exception as e:
raise RuntimeError(f"An error occurred while running tool: {e}")
def _raise_on_error(self, response: requests.Response) -> None:
if response.status_code >= 500:
raise Exception(f"EdenAI Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
f"EdenAI returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
# case where edenai call succeeded but provider returned an error
# (eg: rate limit, server error, etc.)
if self.is_async is False:
# async call are different and only return a job_id,
# not the provider response directly
provider_response = response.json()[0]
if provider_response.get("status") == "fail":
err_msg = provider_response["error"]["message"]
raise ValueError(err_msg)
@abstractmethod
def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
pass
@abstractmethod
def _parse_response(self, response: Any) -> str:
"""Take a dict response and condense it's data in a human readable string"""
pass
def _get_edenai(self, url: str) -> requests.Response:
headers = {
"accept": "application/json",
"authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
response = requests.get(url, headers=headers)
self._raise_on_error(response)
return response
def _parse_json_multilevel(
self, extracted_data: dict, formatted_list: list, level: int = 0
) -> None:
for section, subsections in extracted_data.items():
indentation = " " * level
if isinstance(subsections, str):
subsections = subsections.replace("\n", ",")
formatted_list.append(f"{indentation}{section} : {subsections}")
elif isinstance(subsections, list):
formatted_list.append(f"{indentation}{section} : ")
self._list_handling(subsections, formatted_list, level + 1)
elif isinstance(subsections, dict):
formatted_list.append(f"{indentation}{section} : ")
self._parse_json_multilevel(subsections, formatted_list, level + 1)
def _list_handling(
self, subsection_list: list, formatted_list: list, level: int
) -> None:
for list_item in subsection_list:
if isinstance(list_item, dict):
self._parse_json_multilevel(list_item, formatted_list, level)
elif isinstance(list_item, list):
self._list_handling(list_item, formatted_list, level + 1)
else:
formatted_list.append(f"{' ' * level}{list_item}")