feat: instead of hardcoding child tag ids into main query, include subquery

This commit is contained in:
Jann Stute
2024-12-30 23:40:14 +01:00
parent b79115915d
commit 2615e7dab4

View File

@@ -1,9 +1,13 @@
from typing import TYPE_CHECKING
import structlog
from sqlalchemy import and_, distinct, func, or_, select, text
from sqlalchemy import and_, column, distinct, func, or_, select, text, union_all
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
from sqlalchemy.sql.expression import (
BinaryExpression,
ColumnExpressionArgument,
CompoundSelect,
)
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
from src.core.query_lang import BaseVisitor
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
@@ -28,7 +32,7 @@ WITH RECURSIVE Subtags AS (
FROM tag_subtags ts
INNER JOIN Subtags s ON ts.child_id = s.child_id
)
SELECT * FROM Subtags;
SELECT child_id FROM Subtags
""") # noqa: E501
@@ -59,7 +63,10 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
tag_ids.append(int(term.value))
continue
case ConstraintType.Tag:
if len(ids := self.__get_tag_ids(term.value)) == 1:
if (
isinstance((ids := self.__get_tag_ids(term.value)), list)
and len(ids) == 1
):
tag_ids.append(ids[0])
continue
@@ -113,7 +120,9 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
def visit_not(self, node: Not) -> ColumnExpressionArgument:
return ~self.__entry_satisfies_ast(node.child)
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
def __get_tag_ids(
self, tag_name: str, include_children: bool = True
) -> list[int] | CompoundSelect:
"""Given a tag name find the ids of all tags that this name could refer to."""
with Session(self.lib.engine) as session:
tag_ids = list(
@@ -131,10 +140,13 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
)
if not include_children:
return tag_ids
outp = []
for tag_id in tag_ids:
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
return outp
queries = [
CHILDREN_QUERY.bindparams(tag_id=id).columns(column("child_id")) for id in tag_ids
]
outp = union_all(*queries)
# if only one tag is found return that a list with that tag instead,
# in order to make use of the optimisations in __entry_has_all_tags
return t if len(t := list(session.scalars(outp))) == 1 else outp
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""