From 958ba86ef782e6fa16b305e9bd950a9b45bd99bd Mon Sep 17 00:00:00 2001 From: Jann Stute Date: Wed, 27 Nov 2024 23:50:53 +0100 Subject: [PATCH] rudimentary search field integration --- tagstudio/src/core/library/alchemy/enums.py | 44 +++++++--------- tagstudio/src/core/library/alchemy/library.py | 10 +++- .../src/core/library/alchemy/visitors.py | 50 +++++++++++++++++++ tagstudio/src/qt/ts_qt.py | 9 ++++ 4 files changed, 86 insertions(+), 27 deletions(-) create mode 100644 tagstudio/src/core/library/alchemy/visitors.py diff --git a/tagstudio/src/core/library/alchemy/enums.py b/tagstudio/src/core/library/alchemy/enums.py index ce525019..2ed19c8a 100644 --- a/tagstudio/src/core/library/alchemy/enums.py +++ b/tagstudio/src/core/library/alchemy/enums.py @@ -2,6 +2,9 @@ import enum from dataclasses import dataclass from pathlib import Path +from src.core.query_lang import AST as Query # noqa: N811 +from src.core.query_lang import Parser + class TagColor(enum.IntEnum): DEFAULT = 1 @@ -84,36 +87,27 @@ class FilterState: # a generic query to be parsed query: str | None = None + ast: Query = None + def __post_init__(self): # strip values automatically - if query := (self.query and self.query.strip()): - # parse the value - if ":" in query: - kind, _, value = query.partition(":") - value = value.replace('"', "") - else: - # default to tag search - kind, value = "tag", query - if kind == "tag_id": - self.tag_id = int(value) - elif kind == "tag": - self.tag = value - elif kind == "path": - self.path = value - elif kind == "name": - self.name = value - elif kind == "id": - self.id = int(self.id) if str(self.id).isnumeric() else self.id - elif kind == "filetype": - self.filetype = value - elif kind == "mediatype": - self.mediatype = value + query = None + if self.query: + query = self.query + elif self.tag: + query = self.tag.strip() + self.tag = None + elif self.tag_id: + query = f"tag_id:{self.tag_id}" + self.tag_id = None + elif self.path: + query = f"path:'{str(self.path).strip()}'" + + if query: + self.ast = Parser(query).parse() else: - self.tag = self.tag and self.tag.strip() - self.tag_id = int(self.tag_id) if str(self.tag_id).isnumeric() else self.tag_id - self.path = self.path and str(self.path).strip() self.name = self.name and self.name.strip() self.id = int(self.id) if str(self.id).isnumeric() else self.id diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index f9ba256a..f6215ca7 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -49,6 +49,7 @@ from .fields import ( ) from .joins import TagField, TagSubtag from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType +from .visitors import SQLBoolExpressionBuilder logger = structlog.get_logger(__name__) @@ -417,7 +418,13 @@ class Library: with Session(self.engine, expire_on_commit=False) as session: statement = select(Entry) - if search.tag: + if search.ast: + statement = ( + statement.join(Entry.tag_box_fields) + .join(TagBoxField.tags) + .where(SQLBoolExpressionBuilder().visit(search.ast)) + ) + elif search.tag: SubtagAlias = aliased(Tag) # noqa: N806 statement = ( statement.join(Entry.tag_box_fields) @@ -439,7 +446,6 @@ class Library: .join(TagBoxField.tags) .where(Tag.id == search.tag_id) ) - elif search.id: statement = statement.where(Entry.id == search.id) elif search.name: diff --git a/tagstudio/src/core/library/alchemy/visitors.py b/tagstudio/src/core/library/alchemy/visitors.py new file mode 100644 index 00000000..9d6c3b8b --- /dev/null +++ b/tagstudio/src/core/library/alchemy/visitors.py @@ -0,0 +1,50 @@ +from sqlalchemy import and_, or_ +from sqlalchemy.orm import aliased +from sqlalchemy.sql.expression import ColumnExpressionArgument +from src.core.media_types import MediaCategories +from src.core.query_lang import BaseVisitor +from src.core.query_lang.ast import ANDList, Constraint, ConstraintType, ORList, Property + +from .models import Entry, Tag, TagAlias + + +class SQLBoolExpressionBuilder(BaseVisitor): + def visit_or_list(self, node: ORList) -> ColumnExpressionArgument: + return or_(*[self.visit(element) for element in node.elements]) + + def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument: + return and_(*[self.visit(term) for term in node.terms]) + + def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument: + if len(node.properties) != 0: + raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG + + if node.type == ConstraintType.Tag: + return or_( + Tag.name.ilike(node.value), + Tag.shorthand.ilike(node.value), + TagAlias.name.ilike(node.value), + aliased(Tag).name.ilike(node.value), + ) + elif node.type == ConstraintType.TagID: + return Tag.id == int(node.value) + elif node.type == ConstraintType.Path: + return Entry.path.ilike(node.value.replace("*", "%")) # TODO TSQLANG this is broken + elif node.type == ConstraintType.MediaType: + extensions: set[str] = set[str]() + for media_cat in MediaCategories.ALL_CATEGORIES: + if node.value == media_cat.name: + extensions = extensions | media_cat.extensions + break + return Entry.suffix.in_( + map(lambda x: x.replace(".", ""), extensions) + ) # TODO audio doesn't work on mp3 files (might be my library) + elif node.type == ConstraintType.FileType: + return Entry.suffix.ilike( + node.value + ) # TODO TSQLANG this is broken for mp3, but works for png (might be my library) + + raise NotImplementedError("This type of constraint is not implemented yet") + + def visit_property(self, node: Property) -> None: + return diff --git a/tagstudio/src/qt/ts_qt.py b/tagstudio/src/qt/ts_qt.py index f934892a..62d27fae 100644 --- a/tagstudio/src/qt/ts_qt.py +++ b/tagstudio/src/qt/ts_qt.py @@ -65,6 +65,7 @@ from src.core.constants import ( ) from src.core.driver import DriverMixin from src.core.enums import LibraryPrefs, MacroID, SettingItems +from src.core.library.alchemy import Library from src.core.library.alchemy.enums import ( FieldTypeEnum, FilterState, @@ -130,6 +131,8 @@ class QtDriver(DriverMixin, QObject): preview_panel: PreviewPanel + lib: Library + def __init__(self, backend, args): super().__init__() # prevent recursive badges update when multiple items selected @@ -1100,8 +1103,12 @@ class QtDriver(DriverMixin, QObject): if filter: self.filter = dataclasses.replace(self.filter, **dataclasses.asdict(filter)) + # inform user about running search self.main_window.statusbar.showMessage(f'Searching Library: "{self.filter.summary}"') self.main_window.statusbar.repaint() + + # search the library + start_time = time.time() results = self.lib.search_library(self.filter) @@ -1109,6 +1116,8 @@ class QtDriver(DriverMixin, QObject): logger.info("items to render", count=len(results)) end_time = time.time() + + # inform user about completed search if self.filter.summary: # fmt: off self.main_window.statusbar.showMessage(