diff --git a/tagstudio/src/core/library/alchemy/visitors.py b/tagstudio/src/core/library/alchemy/visitors.py index 1756bb08..82299003 100644 --- a/tagstudio/src/core/library/alchemy/visitors.py +++ b/tagstudio/src/core/library/alchemy/visitors.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING -from sqlalchemy import and_, distinct, func, or_, select +import structlog +from sqlalchemy import and_, distinct, func, or_, select, text from sqlalchemy.orm import Session from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories @@ -16,6 +17,20 @@ if TYPE_CHECKING: else: Library = None # don't import .library because of circular imports +logger = structlog.get_logger(__name__) + +CHILDREN_QUERY = text(""" +-- Note for this entire query that tag_subtags.child_id is the parent id and tag_subtags.parent_id is the child id due to bad naming +WITH RECURSIVE Subtags AS ( + SELECT :tag_id AS child_id + UNION ALL + SELECT ts.parent_id AS child_id + FROM tag_subtags ts + INNER JOIN Subtags s ON ts.child_id = s.child_id +) +SELECT * FROM Subtags; +""") # noqa: E501 + def get_filetype_equivalency_list(item: str) -> list[str] | set[str]: for s in FILETYPE_EQUIVALENTS: @@ -98,16 +113,28 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]): def visit_not(self, node: Not) -> ColumnExpressionArgument: return ~self.__entry_satisfies_ast(node.child) - def __get_tag_ids(self, tag_name: str) -> list[int]: + def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]: """Given a tag name find the ids of all tags that this name could refer to.""" - with Session(self.lib.engine, expire_on_commit=False) as session: - return list( + with Session(self.lib.engine) as session: + tag_ids = list( session.scalars( select(Tag.id) .where(or_(Tag.name.ilike(tag_name), Tag.shorthand.ilike(tag_name))) .union(select(TagAlias.tag_id).where(TagAlias.name.ilike(tag_name))) ) ) + if len(tag_ids) > 1: + logger.debug( + f'Tag Constraint "{tag_name}" is ambiguos, {len(tag_ids)} matching tags found', + tag_ids=tag_ids, + include_children=include_children, + ) + if not include_children: + return tag_ids + outp = [] + for tag_id in tag_ids: + outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id}))) + return outp def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]: """Returns Binary Expression that is true if the Entry has all provided tag ids."""