mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-02-01 15:49:09 +00:00
feat: implement parent tag search
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import and_, distinct, func, or_, select
|
||||
import structlog
|
||||
from sqlalchemy import and_, distinct, func, or_, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
|
||||
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
|
||||
@@ -16,6 +17,20 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
Library = None # don't import .library because of circular imports
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
CHILDREN_QUERY = text("""
|
||||
-- Note for this entire query that tag_subtags.child_id is the parent id and tag_subtags.parent_id is the child id due to bad naming
|
||||
WITH RECURSIVE Subtags AS (
|
||||
SELECT :tag_id AS child_id
|
||||
UNION ALL
|
||||
SELECT ts.parent_id AS child_id
|
||||
FROM tag_subtags ts
|
||||
INNER JOIN Subtags s ON ts.child_id = s.child_id
|
||||
)
|
||||
SELECT * FROM Subtags;
|
||||
""") # noqa: E501
|
||||
|
||||
|
||||
def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
|
||||
for s in FILETYPE_EQUIVALENTS:
|
||||
@@ -98,16 +113,28 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
def visit_not(self, node: Not) -> ColumnExpressionArgument:
|
||||
return ~self.__entry_satisfies_ast(node.child)
|
||||
|
||||
def __get_tag_ids(self, tag_name: str) -> list[int]:
|
||||
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
|
||||
"""Given a tag name find the ids of all tags that this name could refer to."""
|
||||
with Session(self.lib.engine, expire_on_commit=False) as session:
|
||||
return list(
|
||||
with Session(self.lib.engine) as session:
|
||||
tag_ids = list(
|
||||
session.scalars(
|
||||
select(Tag.id)
|
||||
.where(or_(Tag.name.ilike(tag_name), Tag.shorthand.ilike(tag_name)))
|
||||
.union(select(TagAlias.tag_id).where(TagAlias.name.ilike(tag_name)))
|
||||
)
|
||||
)
|
||||
if len(tag_ids) > 1:
|
||||
logger.debug(
|
||||
f'Tag Constraint "{tag_name}" is ambiguos, {len(tag_ids)} matching tags found',
|
||||
tag_ids=tag_ids,
|
||||
include_children=include_children,
|
||||
)
|
||||
if not include_children:
|
||||
return tag_ids
|
||||
outp = []
|
||||
for tag_id in tag_ids:
|
||||
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
|
||||
return outp
|
||||
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
|
||||
|
||||
Reference in New Issue
Block a user