From 2f63aca30fbe09894c576334874fa0421d5eeb4f Mon Sep 17 00:00:00 2001 From: Jann Stute Date: Thu, 2 Jan 2025 00:02:00 +0100 Subject: [PATCH] feat: optimise tag constraint --- tagstudio/src/core/library/alchemy/visitors.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tagstudio/src/core/library/alchemy/visitors.py b/tagstudio/src/core/library/alchemy/visitors.py index 5eed4580..c60def00 100644 --- a/tagstudio/src/core/library/alchemy/visitors.py +++ b/tagstudio/src/core/library/alchemy/visitors.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING import structlog -from sqlalchemy import and_, distinct, func, or_, select, text +from sqlalchemy import and_, distinct, exists, func, or_, select, text from sqlalchemy.orm import Session from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories @@ -82,9 +82,9 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG if node.type == ConstraintType.Tag: - return TagBoxField.tags.any(Tag.id.in_(self.__get_tag_ids(node.value))) + return self.__entry_matches_tag_ids(self.__get_tag_ids(node.value)) elif node.type == ConstraintType.TagID: - return TagBoxField.tags.any(Tag.id == int(node.value)) + return self.__entry_matches_tag_ids([int(node.value)]) elif node.type == ConstraintType.Path: return Entry.path.op("GLOB")(node.value) elif node.type == ConstraintType.MediaType: @@ -113,6 +113,15 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): def visit_not(self, node: Not) -> ColumnExpressionArgument: return ~self.__entry_satisfies_ast(node.child) + def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnExpressionArgument: + """Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501 + return ( + select(1) + .correlate(TagBoxField) + .where(and_(TagField.field_id == TagBoxField.id, TagField.tag_id.in_(tag_ids))) + .exists() + ) + def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]: """Given a tag name find the ids of all tags that this name could refer to.""" with Session(self.lib.engine) as session: