mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-02-02 08:09:13 +00:00
feat: optimise AND queries (#679)
* feat: optimise tag constraint
* feat: use less subqueries
* refactoring: __entry_satisfies_ast was unnecessary
* feat: reduce time consumption of counting total results massively
* feat: log the time it takes to fetch the results
* Revert "feat: reduce time consumption of counting total results massively"
This reverts commit 30af514681.
* feat: log the time it takes to count the results
* feat: optimise __entry_has_all_tags
This commit is contained in:
@@ -596,8 +596,11 @@ class Library:
|
||||
statement = statement.where(Entry.suffix.in_(extensions))
|
||||
|
||||
statement = statement.distinct(Entry.id)
|
||||
start_time = time.time()
|
||||
query_count = select(func.count()).select_from(statement.alias("entries"))
|
||||
count_all: int = session.execute(query_count).scalar()
|
||||
end_time = time.time()
|
||||
logger.info(f"finished counting ({format_timespan(end_time-start_time)})")
|
||||
|
||||
sort_on: ColumnExpressionArgument = Entry.id
|
||||
match search.sorting_mode:
|
||||
@@ -613,9 +616,14 @@ class Library:
|
||||
query_full=str(statement.compile(compile_kwargs={"literal_binds": True})),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
items = session.scalars(statement).fetchall()
|
||||
end_time = time.time()
|
||||
logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})")
|
||||
|
||||
res = SearchResult(
|
||||
total_count=count_all,
|
||||
items=list(session.scalars(statement)),
|
||||
items=list(items),
|
||||
)
|
||||
|
||||
session.expunge_all()
|
||||
|
||||
@@ -5,12 +5,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, distinct, func, or_, select, text
|
||||
from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
|
||||
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
|
||||
from src.core.query_lang import BaseVisitor
|
||||
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
|
||||
from src.core.query_lang.ast import ANDList, Constraint, ConstraintType, Not, ORList, Property
|
||||
|
||||
from .joins import TagEntry
|
||||
from .models import Entry, Tag, TagAlias
|
||||
@@ -33,7 +32,7 @@ WITH RECURSIVE ChildTags AS (
|
||||
FROM tag_parents tp
|
||||
INNER JOIN ChildTags c ON tp.child_id = c.child_id
|
||||
)
|
||||
SELECT * FROM ChildTags;
|
||||
SELECT child_id FROM ChildTags;
|
||||
""") # noqa: E501
|
||||
|
||||
|
||||
@@ -44,17 +43,17 @@ def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
|
||||
return [item]
|
||||
|
||||
|
||||
class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
|
||||
def __init__(self, lib: Library) -> None:
|
||||
super().__init__()
|
||||
self.lib = lib
|
||||
|
||||
def visit_or_list(self, node: ORList) -> ColumnExpressionArgument:
|
||||
def visit_or_list(self, node: ORList) -> ColumnElement[bool]:
|
||||
return or_(*[self.visit(element) for element in node.elements])
|
||||
|
||||
def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
|
||||
def visit_and_list(self, node: ANDList) -> ColumnElement[bool]:
|
||||
tag_ids: list[int] = []
|
||||
bool_expressions: list[ColumnExpressionArgument] = []
|
||||
bool_expressions: list[ColumnElement[bool]] = []
|
||||
|
||||
# Search for TagID / unambiguous Tag Constraints and store the respective tag ids separately
|
||||
for term in node.terms:
|
||||
@@ -74,7 +73,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
tag_ids.append(ids[0])
|
||||
continue
|
||||
|
||||
bool_expressions.append(self.__entry_satisfies_ast(term))
|
||||
bool_expressions.append(self.visit(term))
|
||||
|
||||
# If there are at least two tag ids use a relational division query
|
||||
# to efficiently check all of them
|
||||
@@ -88,15 +87,15 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
|
||||
return and_(*bool_expressions)
|
||||
|
||||
def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument:
|
||||
def visit_constraint(self, node: Constraint) -> ColumnElement[bool]:
|
||||
"""Returns a Boolean Expression that is true, if the Entry satisfies the constraint."""
|
||||
if len(node.properties) != 0:
|
||||
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG
|
||||
|
||||
if node.type == ConstraintType.Tag:
|
||||
return Entry.tags.any(Tag.id.in_(self.__get_tag_ids(node.value)))
|
||||
return self.__entry_matches_tag_ids(self.__get_tag_ids(node.value))
|
||||
elif node.type == ConstraintType.TagID:
|
||||
return Entry.tags.any(Tag.id == int(node.value))
|
||||
return self.__entry_matches_tag_ids([int(node.value)])
|
||||
elif node.type == ConstraintType.Path:
|
||||
return Entry.path.op("GLOB")(node.value)
|
||||
elif node.type == ConstraintType.MediaType:
|
||||
@@ -120,8 +119,17 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
def visit_property(self, node: Property) -> None:
|
||||
raise NotImplementedError("This should never be reached!")
|
||||
|
||||
def visit_not(self, node: Not) -> ColumnExpressionArgument:
|
||||
return ~self.__entry_satisfies_ast(node.child)
|
||||
def visit_not(self, node: Not) -> ColumnElement[bool]:
|
||||
return ~self.visit(node.child)
|
||||
|
||||
def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnElement[bool]:
|
||||
"""Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501
|
||||
return (
|
||||
select(1)
|
||||
.correlate(Entry)
|
||||
.where(and_(TagEntry.entry_id == Entry.id, TagEntry.tag_id.in_(tag_ids)))
|
||||
.exists()
|
||||
)
|
||||
|
||||
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
|
||||
"""Given a tag name find the ids of all tags that this name could refer to."""
|
||||
@@ -146,24 +154,17 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
|
||||
return outp
|
||||
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
|
||||
# Relational Division Query
|
||||
return Entry.id.in_(
|
||||
select(Entry.id)
|
||||
.outerjoin(TagEntry)
|
||||
select(TagEntry.entry_id)
|
||||
.where(TagEntry.tag_id.in_(tag_ids))
|
||||
.group_by(Entry.id)
|
||||
.group_by(TagEntry.entry_id)
|
||||
.having(func.count(distinct(TagEntry.tag_id)) == len(tag_ids))
|
||||
)
|
||||
|
||||
def __entry_satisfies_ast(self, partial_query: AST) -> BinaryExpression[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry satisfies the partial query."""
|
||||
return self.__entry_satisfies_expression(self.visit(partial_query))
|
||||
|
||||
def __entry_satisfies_expression(
|
||||
self, expr: ColumnExpressionArgument
|
||||
) -> BinaryExpression[bool]:
|
||||
def __entry_satisfies_expression(self, expr: ColumnElement[bool]) -> ColumnElement[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry satisfies the column expression.
|
||||
|
||||
Executed on: Entry ⟕ TagEntry (Entry LEFT OUTER JOIN TagEntry).
|
||||
|
||||
Reference in New Issue
Block a user