Source code for advanced_alchemy.filters

"""Collection filter datastructures."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import abc  # noqa: TC003
from dataclasses import dataclass
from datetime import datetime  # noqa: TC003
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Generic, Literal, cast

from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, or_, text
from typing_extensions import TypeVar

if TYPE_CHECKING:
    from typing import Callable

    from sqlalchemy import ColumnElement
    from sqlalchemy.orm import InstrumentedAttribute
    from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
    from typing_extensions import TypeAlias

    from advanced_alchemy import base


__all__ = (
    "BeforeAfter",
    "CollectionFilter",
    "FilterTypes",
    "InAnyFilter",
    "LimitOffset",
    "NotInCollectionFilter",
    "NotInSearchFilter",
    "OnBeforeAfter",
    "OrderBy",
    "PaginationFilter",
    "SearchFilter",
    "StatementFilter",
    "StatementFilterT",
    "StatementTypeT",
)

T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="base.ModelProtocol")
StatementFilterT = TypeVar("StatementFilterT", bound="StatementFilter")
StatementTypeT = TypeVar(
    "StatementTypeT",
    bound="ReturningDelete[tuple[Any]] |  ReturningUpdate[tuple[Any]] | Select[tuple[Any]] | Select[Any] | Update | Delete",
)
FilterTypes: TypeAlias = "BeforeAfter | OnBeforeAfter | CollectionFilter[Any] | LimitOffset | OrderBy | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter"
"""Aggregate type alias of the types supported for collection filtering."""


class StatementFilter(ABC):
    @abstractmethod
    def append_to_statement(
        self, statement: StatementTypeT, model: type[ModelT], *args: Any, **kwargs: Any
    ) -> StatementTypeT:
        return statement

    @staticmethod
    def _get_instrumented_attr(model: Any, key: str | InstrumentedAttribute[Any]) -> InstrumentedAttribute[Any]:
        if isinstance(key, str):
            return cast("InstrumentedAttribute[Any]", getattr(model, key))
        return key


[docs] @dataclass class BeforeAfter(StatementFilter): """Data required to filter a query on a ``datetime`` column.""" field_name: str """Name of the model attribute to filter on.""" before: datetime | None """Filter results where field earlier than this.""" after: datetime | None """Filter results where field later than this.""" def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: field = self._get_instrumented_attr(model, self.field_name) if self.before is not None: statement = cast("StatementTypeT", statement.where(field < self.before)) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] if self.after is not None: statement = cast("StatementTypeT", statement.where(field > self.after)) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
@dataclass class OnBeforeAfter(StatementFilter): """Data required to filter a query on a ``datetime`` column.""" field_name: str """Name of the model attribute to filter on.""" on_or_before: datetime | None """Filter results where field is on or earlier than this.""" on_or_after: datetime | None """Filter results where field on or later than this.""" def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: field = self._get_instrumented_attr(model, self.field_name) if self.on_or_before is not None: statement = cast("StatementTypeT", statement.where(field <= self.on_or_before)) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] if self.on_or_after is not None: statement = cast("StatementTypeT", statement.where(field >= self.on_or_after)) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement class InAnyFilter(StatementFilter, ABC): """Subclass for methods that have a `prefer_any` attribute."""
[docs] @dataclass class CollectionFilter(InAnyFilter, Generic[T]): """Data required to construct a ``WHERE ... IN (...)`` clause.""" field_name: str """Name of the model attribute to filter on.""" values: abc.Collection[T] | None """Values for ``IN`` clause. An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """ def append_to_statement( self, statement: StatementTypeT, model: type[ModelT], prefer_any: bool = False, ) -> StatementTypeT: field = self._get_instrumented_attr(model, self.field_name) if self.values is None: return statement if not self.values: return cast("StatementTypeT", statement.where(text("1=-1"))) if prefer_any: return cast("StatementTypeT", statement.where(any_(self.values) == field)) # type: ignore[arg-type] return cast("StatementTypeT", statement.where(field.in_(self.values)))
@dataclass class NotInCollectionFilter(InAnyFilter, Generic[T]): """Data required to construct a ``WHERE ... NOT IN (...)`` clause.""" field_name: str """Name of the model attribute to filter on.""" values: abc.Collection[T] | None """Values for ``NOT IN`` clause. An empty list or ``None`` will return all rows.""" def append_to_statement( self, statement: StatementTypeT, model: type[ModelT], prefer_any: bool = False, ) -> StatementTypeT: field = self._get_instrumented_attr(model, self.field_name) if not self.values: return statement if prefer_any: return cast("StatementTypeT", statement.where(any_(self.values) != field)) # type: ignore[arg-type] return cast("StatementTypeT", statement.where(field.notin_(self.values))) class PaginationFilter(StatementFilter, ABC): """Subclass for methods that function as a pagination type."""
[docs] @dataclass class LimitOffset(PaginationFilter): """Data required to add limit/offset filtering to a query.""" limit: int """Value for ``LIMIT`` clause of query.""" offset: int """Value for ``OFFSET`` clause of query.""" def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: if isinstance(statement, Select): return cast("StatementTypeT", statement.limit(self.limit).offset(self.offset)) return statement
[docs] @dataclass class OrderBy(StatementFilter): """Data required to construct a ``ORDER BY ...`` clause.""" field_name: str """Name of the model attribute to sort on.""" sort_order: Literal["asc", "desc"] = "asc" """Sort ascending or descending""" def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: if not isinstance(statement, Select): return statement field = self._get_instrumented_attr(model, self.field_name) if self.sort_order == "desc": return cast("StatementTypeT", statement.order_by(field.desc())) return cast("StatementTypeT", statement.order_by(field.asc()))
[docs] @dataclass class SearchFilter(StatementFilter): """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" field_name: str | set[str] """Name of the model attribute to search on.""" value: str """Search value.""" ignore_case: bool | None = False """Should the search be case insensitive.""" @property def _operator(self) -> Callable[..., ColumnElement[bool]]: return or_ @property def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: return attrgetter("ilike" if self.ignore_case else "like") @property def normalized_field_names(self) -> set[str]: return {self.field_name} if isinstance(self.field_name, str) else self.field_name def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]: search_clause: list[BinaryExpression[bool]] = [] for field_name in self.normalized_field_names: field = self._get_instrumented_attr(model, field_name) search_text = f"%{self.value}%" search_clause.append(self._func(field)(search_text)) return search_clause def append_to_statement( self, statement: StatementTypeT, model: type[ModelT], ) -> StatementTypeT: where_clause = self._operator(*self.get_search_clauses(model)) return cast("StatementTypeT", statement.where(where_clause))
@dataclass class NotInSearchFilter(SearchFilter): """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause.""" @property def _operator(self) -> Callable[..., ColumnElement[bool]]: return and_ @property def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: return attrgetter("not_ilike" if self.ignore_case else "not_like")