Source code for langchain_azure_ai.tools.ai_services.document_intelligence
"""Tool that queries the Azure AI Document Intelligence API."""
from __future__ import annotations
import logging
from typing import Annotated, Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import ArgsSchema, BaseTool
from pydantic import BaseModel, PrivateAttr, SkipValidation, model_validator
from langchain_azure_ai._resources import AIServicesService
from langchain_azure_ai.utils.utils import detect_file_src_type
try:
from azure.ai.documentintelligence import DocumentIntelligenceClient
from azure.ai.documentintelligence.models import AnalyzeDocumentRequest
from azure.core.credentials import AzureKeyCredential
except ImportError:
raise ImportError(
"To use Azure AI Document Intelligence tool, please install the"
"'azure-ai-documentintelligence' package: "
"`pip install azure-ai-documentintelligence` or install the 'tools' "
"extra: `pip install langchain-azure-ai[tools]`"
)
logger = logging.getLogger(__name__)
[docs]
class DocumentInput(BaseModel):
"""The input document for the Azure AI Document Intelligence tool."""
document_path: str
"""The path or URL to the document to analyze."""
[docs]
class AzureAIDocumentIntelligenceTool(BaseTool, AIServicesService):
"""Tool that queries the Azure AI Document Intelligence API."""
_client: DocumentIntelligenceClient = PrivateAttr() # pyright: ignore[reportUndefinedVariable]
name: str = "azure_ai_document_intelligence"
"""The name of the tool."""
description: str = (
"A tool that uses Azure AI Document Intelligence API to analyze "
"documents. Useful for when you need to extract text, tables, and "
"key-value pairs from documents. Input should be a url or path to "
"a document."
)
"""The description of the tool."""
args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = DocumentInput
"""The input args schema for the tool."""
model_id: str = "prebuilt-layout"
"""The model ID to use for document analysis. If not specified, the
prebuilt-document model will be used."""
@model_validator(mode="after")
def initialize_client(self) -> AzureAIDocumentIntelligenceTool:
"""Initialize the Azure AI Document Intelligence client."""
credential = (
AzureKeyCredential(self.credential)
if isinstance(self.credential, str)
else self.credential
)
self._client = DocumentIntelligenceClient(
endpoint=self.endpoint, # type: ignore[arg-type]
credential=credential, # type: ignore[arg-type]
**self.client_kwargs, # type: ignore[arg-type]
)
return self
def _parse_tables(self, tables: List[Any]) -> List[Any]:
"""Parse tables from the document analysis result."""
result = []
for table in tables:
rc, cc = table.row_count, table.column_count
_table = [["" for _ in range(cc)] for _ in range(rc)]
for cell in table.cells:
_table[cell.row_index][cell.column_index] = cell.content
result.append(_table)
return result
def _parse_kv_pairs(self, kv_pairs: List[Any]) -> List[Any]:
"""Parse key-value pairs from the document analysis result."""
result = []
for kv_pair in kv_pairs:
key = kv_pair.key.content if kv_pair.key else ""
value = kv_pair.value.content if kv_pair.value else ""
result.append((key, value))
return result
def _document_analysis(self, document_path: str) -> Dict:
"""Analyze a document using the Document Intelligence client."""
document_src_type = detect_file_src_type(document_path)
if document_src_type == "local":
with open(document_path, "rb") as document:
poller = self._client.begin_analyze_document(
self.model_id,
AnalyzeDocumentRequest(bytes_source=document), # type: ignore[call-overload]
)
elif document_src_type == "remote":
poller = self._client.begin_analyze_document(
self.model_id, AnalyzeDocumentRequest(url_source=document_path)
)
else:
raise ValueError(f"Invalid document path: {document_path}")
result = poller.result()
res_dict = {}
if result.content is not None:
res_dict["content"] = result.content
if result.tables is not None:
res_dict["tables"] = self._parse_tables(result.tables) # type: ignore[assignment]
if result.key_value_pairs is not None:
res_dict["key_value_pairs"] = self._parse_kv_pairs(result.key_value_pairs) # type: ignore[assignment]
return res_dict
def _format_document_analysis_result(self, document_analysis_result: Dict) -> str:
"""Format the document analysis result into a readable string."""
formatted_result = []
if "content" in document_analysis_result:
formatted_result.append(
f"Content: {document_analysis_result['content']}".replace("\n", " ")
)
if "tables" in document_analysis_result:
for i, table in enumerate(document_analysis_result["tables"]):
formatted_result.append(f"Table {i}: {table}".replace("\n", " "))
if "key_value_pairs" in document_analysis_result:
for kv_pair in document_analysis_result["key_value_pairs"]:
formatted_result.append(
f"{kv_pair[0]}: {kv_pair[1]}".replace("\n", " ")
)
return "\n".join(formatted_result)
def _run(
self,
document_path: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
try:
document_analysis_result = self._document_analysis(document_path)
if not document_analysis_result:
return "No good document analysis result was found"
return self._format_document_analysis_result(document_analysis_result)
except Exception as e:
raise RuntimeError(f"Error while running {self.name}: {e}")