Source code for langchain_azure_ai.tools.ai_services.image_analysis
"""Tool that queries the Azure AI Services Image Analysis API."""
from __future__ import annotations
import json
import logging
from typing import Annotated, Any, Dict, Optional
from azure.core.exceptions import HttpResponseError
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import ArgsSchema, BaseTool
from langchain_core.utils import pre_init
from pydantic import BaseModel, PrivateAttr, SkipValidation, model_validator
from langchain_azure_ai._resources import AIServicesService
from langchain_azure_ai.utils.utils import detect_file_src_type
try:
from azure.ai.vision.imageanalysis import ImageAnalysisClient
from azure.ai.vision.imageanalysis.models import VisualFeatures
except ImportError:
raise ImportError(
"To use Azure AI Image Analysis tool, please install the"
"'azure-ai-vision-imageanalysis' package: "
"`pip install azure-ai-vision-imageanalysis` or install the 'tools' "
"extra: `pip install langchain-azure-ai[tools]`"
)
logger = logging.getLogger(__name__)
[docs]
class ImageInput(BaseModel):
"""The input document for the Azure AI Image Analysis tool."""
image_path: str
"""The path or URL to the image to analyze."""
[docs]
class AzureAIImageAnalysisTool(BaseTool, AIServicesService):
"""Tool that queries the Azure AI Services Image Analysis API.
In order to set this up, follow instructions at:
https://learn.microsoft.com/en-us/azure/cognitive-services/computer-vision/quickstarts-sdk/image-analysis-client-library-40
"""
_client: ImageAnalysisClient = PrivateAttr()
name: str = "azure_ai_image_analysis"
description: str = (
"A wrapper around Azure AI Services Image Analysis. "
"Useful for when you need to analyze images. "
"Input should be a url to an image."
)
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = ImageInput
"""The input args schema for the tool."""
visual_features: Optional[VisualFeatures] = None
[docs]
@pre_init
def validate_environment(cls, values: Dict) -> Any:
"""Validate that the environment is set up correctly."""
values = super().validate_environment(values)
try:
if values["visual_features"] is None:
values["visual_features"] = [
VisualFeatures.TAGS,
VisualFeatures.OBJECTS,
VisualFeatures.CAPTION,
VisualFeatures.DENSE_CAPTIONS,
VisualFeatures.READ,
VisualFeatures.SMART_CROPS,
VisualFeatures.PEOPLE,
]
else:
for feature in values["visual_features"]:
if not any(item.value == feature for item in VisualFeatures):
raise ValueError(
f"Invalid visual feature: {feature}. "
f"Valid features are: {[f.value for f in VisualFeatures]}"
)
except ImportError:
raise ImportError(
"azure-ai-vision-imageanalysis is not installed. "
"Run `pip install azure-ai-vision-imageanalysis` to install."
)
return values
@model_validator(mode="after")
def initialize_client(self) -> AzureAIImageAnalysisTool:
"""Initialize the Azure AI Image Analysis client."""
from azure.ai.vision.imageanalysis import ImageAnalysisClient
from azure.core.credentials import AzureKeyCredential
credential = (
AzureKeyCredential(self.credential)
if isinstance(self.credential, str)
else self.credential
)
self._client = ImageAnalysisClient(
endpoint=self.endpoint, # type: ignore[arg-type]
credential=credential, # type: ignore[arg-type]
**self.client_kwargs,
)
return self
def _image_analysis(self, image_path: str) -> Dict:
image_src_type = detect_file_src_type(image_path)
print(f"Image source type detected: {image_src_type}")
try:
if image_src_type == "local":
with open(image_path, "rb") as f:
image_data = f.read()
result = self._client.analyze(
image_data=image_data,
visual_features=self.visual_features, # type: ignore[arg-type]
)
elif image_src_type == "remote":
result = self._client.analyze_from_url(
image_url=image_path,
visual_features=self.visual_features, # type: ignore[arg-type]
)
else:
raise ValueError(f"Invalid image path: {image_path}")
except HttpResponseError as e:
return {
"status_code": e.status_code,
"error_code": e.error.code if e.error else None,
"error_message": e.error.message if e.error else None,
"error_details": e.error.details if e.error else None,
}
res_dict = result.as_dict()
return res_dict
def _format_image_analysis_result(self, results: Dict) -> str:
output = {}
if "tagsResult" in results:
output["tags"] = results["tagsResult"]["values"]
if "objectsResult" in results:
output["objects"] = results["objectsResult"]["values"]
if "readResult" in results:
output["read"] = []
for line in [block for block in results["readResult"]["blocks"]]:
output["read"].append(", ".join(text["text"] for text in line["lines"]))
if "peopleResult" in results:
output["people"] = results["peopleResult"]["values"]
if "smartCropsResult" in results:
output["smartCrops"] = results["smartCropsResult"]["values"]
if "captionResult" in results:
output["captions"] = results["captionResult"]["captions"]
return json.dumps(output, indent=2)
def _run(
self,
image_path: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
try:
image_analysis_result = self._image_analysis(image_path)
if not image_analysis_result:
return "No good image analysis result was found"
return self._format_image_analysis_result(image_analysis_result)
except Exception as e:
raise RuntimeError(f"Error while running {self.name}: {e}")