"""Internal representation of a structured query language."""from__future__importannotationsfromabcimportABC,abstractmethodfromenumimportEnumfromtypingimportAny,List,Optional,Sequence,Unionfromlangchain_core.pydantic_v1importBaseModel
[docs]classVisitor(ABC):"""Defines interface for IR translation using a visitor pattern."""allowed_comparators:Optional[Sequence[Comparator]]=None"""Allowed comparators for the visitor."""allowed_operators:Optional[Sequence[Operator]]=None"""Allowed operators for the visitor."""def_validate_func(self,func:Union[Operator,Comparator])->None:ifisinstance(func,Operator)andself.allowed_operatorsisnotNone:iffuncnotinself.allowed_operators:raiseValueError(f"Received disallowed operator {func}. Allowed "f"comparators are {self.allowed_operators}")ifisinstance(func,Comparator)andself.allowed_comparatorsisnotNone:iffuncnotinself.allowed_comparators:raiseValueError(f"Received disallowed comparator {func}. Allowed "f"comparators are {self.allowed_comparators}")
[docs]@abstractmethoddefvisit_operation(self,operation:Operation)->Any:"""Translate an Operation. Args: operation: Operation to translate. """
[docs]@abstractmethoddefvisit_comparison(self,comparison:Comparison)->Any:"""Translate a Comparison. Args: comparison: Comparison to translate. """
[docs]@abstractmethoddefvisit_structured_query(self,structured_query:StructuredQuery)->Any:"""Translate a StructuredQuery. Args: structured_query: StructuredQuery to translate. """
def_to_snake_case(name:str)->str:"""Convert a name into snake_case."""snake_case=""fori,charinenumerate(name):ifchar.isupper()andi!=0:snake_case+="_"+char.lower()else:snake_case+=char.lower()returnsnake_case
[docs]classExpr(BaseModel):"""Base class for all expressions."""
[docs]defaccept(self,visitor:Visitor)->Any:"""Accept a visitor. Args: visitor: visitor to accept. Returns: result of visiting. """returngetattr(visitor,f"visit_{_to_snake_case(self.__class__.__name__)}")(self)
[docs]classOperator(str,Enum):"""Enumerator of the operations."""AND="and"OR="or"NOT="not"
[docs]classComparator(str,Enum):"""Enumerator of the comparison operators."""EQ="eq"NE="ne"GT="gt"GTE="gte"LT="lt"LTE="lte"CONTAIN="contain"LIKE="like"IN="in"NIN="nin"
[docs]classComparison(FilterDirective):"""Comparison to a value. Parameters: comparator: The comparator to use. attribute: The attribute to compare. value: The value to compare to. """comparator:Comparatorattribute:strvalue:Anydef__init__(self,comparator:Comparator,attribute:str,value:Any,**kwargs:Any)->None:super().__init__(comparator=comparator,attribute=attribute,value=value,**kwargs)
[docs]classOperation(FilterDirective):"""Logical operation over other directives. Parameters: operator: The operator to use. arguments: The arguments to the operator. """operator:Operatorarguments:List[FilterDirective]def__init__(self,operator:Operator,arguments:List[FilterDirective],**kwargs:Any):super().__init__(operator=operator,arguments=arguments,**kwargs)
[docs]classStructuredQuery(Expr):"""Structured query."""query:str"""Query string."""filter:Optional[FilterDirective]"""Filtering expression."""limit:Optional[int]"""Limit on the number of results."""def__init__(self,query:str,filter:Optional[FilterDirective],limit:Optional[int]=None,**kwargs:Any,):super().__init__(query=query,filter=filter,limit=limit,**kwargs)