Source code for langchain_community.tools.powerbi.tool

"""Tools for interacting with a Power BI dataset."""

import logging
from time import perf_counter
from typing import Any, Dict, Optional, Tuple

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import ConfigDict, Field, model_validator

from langchain_community.chat_models.openai import _import_tiktoken
from langchain_community.tools.powerbi.prompt import (
    BAD_REQUEST_RESPONSE,
    DEFAULT_FEWSHOT_EXAMPLES,
    RETRY_RESPONSE,
)
from langchain_community.utilities.powerbi import PowerBIDataset, json_to_md

logger = logging.getLogger(__name__)


[docs] class QueryPowerBITool(BaseTool): # type: ignore[override] """Tool for querying a Power BI Dataset.""" name: str = "query_powerbi" description: str = """ Input to this tool is a detailed question about the dataset, output is a result from the dataset. It will try to answer the question using the dataset, and if it cannot, it will ask for clarification. Example Input: "How many rows are in table1?" """ # noqa: E501 llm_chain: Any = None powerbi: PowerBIDataset = Field(exclude=True) examples: Optional[str] = DEFAULT_FEWSHOT_EXAMPLES session_cache: Dict[str, Any] = Field(default_factory=dict, exclude=True) max_iterations: int = 5 output_token_limit: int = 4000 tiktoken_model_name: Optional[str] = None # "cl100k_base" model_config = ConfigDict( arbitrary_types_allowed=True, ) @model_validator(mode="before") @classmethod def validate_llm_chain_input_variables( # pylint: disable=E0213 cls, values: dict ) -> dict: """Make sure the LLM chain has the correct input variables.""" llm_chain = values["llm_chain"] for var in llm_chain.prompt.input_variables: if var not in ["tool_input", "tables", "schemas", "examples"]: raise ValueError( "LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301 llm_chain.prompt.input_variables, ) return values def _check_cache(self, tool_input: str) -> Optional[str]: """Check if the input is present in the cache. If the value is a bad request, overwrite with the escalated version, if not present return None.""" if tool_input not in self.session_cache: return None return self.session_cache[tool_input] def _run( self, tool_input: str, run_manager: Optional[CallbackManagerForToolRun] = None, **kwargs: Any, ) -> str: """Execute the query, return the results or an error message.""" if cache := self._check_cache(tool_input): logger.debug("Found cached result for %s: %s", tool_input, cache) return cache try: logger.info("Running PBI Query Tool with input: %s", tool_input) query = self.llm_chain.predict( tool_input=tool_input, tables=self.powerbi.get_table_names(), schemas=self.powerbi.get_schemas(), examples=self.examples, callbacks=run_manager.get_child() if run_manager else None, ) except Exception as exc: # pylint: disable=broad-except self.session_cache[tool_input] = f"Error on call to LLM: {exc}" return self.session_cache[tool_input] if query == "I cannot answer this": self.session_cache[tool_input] = query return self.session_cache[tool_input] logger.info("PBI Query:\n%s", query) start_time = perf_counter() pbi_result = self.powerbi.run(command=query) end_time = perf_counter() logger.debug("PBI Result: %s", pbi_result) logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}") result, error = self._parse_output(pbi_result) if error is not None and "TokenExpired" in error: self.session_cache[tool_input] = ( "Authentication token expired or invalid, please try reauthenticate." ) return self.session_cache[tool_input] iterations = kwargs.get("iterations", 0) if error and iterations < self.max_iterations: return self._run( tool_input=RETRY_RESPONSE.format( tool_input=tool_input, query=query, error=error ), run_manager=run_manager, iterations=iterations + 1, ) self.session_cache[tool_input] = ( result if result else BAD_REQUEST_RESPONSE.format(error=error) ) return self.session_cache[tool_input] async def _arun( self, tool_input: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, **kwargs: Any, ) -> str: """Execute the query, return the results or an error message.""" if cache := self._check_cache(tool_input): logger.debug("Found cached result for %s: %s", tool_input, cache) return f"{cache}, from cache, you have already asked this question." try: logger.info("Running PBI Query Tool with input: %s", tool_input) query = await self.llm_chain.apredict( tool_input=tool_input, tables=self.powerbi.get_table_names(), schemas=self.powerbi.get_schemas(), examples=self.examples, callbacks=run_manager.get_child() if run_manager else None, ) except Exception as exc: # pylint: disable=broad-except self.session_cache[tool_input] = f"Error on call to LLM: {exc}" return self.session_cache[tool_input] if query == "I cannot answer this": self.session_cache[tool_input] = query return self.session_cache[tool_input] logger.info("PBI Query: %s", query) start_time = perf_counter() pbi_result = await self.powerbi.arun(command=query) end_time = perf_counter() logger.debug("PBI Result: %s", pbi_result) logger.debug(f"PBI Query duration: {end_time - start_time:0.6f}") result, error = self._parse_output(pbi_result) if error is not None and ("TokenExpired" in error or "TokenError" in error): self.session_cache[tool_input] = ( "Authentication token expired or invalid, please try to reauthenticate or check the scope of the credential." # noqa: E501 ) return self.session_cache[tool_input] iterations = kwargs.get("iterations", 0) if error and iterations < self.max_iterations: return await self._arun( tool_input=RETRY_RESPONSE.format( tool_input=tool_input, query=query, error=error ), run_manager=run_manager, iterations=iterations + 1, ) self.session_cache[tool_input] = ( result if result else BAD_REQUEST_RESPONSE.format(error=error) ) return self.session_cache[tool_input] def _parse_output( self, pbi_result: Dict[str, Any] ) -> Tuple[Optional[str], Optional[Any]]: """Parse the output of the query to a markdown table.""" if "results" in pbi_result: rows = pbi_result["results"][0]["tables"][0]["rows"] if len(rows) == 0: logger.info("0 records in result, query was valid.") return ( None, "0 rows returned, this might be correct, but please validate if all filter values were correct?", # noqa: E501 ) result = json_to_md(rows) too_long, length = self._result_too_large(result) if too_long: return ( f"Result too large, please try to be more specific or use the `TOPN` function. The result is {length} tokens long, the limit is {self.output_token_limit} tokens.", # noqa: E501 None, ) return result, None if "error" in pbi_result: if ( "pbi.error" in pbi_result["error"] and "details" in pbi_result["error"]["pbi.error"] ): return None, pbi_result["error"]["pbi.error"]["details"][0]["detail"] return None, pbi_result["error"] return None, pbi_result def _result_too_large(self, result: str) -> Tuple[bool, int]: """Tokenize the output of the query.""" if self.tiktoken_model_name: tiktoken_ = _import_tiktoken() encoding = tiktoken_.encoding_for_model(self.tiktoken_model_name) length = len(encoding.encode(result)) logger.info("Result length: %s", length) return length > self.output_token_limit, length return False, 0
[docs] class InfoPowerBITool(BaseTool): # type: ignore[override] """Tool for getting metadata about a PowerBI Dataset.""" name: str = "schema_powerbi" description: str = """ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling list_tables_powerbi first! Example Input: "table1, table2, table3" """ # noqa: E501 powerbi: PowerBIDataset = Field(exclude=True) model_config = ConfigDict( arbitrary_types_allowed=True, ) def _run( self, tool_input: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Get the schema for tables in a comma-separated list.""" return self.powerbi.get_table_info(tool_input.split(", ")) async def _arun( self, tool_input: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: return await self.powerbi.aget_table_info(tool_input.split(", "))
[docs] class ListPowerBITool(BaseTool): # type: ignore[override] """Tool for getting tables names.""" name: str = "list_tables_powerbi" description: str = "Input is an empty string, output is a comma separated list of tables in the database." # noqa: E501 # pylint: disable=C0301 powerbi: PowerBIDataset = Field(exclude=True) model_config = ConfigDict( arbitrary_types_allowed=True, ) def _run( self, tool_input: Optional[str] = None, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Get the names of the tables.""" return ", ".join(self.powerbi.get_table_names()) async def _arun( self, tool_input: Optional[str] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Get the names of the tables.""" return ", ".join(self.powerbi.get_table_names())