from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from langchain_aws.utilities.redis import TokenEscaper
# disable mypy error for dunder method overrides
# mypy: disable-error-code="override"
[docs]class InMemoryDBFilterOperator(Enum):
"""InMemoryDBFilterOperator enumerator is used to create
InMemoryDBFilterExpressions"""
EQ = 1
NE = 2
LT = 3
GT = 4
LE = 5
GE = 6
OR = 7
AND = 8
LIKE = 9
IN = 10
[docs]class InMemoryDBFilter:
"""Collection of InMemoryDBFilterFields."""
[docs] @staticmethod
def text(field: str) -> "InMemoryDBText":
return InMemoryDBText(field)
[docs] @staticmethod
def num(field: str) -> "InMemoryDBNum":
return InMemoryDBNum(field)
[docs] @staticmethod
def tag(field: str) -> "InMemoryDBTag":
return InMemoryDBTag(field)
[docs]class InMemoryDBFilterField:
"""Base class for InMemoryDBFilterFields."""
escaper: "TokenEscaper" = TokenEscaper()
OPERATORS: Dict[InMemoryDBFilterOperator, str] = {}
[docs] def __init__(self, field: str):
self._field = field
self._value: Any = None
self._operator: InMemoryDBFilterOperator = InMemoryDBFilterOperator.EQ
[docs] def equals(self, other: "InMemoryDBFilterField") -> bool:
if not isinstance(other, type(self)):
return False
return self._field == other._field and self._value == other._value
def _set_value(
self, val: Any, val_type: Tuple[Any], operator: InMemoryDBFilterOperator
) -> None:
# check that the operator is supported by this class
if operator not in self.OPERATORS:
raise ValueError(
f"Operator {operator} not supported by {self.__class__.__name__}. "
+ f"Supported operators are {self.OPERATORS.values()}."
)
if not isinstance(val, val_type):
raise TypeError(
f"Right side argument passed to operator {self.OPERATORS[operator]} "
f"with left side "
f"argument {self.__class__.__name__} must be of type {val_type}, "
f"received value {val}"
)
self._value = val
self._operator = operator
[docs]def check_operator_misuse(func: Callable) -> Callable:
"""Decorator to check for misuse of equality operators."""
@wraps(func)
def wrapper(instance: Any, *args: Any, **kwargs: Any) -> Any:
# Extracting 'other' from positional arguments or keyword arguments
other = kwargs.get("other") if "other" in kwargs else None
if not other:
for arg in args:
if isinstance(arg, type(instance)):
other = arg
break
if isinstance(other, type(instance)):
raise ValueError(
"Equality operators are overridden for FilterExpression creation. Use "
".equals() for equality checks"
)
return func(instance, *args, **kwargs)
return wrapper
[docs]class InMemoryDBTag(InMemoryDBFilterField):
"""InMemoryDBFilterField representing a tag in a InMemoryDB index."""
OPERATORS: Dict[InMemoryDBFilterOperator, str] = {
InMemoryDBFilterOperator.EQ: "==",
InMemoryDBFilterOperator.NE: "!=",
InMemoryDBFilterOperator.IN: "==",
}
OPERATOR_MAP: Dict[InMemoryDBFilterOperator, str] = {
InMemoryDBFilterOperator.EQ: "@%s:{%s}",
InMemoryDBFilterOperator.NE: "(-@%s:{%s})",
InMemoryDBFilterOperator.IN: "@%s:{%s}",
}
SUPPORTED_VAL_TYPES = (list, set, tuple, str, type(None))
[docs] def __init__(self, field: str):
"""Create a InMemoryDBTag FilterField.
Args:
field (str): The name of the InMemoryDBTag field in the index to be queried
against.
"""
super().__init__(field)
def _set_tag_value(
self,
other: Union[List[str], Set[str], Tuple[str], str],
operator: InMemoryDBFilterOperator,
) -> None:
if isinstance(other, (list, set, tuple)):
try:
# "if val" clause removes non-truthy values from list
other = [str(val) for val in other if val]
except ValueError:
raise ValueError("All tags within collection must be strings")
# above to catch the "" case
elif not other:
other = []
elif isinstance(other, str):
other = [other]
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore
@check_operator_misuse
def __eq__(
self, other: Union[List[str], Set[str], Tuple[str], str]
) -> "InMemoryDBFilterExpression":
"""Create a InMemoryDBTag equality filter expression.
Args:
other (Union[List[str], Set[str], Tuple[str], str]):
The tag(s) to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBTag
>>> filter = InMemoryDBTag("brand") == "nike"
"""
self._set_tag_value(other, InMemoryDBFilterOperator.EQ)
return InMemoryDBFilterExpression(str(self))
@check_operator_misuse
def __ne__(
self, other: Union[List[str], Set[str], Tuple[str], str]
) -> "InMemoryDBFilterExpression":
"""Create a InMemoryDBTag inequality filter expression.
Args:
other (Union[List[str], Set[str], Tuple[str], str]):
The tag(s) to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBTag
>>> filter = InMemoryDBTag("brand") != "nike"
"""
self._set_tag_value(other, InMemoryDBFilterOperator.NE)
return InMemoryDBFilterExpression(str(self))
@property
def _formatted_tag_value(self) -> str:
return "|".join([self.escaper.escape(tag) for tag in self._value])
def __str__(self) -> str:
"""Return the query syntax for a InMemoryDBTag filter expression."""
if not self._value:
return "*"
return self.OPERATOR_MAP[self._operator] % (
self._field,
self._formatted_tag_value,
)
[docs]class InMemoryDBNum(InMemoryDBFilterField):
"""InMemoryDBFilterField representing a numeric field in a InMemoryDB index."""
OPERATORS: Dict[InMemoryDBFilterOperator, str] = {
InMemoryDBFilterOperator.EQ: "==",
InMemoryDBFilterOperator.NE: "!=",
InMemoryDBFilterOperator.LT: "<",
InMemoryDBFilterOperator.GT: ">",
InMemoryDBFilterOperator.LE: "<=",
InMemoryDBFilterOperator.GE: ">=",
}
OPERATOR_MAP: Dict[InMemoryDBFilterOperator, str] = {
InMemoryDBFilterOperator.EQ: "@%s:[%s %s]",
InMemoryDBFilterOperator.NE: "(-@%s:[%s %s])",
InMemoryDBFilterOperator.GT: "@%s:[(%s +inf]",
InMemoryDBFilterOperator.LT: "@%s:[-inf (%s]",
InMemoryDBFilterOperator.GE: "@%s:[%s +inf]",
InMemoryDBFilterOperator.LE: "@%s:[-inf %s]",
}
SUPPORTED_VAL_TYPES = (int, float, type(None))
def __str__(self) -> str:
"""Return the query syntax for a InMemoryDBNum filter expression."""
if self._value is None:
return "*"
if (
self._operator == InMemoryDBFilterOperator.EQ
or self._operator == InMemoryDBFilterOperator.NE
):
return self.OPERATOR_MAP[self._operator] % (
self._field,
self._value,
self._value,
)
else:
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
@check_operator_misuse
def __eq__(self, other: Union[int, float]) -> "InMemoryDBFilterExpression":
"""Create a Numeric equality filter expression.
Args:
other (Union[int, float]): The value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBNum
>>> filter = InMemoryDBNum("zipcode") == 90210
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.EQ) # type: ignore
return InMemoryDBFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: Union[int, float]) -> "InMemoryDBFilterExpression":
"""Create a Numeric inequality filter expression.
Args:
other (Union[int, float]): The value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBNum
>>> filter = InMemoryDBNum("zipcode") != 90210
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.NE) # type: ignore
return InMemoryDBFilterExpression(str(self))
def __gt__(self, other: Union[int, float]) -> "InMemoryDBFilterExpression":
"""Create a Numeric greater than filter expression.
Args:
other (Union[int, float]): The value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBNum
>>> filter = InMemoryDBNum("age") > 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.GT) # type: ignore
return InMemoryDBFilterExpression(str(self))
def __lt__(self, other: Union[int, float]) -> "InMemoryDBFilterExpression":
"""Create a Numeric less than filter expression.
Args:
other (Union[int, float]): The value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBNum
>>> filter = InMemoryDBNum("age") < 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.LT) # type: ignore
return InMemoryDBFilterExpression(str(self))
def __ge__(self, other: Union[int, float]) -> "InMemoryDBFilterExpression":
"""Create a Numeric greater than or equal to filter expression.
Args:
other (Union[int, float]): The value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBNum
>>> filter = InMemoryDBNum("age") >= 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.GE) # type: ignore
return InMemoryDBFilterExpression(str(self))
def __le__(self, other: Union[int, float]) -> "InMemoryDBFilterExpression":
"""Create a Numeric less than or equal to filter expression.
Args:
other (Union[int, float]): The value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBNum
>>> filter = InMemoryDBNum("age") <= 18
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.LE) # type: ignore
return InMemoryDBFilterExpression(str(self))
[docs]class InMemoryDBText(InMemoryDBFilterField):
"""InMemoryDBFilterField representing a text field in a InMemoryDB index."""
OPERATORS: Dict[InMemoryDBFilterOperator, str] = {
InMemoryDBFilterOperator.EQ: "==",
InMemoryDBFilterOperator.NE: "!=",
InMemoryDBFilterOperator.LIKE: "%",
}
OPERATOR_MAP: Dict[InMemoryDBFilterOperator, str] = {
InMemoryDBFilterOperator.EQ: '@%s:("%s")',
InMemoryDBFilterOperator.NE: '(-@%s:"%s")',
InMemoryDBFilterOperator.LIKE: "@%s:(%s)",
}
SUPPORTED_VAL_TYPES = (str, type(None))
@check_operator_misuse
def __eq__(self, other: str) -> "InMemoryDBFilterExpression":
"""Create a InMemoryDBText equality (exact match) filter expression.
Args:
other (str): The text value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBText
>>> filter = InMemoryDBText("job") == "engineer"
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.EQ) # type: ignore
return InMemoryDBFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: str) -> "InMemoryDBFilterExpression":
"""Create a InMemoryDBText inequality filter expression.
Args:
other (str): The text value to filter on.
Example:
>>> from langchain_community.vectorstores.InMemoryDB import InMemoryDBText
>>> filter = InMemoryDBText("job") != "engineer"
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.NE) # type: ignore
return InMemoryDBFilterExpression(str(self))
def __mod__(self, other: str) -> "InMemoryDBFilterExpression":
"""Create a InMemoryDBText "LIKE" filter expression.
Args:
other (str): The text value to filter on.
Example:
>>> from langchain_aws.vectorstores.inmemorydb import InMemoryDBText
>>> filter = InMemoryDBText("job") % "engine*" # suffix wild card match
>>> filter = InMemoryDBText("job") % "%%engine%%" # fuzzy match w/ LD
>>> filter = InMemoryDBText("job") % "engineer|doctor" # contains either
>>> filter = InMemoryDBText("job") % "engineer doctor" # contains both
"""
self._set_value(other, self.SUPPORTED_VAL_TYPES, InMemoryDBFilterOperator.LIKE) # type: ignore
return InMemoryDBFilterExpression(str(self))
def __str__(self) -> str:
"""Return the query syntax for a InMemoryDBText filter expression."""
if not self._value:
return "*"
return self.OPERATOR_MAP[self._operator] % (
self._field,
self._value,
)
[docs]class InMemoryDBFilterExpression:
"""Logical expression of InMemoryDBFilterFields.
InMemoryDBFilterExpressions can be combined using the & and | operators to create
complex logical expressions that evaluate to the InMemoryDB Query language.
This presents an interface by which users can create complex queries
without having to know the InMemoryDB Query language.
Filter expressions are not initialized directly. Instead they are built
by combining InMemoryDBFilterFields using the & and | operators.
Examples:
>>> from langchain_aws.vectorstores.inmemorydb import (
... InMemoryDBTag, InMemoryDBNum
... )
>>> brand_is_nike = InMemoryDBTag("brand") == "nike"
>>> price_is_under_100 = InMemoryDBNum("price") < 100
>>> filter = brand_is_nike & price_is_under_100
>>> print(str(filter))
(@brand:{nike} @price:[-inf (100)])
"""
[docs] def __init__(
self,
_filter: Optional[str] = None,
operator: Optional[InMemoryDBFilterOperator] = None,
left: Optional["InMemoryDBFilterExpression"] = None,
right: Optional["InMemoryDBFilterExpression"] = None,
):
self._filter = _filter
self._operator = operator
self._left = left
self._right = right
def __and__(
self, other: "InMemoryDBFilterExpression"
) -> "InMemoryDBFilterExpression":
return InMemoryDBFilterExpression(
operator=InMemoryDBFilterOperator.AND, left=self, right=other
)
def __or__(
self, other: "InMemoryDBFilterExpression"
) -> "InMemoryDBFilterExpression":
return InMemoryDBFilterExpression(
operator=InMemoryDBFilterOperator.OR, left=self, right=other
)
def __str__(self) -> str:
# top level check that allows recursive calls to __str__
if not self._filter and not self._operator:
raise ValueError("Improperly initialized InMemoryDBFilterExpression")
# if there's an operator, combine expressions accordingly
if self._operator:
if not isinstance(self._left, InMemoryDBFilterExpression) or not isinstance(
self._right, InMemoryDBFilterExpression
):
raise TypeError(
"Improper combination of filters."
"Both left and right should be type FilterExpression"
)
operator_str = (
" | " if self._operator == InMemoryDBFilterOperator.OR else " "
)
return self.format_expression(self._left, self._right, operator_str)
# check that base case, the filter is set
if not self._filter:
raise ValueError("Improperly initialized InMemoryDBFilterExpression")
return self._filter