From f190547cd30d1d91e99ed4a9eda50321e4e131ac Mon Sep 17 00:00:00 2001 From: Jann Stute Date: Thu, 2 Jan 2025 00:16:18 +0100 Subject: [PATCH] feat: use less subqueries --- .../src/core/library/alchemy/visitors.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/tagstudio/src/core/library/alchemy/visitors.py b/tagstudio/src/core/library/alchemy/visitors.py index c60def00..de1e70a8 100644 --- a/tagstudio/src/core/library/alchemy/visitors.py +++ b/tagstudio/src/core/library/alchemy/visitors.py @@ -1,9 +1,8 @@ from typing import TYPE_CHECKING import structlog -from sqlalchemy import and_, distinct, exists, func, or_, select, text +from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories from src.core.query_lang import BaseVisitor from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property @@ -39,17 +38,17 @@ def get_filetype_equivalency_list(item: str) -> list[str] | set[str]: return [item] -class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): +class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]): def __init__(self, lib: Library) -> None: super().__init__() self.lib = lib - def visit_or_list(self, node: ORList) -> ColumnExpressionArgument: + def visit_or_list(self, node: ORList) -> ColumnElement[bool]: return or_(*[self.visit(element) for element in node.elements]) - def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument: + def visit_and_list(self, node: ANDList) -> ColumnElement[bool]: tag_ids: list[int] = [] - bool_expressions: list[ColumnExpressionArgument] = [] + bool_expressions: list[ColumnElement[bool]] = [] # Search for TagID / unambigous Tag Constraints and store the respective tag ids seperately for term in node.terms: @@ -77,7 +76,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): return and_(*bool_expressions) - def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument: + def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: if len(node.properties) != 0: raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG @@ -110,10 +109,10 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): def visit_property(self, node: Property) -> None: raise NotImplementedError("This should never be reached!") - def visit_not(self, node: Not) -> ColumnExpressionArgument: + def visit_not(self, node: Not) -> ColumnElement[bool]: return ~self.__entry_satisfies_ast(node.child) - def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnExpressionArgument: + def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnElement[bool]: """Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501 return ( select(1) @@ -145,7 +144,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id}))) return outp - def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]: + def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]: """Returns Binary Expression that is true if the Entry has all provided tag ids.""" # Relational Division Query return Entry.id.in_( @@ -157,13 +156,11 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): .having(func.count(distinct(TagField.tag_id)) == len(tag_ids)) ) - def __entry_satisfies_ast(self, partial_query: AST) -> BinaryExpression[bool]: + def __entry_satisfies_ast(self, partial_query: AST) -> ColumnElement[bool]: """Returns Binary Expression that is true if the Entry satisfies the partial query.""" - return self.__entry_satisfies_expression(self.visit(partial_query)) + return self.visit(partial_query) - def __entry_satisfies_expression( - self, expr: ColumnExpressionArgument - ) -> BinaryExpression[bool]: + def __entry_satisfies_expression(self, expr: ColumnElement[bool]) -> ColumnElement[bool]: """Returns Binary Expression that is true if the Entry satisfies the column expression.""" return Entry.id.in_( select(Entry.id).outerjoin(Entry.tag_box_fields).outerjoin(TagField).where(expr)