mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-01-30 23:00:51 +00:00
feat: instead of hardcoding child tag ids into main query, include subquery
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, distinct, func, or_, select, text
|
||||
from sqlalchemy import and_, column, distinct, func, or_, select, text, union_all
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
|
||||
from sqlalchemy.sql.expression import (
|
||||
BinaryExpression,
|
||||
ColumnExpressionArgument,
|
||||
CompoundSelect,
|
||||
)
|
||||
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
|
||||
from src.core.query_lang import BaseVisitor
|
||||
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
|
||||
@@ -28,7 +32,7 @@ WITH RECURSIVE Subtags AS (
|
||||
FROM tag_subtags ts
|
||||
INNER JOIN Subtags s ON ts.child_id = s.child_id
|
||||
)
|
||||
SELECT * FROM Subtags;
|
||||
SELECT child_id FROM Subtags
|
||||
""") # noqa: E501
|
||||
|
||||
|
||||
@@ -59,7 +63,10 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
tag_ids.append(int(term.value))
|
||||
continue
|
||||
case ConstraintType.Tag:
|
||||
if len(ids := self.__get_tag_ids(term.value)) == 1:
|
||||
if (
|
||||
isinstance((ids := self.__get_tag_ids(term.value)), list)
|
||||
and len(ids) == 1
|
||||
):
|
||||
tag_ids.append(ids[0])
|
||||
continue
|
||||
|
||||
@@ -113,7 +120,9 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
def visit_not(self, node: Not) -> ColumnExpressionArgument:
|
||||
return ~self.__entry_satisfies_ast(node.child)
|
||||
|
||||
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
|
||||
def __get_tag_ids(
|
||||
self, tag_name: str, include_children: bool = True
|
||||
) -> list[int] | CompoundSelect:
|
||||
"""Given a tag name find the ids of all tags that this name could refer to."""
|
||||
with Session(self.lib.engine) as session:
|
||||
tag_ids = list(
|
||||
@@ -131,10 +140,13 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
)
|
||||
if not include_children:
|
||||
return tag_ids
|
||||
outp = []
|
||||
for tag_id in tag_ids:
|
||||
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
|
||||
return outp
|
||||
queries = [
|
||||
CHILDREN_QUERY.bindparams(tag_id=id).columns(column("child_id")) for id in tag_ids
|
||||
]
|
||||
outp = union_all(*queries)
|
||||
# if only one tag is found return that a list with that tag instead,
|
||||
# in order to make use of the optimisations in __entry_has_all_tags
|
||||
return t if len(t := list(session.scalars(outp))) == 1 else outp
|
||||
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
|
||||
|
||||
Reference in New Issue
Block a user