From eda7f5200f61e27b5757343da0a147f4c15fa3f6 Mon Sep 17 00:00:00 2001 From: Jann Stute Date: Thu, 28 Nov 2024 13:34:56 +0100 Subject: [PATCH] refactoring --- tagstudio/src/core/library/alchemy/enums.py | 69 +++++++++++-------- tagstudio/src/core/library/alchemy/library.py | 38 +--------- tagstudio/src/core/utils/dupe_files.py | 2 +- tagstudio/src/qt/ts_qt.py | 14 ++-- tagstudio/src/qt/widgets/item_thumb.py | 2 +- tagstudio/src/qt/widgets/tag_box.py | 2 +- tagstudio/tests/macros/test_missing_files.py | 2 +- tagstudio/tests/qt/test_qt_driver.py | 4 +- tagstudio/tests/test_filter_state.py | 8 +-- tagstudio/tests/test_library.py | 26 +++---- 10 files changed, 73 insertions(+), 94 deletions(-) diff --git a/tagstudio/src/core/library/alchemy/enums.py b/tagstudio/src/core/library/alchemy/enums.py index 14e9667f..ffc8b40f 100644 --- a/tagstudio/src/core/library/alchemy/enums.py +++ b/tagstudio/src/core/library/alchemy/enums.py @@ -1,9 +1,9 @@ import enum -from dataclasses import dataclass +from dataclasses import dataclass, replace from pathlib import Path from src.core.query_lang import AST as Query # noqa: N811 -from src.core.query_lang import Parser +from src.core.query_lang import Constraint, ConstraintType, Parser class TagColor(enum.IntEnum): @@ -77,27 +77,17 @@ class FilterState: # these should remain page_index: int | None = None page_size: int | None = None - search_mode: SearchMode = SearchMode.AND # TODO - actually implement this + search_mode: SearchMode = SearchMode.AND # TODO this can be removed? # these should be erased on update - # tag name - tag: str | None = None - # tag ID - tag_id: int | None = None - # entry id id: int | None = None # whole path path: Path | str | None = None # file name name: str | None = None - # file type - filetype: str | None = None - mediatype: str | None = None - - # a generic query to be parsed - query: str | None = None + # Abstract Syntax Tree Of the current Search Query ast: Query = None def __post_init__(self): @@ -105,34 +95,24 @@ class FilterState: query = None - if self.query is not None: - query = self.query - elif self.tag is not None: - query = self.tag.strip() - self.tag = None - elif self.tag_id is not None: - query = f"tag_id:{self.tag_id}" - self.tag_id = None - elif self.path is not None: + if self.path is not None: query = f"path:'{str(self.path).strip()}'" - self.query = query - - if query: + if query is not None: self.ast = Parser(query).parse() else: self.name = self.name and self.name.strip() self.id = int(self.id) if str(self.id).isnumeric() else self.id - if self.page_index is None: + if self.page_index is None: # TODO QTLANG can this just be a default value? self.page_index = 0 - if self.page_size is None: + if self.page_size is None: # TODO QTLANG can this just be a default value? self.page_size = 500 @property def summary(self): """Show query summary.""" - return self.query or self.tag or self.name or self.tag_id or self.path or self.id + return self.name or self.path or self.id @property def limit(self): @@ -142,6 +122,37 @@ class FilterState: def offset(self): return self.page_size * self.page_index + @classmethod + def show_all(cls) -> "FilterState": + return FilterState() + + @classmethod + def from_search_query(cls, search_query: str) -> "FilterState": + return cls(ast=Parser(search_query).parse()) + + @classmethod + def from_tag_id(cls, tag_id: int | str) -> "FilterState": + return cls(ast=Constraint(ConstraintType.TagID, str(tag_id), [])) + + @classmethod + def from_path(cls, path: Path | str) -> "FilterState": + return cls(ast=Constraint(ConstraintType.Path, str(path).strip(), [])) + + @classmethod + def from_mediatype(cls, mediatype: str) -> "FilterState": + return cls(ast=Constraint(ConstraintType.MediaType, mediatype, [])) + + @classmethod + def from_filetype(cls, filetype: str) -> "FilterState": + return cls(ast=Constraint(ConstraintType.FileType, filetype, [])) + + @classmethod + def from_tag_name(cls, tag_name: str) -> "FilterState": + return cls(ast=Constraint(ConstraintType.Tag, tag_name, [])) + + def with_page_size(self, page_size: int) -> "FilterState": + return replace(self, page_size=page_size) + class FieldTypeEnum(enum.Enum): TEXT_LINE = "Text Line" diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index e5260a55..7e83a3ff 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -425,28 +425,6 @@ class Library: .outerjoin(TagAlias) .where(SQLBoolExpressionBuilder().visit(search.ast)) ) - elif search.tag: - SubtagAlias = aliased(Tag) # noqa: N806 - statement = ( - statement.join(Entry.tag_box_fields) - .join(TagBoxField.tags) - .outerjoin(Tag.aliases) - .outerjoin(SubtagAlias, Tag.subtags) - .where( - or_( - Tag.name.ilike(search.tag), - Tag.shorthand.ilike(search.tag), - TagAlias.name.ilike(search.tag), - SubtagAlias.name.ilike(search.tag), - ) - ) - ) - elif search.tag_id: - statement = ( - statement.join(Entry.tag_box_fields) - .join(TagBoxField.tags) - .where(Tag.id == search.tag_id) - ) elif search.id: statement = statement.where(Entry.id == search.id) elif search.name: @@ -460,18 +438,6 @@ class Library: elif search.path: search_str = str(search.path).replace("*", "%") statement = statement.where(Entry.path.ilike(search_str)) - elif search.filetype: - statement = statement.where(Entry.suffix.ilike(f"{search.filetype}")) - elif search.mediatype: - extensions: set[str] = set[str]() - for media_cat in MediaCategories.ALL_CATEGORIES: - if search.mediatype == media_cat.name: - extensions = extensions | media_cat.extensions - break - # just need to map it to search db - suffixes do not have '.' - statement = statement.where( - Entry.suffix.in_(map(lambda x: x.replace(".", ""), extensions)) - ) extensions = self.prefs(LibraryPrefs.EXTENSION_LIST) is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) @@ -490,7 +456,9 @@ class Library: .options(selectinload(Tag.aliases), selectinload(Tag.subtags)), ) - query_count = select(func.count()).select_from(statement.alias("entries")) + query_count = select(func.count()).select_from( + statement.alias("entries") + ) # TODO this should count the number of *unique* results count_all: int = session.execute(query_count).scalar() statement = statement.limit(search.limit).offset(search.offset) diff --git a/tagstudio/src/core/utils/dupe_files.py b/tagstudio/src/core/utils/dupe_files.py index 2d0a074b..3c1d55d1 100644 --- a/tagstudio/src/core/utils/dupe_files.py +++ b/tagstudio/src/core/utils/dupe_files.py @@ -50,7 +50,7 @@ class DupeRegistry: continue results = self.library.search_library( - FilterState(path=path_relative), + FilterState.from_path(path_relative), ) if not results: diff --git a/tagstudio/src/qt/ts_qt.py b/tagstudio/src/qt/ts_qt.py index 62d27fae..11fc9809 100644 --- a/tagstudio/src/qt/ts_qt.py +++ b/tagstudio/src/qt/ts_qt.py @@ -141,7 +141,7 @@ class QtDriver(DriverMixin, QObject): self.rm: ResourceManager = ResourceManager() self.args = args self.frame_content = [] - self.filter = FilterState() + self.filter = FilterState.show_all() self.pages_count = 0 self.scrollbar_pos = 0 @@ -469,7 +469,7 @@ class QtDriver(DriverMixin, QObject): ] self.item_thumbs: list[ItemThumb] = [] self.thumb_renderers: list[ThumbRenderer] = [] - self.filter = FilterState() + self.filter = FilterState.show_all() self.init_library_window() path_result = self.evaluate_path(self.args.open) @@ -510,13 +510,17 @@ class QtDriver(DriverMixin, QObject): # Search Button search_button: QPushButton = self.main_window.searchButton search_button.clicked.connect( - lambda: self.filter_items(FilterState(query=self.main_window.searchField.text())) + lambda: self.filter_items( + FilterState.from_search_query(self.main_window.searchField.text()) + ) ) # Search Field search_field: QLineEdit = self.main_window.searchField search_field.returnPressed.connect( # TODO - parse search field for filters - lambda: self.filter_items(FilterState(query=self.main_window.searchField.text())) + lambda: self.filter_items( + FilterState.from_search_query(self.main_window.searchField.text()) + ) ) # Search Type Selector search_type_selector: QComboBox = self.main_window.comboBox_2 @@ -1142,7 +1146,7 @@ class QtDriver(DriverMixin, QObject): def set_search_type(self, mode: SearchMode = SearchMode.AND): self.filter_items( - FilterState( + FilterState( # TODO TSQLANG deal with this search_mode=mode, path=self.main_window.searchField.text(), ) diff --git a/tagstudio/src/qt/widgets/item_thumb.py b/tagstudio/src/qt/widgets/item_thumb.py index 35206609..de611d7c 100644 --- a/tagstudio/src/qt/widgets/item_thumb.py +++ b/tagstudio/src/qt/widgets/item_thumb.py @@ -455,7 +455,7 @@ class ItemThumb(FlowWidget): # update the entry self.driver.frame_content[idx] = self.lib.search_library( FilterState(id=entry.id) # TODO TSQLANG don't search, get entry directly by id - ).items[0] + ).items[0] # self.lib.get_entry(entry.id) self.driver.update_badges(update_items) diff --git a/tagstudio/src/qt/widgets/tag_box.py b/tagstudio/src/qt/widgets/tag_box.py index 3116618a..d3b3c0bf 100755 --- a/tagstudio/src/qt/widgets/tag_box.py +++ b/tagstudio/src/qt/widgets/tag_box.py @@ -99,7 +99,7 @@ class TagBoxWidget(FieldWidget): tag_widget.on_click.connect( lambda tag_id=tag.id: ( self.driver.main_window.searchField.setText(f"tag_id:{tag_id}"), - self.driver.filter_items(FilterState(query=f"tag_id:{tag_id}")), + self.driver.filter_items(FilterState.from_tag_id(tag_id)), ) ) diff --git a/tagstudio/tests/macros/test_missing_files.py b/tagstudio/tests/macros/test_missing_files.py index e90c0077..213aa18a 100644 --- a/tagstudio/tests/macros/test_missing_files.py +++ b/tagstudio/tests/macros/test_missing_files.py @@ -26,5 +26,5 @@ def test_refresh_missing_files(library: Library): assert list(registry.fix_missing_files()) == [1, 2] # `bar.md` should be relinked to new correct path - results = library.search_library(FilterState(path="bar.md")) + results = library.search_library(FilterState.from_path("bar.md")) assert results[0].path == pathlib.Path("bar.md") diff --git a/tagstudio/tests/qt/test_qt_driver.py b/tagstudio/tests/qt/test_qt_driver.py index d405030c..a8a484d1 100644 --- a/tagstudio/tests/qt/test_qt_driver.py +++ b/tagstudio/tests/qt/test_qt_driver.py @@ -79,7 +79,7 @@ def test_library_state_update(qt_driver): assert len(qt_driver.frame_content) == 2 # filter by tag - state = FilterState(tag="foo", page_size=10) + state = FilterState.from_tag_name("foo").with_page_size(10) qt_driver.filter_items(state) assert qt_driver.filter.page_size == 10 assert len(qt_driver.frame_content) == 1 @@ -94,7 +94,7 @@ def test_library_state_update(qt_driver): assert list(entry.tags)[0].name == "foo" # When state property is changed, previous one is overwritten - state = FilterState(path="*bar.md") + state = FilterState.from_path("*bar.md") qt_driver.filter_items(state) assert len(qt_driver.frame_content) == 1 entry = qt_driver.frame_content[0] diff --git a/tagstudio/tests/test_filter_state.py b/tagstudio/tests/test_filter_state.py index f97f5f32..cd3c7c85 100644 --- a/tagstudio/tests/test_filter_state.py +++ b/tagstudio/tests/test_filter_state.py @@ -2,13 +2,13 @@ import pytest from src.core.library.alchemy.enums import FilterState -def test_filter_state_query(): +def test_filter_state_query(): # TODO TSQLANG can this test be removed? # Given query = "tag:foo" - state = FilterState(query=query) + state = FilterState.from_search_query(query) # When - assert state.tag == "foo" + # assert state.tag == "foo" @pytest.mark.parametrize( @@ -21,7 +21,7 @@ def test_filter_state_query(): ("id", int), ], ) -def test_filter_state_attrs_compare(attribute, comparator): +def test_filter_state_attrs_compare(attribute, comparator): # TODO TSQLANG rework this test # When state = FilterState(**{attribute: "2"}) diff --git a/tagstudio/tests/test_library.py b/tagstudio/tests/test_library.py index bab51ef3..15289da2 100644 --- a/tagstudio/tests/test_library.py +++ b/tagstudio/tests/test_library.py @@ -115,7 +115,7 @@ def test_library_search(library, generate_tag, entry_full): tag = list(entry_full.tags)[0] results = library.search_library( - FilterState(tag=tag.name), + FilterState.from_tag_name(tag.name), ) assert results.total_count == 1 @@ -159,11 +159,7 @@ def test_entries_count(library): new_ids = library.add_entries(entries) assert len(new_ids) == 10 - results = library.search_library( - FilterState( - page_size=5, - ) - ) + results = library.search_library(FilterState.show_all().with_page_size(5)) assert results.total_count == 12 assert len(results) == 5 @@ -234,7 +230,7 @@ def test_search_filter_extensions(library, is_exclude): # When results = library.search_library( - FilterState(), + FilterState.show_all(), ) # Then @@ -255,7 +251,7 @@ def test_search_library_case_insensitive(library): # When results = library.search_library( - FilterState(tag=tag.name.upper()), + FilterState.from_tag_name(tag.name.upper()), ) # Then @@ -285,7 +281,7 @@ def test_save_windows_path(library, generate_tag): # library.add_tag(tag) library.add_field_tag(entry, tag, create_field=True) - results = library.search_library(FilterState(tag=tag_name)) + results = library.search_library(FilterState.from_tag_name(tag_name)) assert results # path should be saved in posix format @@ -474,36 +470,36 @@ def test_library_prefs_multiple_identical_vals(): def test_path_search_glob_after(library: Library): - results = library.search_library(FilterState(path="foo*")) + results = library.search_library(FilterState.from_path("foo*")) assert results.total_count == 1 assert len(results.items) == 1 def test_path_search_glob_in_front(library: Library): - results = library.search_library(FilterState(path="*bar.md")) + results = library.search_library(FilterState.from_path("*bar.md")) assert results.total_count == 1 assert len(results.items) == 1 def test_path_search_glob_both_sides(library: Library): - results = library.search_library(FilterState(path="*one/two*")) + results = library.search_library(FilterState.from_path("*one/two*")) assert results.total_count == 1 assert len(results.items) == 1 @pytest.mark.parametrize(["filetype", "num_of_filetype"], [("md", 1), ("txt", 1), ("png", 0)]) def test_filetype_search(library, filetype, num_of_filetype): - results = library.search_library(FilterState(filetype=filetype)) + results = library.search_library(FilterState.from_filetype(filetype)) assert len(results.items) == num_of_filetype @pytest.mark.parametrize(["filetype", "num_of_filetype"], [("png", 2), ("apng", 1), ("ng", 0)]) def test_filetype_return_one_filetype(file_mediatypes_library, filetype, num_of_filetype): - results = file_mediatypes_library.search_library(FilterState(filetype=filetype)) + results = file_mediatypes_library.search_library(FilterState.from_filetype(filetype)) assert len(results.items) == num_of_filetype @pytest.mark.parametrize(["mediatype", "num_of_mediatype"], [("plaintext", 2), ("image", 0)]) def test_mediatype_search(library, mediatype, num_of_mediatype): - results = library.search_library(FilterState(mediatype=mediatype)) + results = library.search_library(FilterState.from_mediatype(mediatype)) assert len(results.items) == num_of_mediatype