mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-02-01 15:49:09 +00:00
feat: use less subqueries
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, distinct, exists, func, or_, select, text
|
||||
from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
|
||||
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
|
||||
from src.core.query_lang import BaseVisitor
|
||||
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
|
||||
@@ -39,17 +38,17 @@ def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
|
||||
return [item]
|
||||
|
||||
|
||||
class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
|
||||
def __init__(self, lib: Library) -> None:
|
||||
super().__init__()
|
||||
self.lib = lib
|
||||
|
||||
def visit_or_list(self, node: ORList) -> ColumnExpressionArgument:
|
||||
def visit_or_list(self, node: ORList) -> ColumnElement[bool]:
|
||||
return or_(*[self.visit(element) for element in node.elements])
|
||||
|
||||
def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
|
||||
def visit_and_list(self, node: ANDList) -> ColumnElement[bool]:
|
||||
tag_ids: list[int] = []
|
||||
bool_expressions: list[ColumnExpressionArgument] = []
|
||||
bool_expressions: list[ColumnElement[bool]] = []
|
||||
|
||||
# Search for TagID / unambigous Tag Constraints and store the respective tag ids seperately
|
||||
for term in node.terms:
|
||||
@@ -77,7 +76,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
|
||||
return and_(*bool_expressions)
|
||||
|
||||
def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument:
|
||||
def visit_constraint(self, node: Constraint) -> ColumnElement[bool]:
|
||||
if len(node.properties) != 0:
|
||||
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG
|
||||
|
||||
@@ -110,10 +109,10 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
def visit_property(self, node: Property) -> None:
|
||||
raise NotImplementedError("This should never be reached!")
|
||||
|
||||
def visit_not(self, node: Not) -> ColumnExpressionArgument:
|
||||
def visit_not(self, node: Not) -> ColumnElement[bool]:
|
||||
return ~self.__entry_satisfies_ast(node.child)
|
||||
|
||||
def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnExpressionArgument:
|
||||
def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnElement[bool]:
|
||||
"""Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501
|
||||
return (
|
||||
select(1)
|
||||
@@ -145,7 +144,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
|
||||
return outp
|
||||
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
|
||||
# Relational Division Query
|
||||
return Entry.id.in_(
|
||||
@@ -157,13 +156,11 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
.having(func.count(distinct(TagField.tag_id)) == len(tag_ids))
|
||||
)
|
||||
|
||||
def __entry_satisfies_ast(self, partial_query: AST) -> BinaryExpression[bool]:
|
||||
def __entry_satisfies_ast(self, partial_query: AST) -> ColumnElement[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry satisfies the partial query."""
|
||||
return self.__entry_satisfies_expression(self.visit(partial_query))
|
||||
return self.visit(partial_query)
|
||||
|
||||
def __entry_satisfies_expression(
|
||||
self, expr: ColumnExpressionArgument
|
||||
) -> BinaryExpression[bool]:
|
||||
def __entry_satisfies_expression(self, expr: ColumnElement[bool]) -> ColumnElement[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry satisfies the column expression."""
|
||||
return Entry.id.in_(
|
||||
select(Entry.id).outerjoin(Entry.tag_box_fields).outerjoin(TagField).where(expr)
|
||||
|
||||
Reference in New Issue
Block a user