Source code for langchain_core.structured_query

"""Internal representation of a structured query language."""

from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Sequence, Union

from langchain_core.pydantic_v1 import BaseModel


[docs]class Visitor(ABC): """Defines interface for IR translation using a visitor pattern.""" allowed_comparators: Optional[Sequence[Comparator]] = None """Allowed comparators for the visitor.""" allowed_operators: Optional[Sequence[Operator]] = None """Allowed operators for the visitor.""" def _validate_func(self, func: Union[Operator, Comparator]) -> None: if isinstance(func, Operator) and self.allowed_operators is not None: if func not in self.allowed_operators: raise ValueError( f"Received disallowed operator {func}. Allowed " f"comparators are {self.allowed_operators}" ) if isinstance(func, Comparator) and self.allowed_comparators is not None: if func not in self.allowed_comparators: raise ValueError( f"Received disallowed comparator {func}. Allowed " f"comparators are {self.allowed_comparators}" )
[docs] @abstractmethod def visit_operation(self, operation: Operation) -> Any: """Translate an Operation. Args: operation: Operation to translate. """
[docs] @abstractmethod def visit_comparison(self, comparison: Comparison) -> Any: """Translate a Comparison. Args: comparison: Comparison to translate. """
[docs] @abstractmethod def visit_structured_query(self, structured_query: StructuredQuery) -> Any: """Translate a StructuredQuery. Args: structured_query: StructuredQuery to translate. """
def _to_snake_case(name: str) -> str: """Convert a name into snake_case.""" snake_case = "" for i, char in enumerate(name): if char.isupper() and i != 0: snake_case += "_" + char.lower() else: snake_case += char.lower() return snake_case
[docs]class Expr(BaseModel): """Base class for all expressions."""
[docs] def accept(self, visitor: Visitor) -> Any: """Accept a visitor. Args: visitor: visitor to accept. Returns: result of visiting. """ return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")( self )
[docs]class Operator(str, Enum): """Enumerator of the operations.""" AND = "and" OR = "or" NOT = "not"
[docs]class Comparator(str, Enum): """Enumerator of the comparison operators.""" EQ = "eq" NE = "ne" GT = "gt" GTE = "gte" LT = "lt" LTE = "lte" CONTAIN = "contain" LIKE = "like" IN = "in" NIN = "nin"
[docs]class FilterDirective(Expr, ABC): """Filtering expression."""
[docs]class Comparison(FilterDirective): """Comparison to a value. Parameters: comparator: The comparator to use. attribute: The attribute to compare. value: The value to compare to. """ comparator: Comparator attribute: str value: Any def __init__( self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any ) -> None: super().__init__( comparator=comparator, attribute=attribute, value=value, **kwargs )
[docs]class Operation(FilterDirective): """Logical operation over other directives. Parameters: operator: The operator to use. arguments: The arguments to the operator. """ operator: Operator arguments: List[FilterDirective] def __init__( self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any ): super().__init__(operator=operator, arguments=arguments, **kwargs)
[docs]class StructuredQuery(Expr): """Structured query.""" query: str """Query string.""" filter: Optional[FilterDirective] """Filtering expression.""" limit: Optional[int] """Limit on the number of results.""" def __init__( self, query: str, filter: Optional[FilterDirective], limit: Optional[int] = None, **kwargs: Any, ): super().__init__(query=query, filter=filter, limit=limit, **kwargs)