diff --git a/tagstudio/src/core/library/alchemy/enums.py b/tagstudio/src/core/library/alchemy/enums.py index ffc8b40f..163b1c57 100644 --- a/tagstudio/src/core/library/alchemy/enums.py +++ b/tagstudio/src/core/library/alchemy/enums.py @@ -80,8 +80,6 @@ class FilterState: search_mode: SearchMode = SearchMode.AND # TODO this can be removed? # these should be erased on update - # entry id - id: int | None = None # whole path path: Path | str | None = None # file name @@ -102,7 +100,6 @@ class FilterState: self.ast = Parser(query).parse() else: self.name = self.name and self.name.strip() - self.id = int(self.id) if str(self.id).isnumeric() else self.id if self.page_index is None: # TODO QTLANG can this just be a default value? self.page_index = 0 @@ -112,7 +109,7 @@ class FilterState: @property def summary(self): """Show query summary.""" - return self.name or self.path or self.id + return self.name or self.path @property def limit(self): diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index 7e83a3ff..67859a52 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -308,6 +308,29 @@ class Library: make_transient(entry) return entry + def get_entry_full(self, entry_id: int) -> Entry | None: + """Load entry an join with all joins and all tags.""" + with Session(self.engine) as session: + statement = select(Entry).where(Entry.id == entry_id) + statement = ( + statement.outerjoin(Entry.text_fields) + .outerjoin(Entry.datetime_fields) + .outerjoin(Entry.tag_box_fields) + ) + statement = statement.options( + selectinload(Entry.text_fields), + selectinload(Entry.datetime_fields), + selectinload(Entry.tag_box_fields) + .joinedload(TagBoxField.tags) + .options(selectinload(Tag.aliases), selectinload(Tag.subtags)), + ) + entry = session.scalar(statement) + if not entry: + return None + session.expunge(entry) + make_transient(entry) + return entry + @property def entries_count(self) -> int: with Session(self.engine) as session: @@ -425,8 +448,6 @@ class Library: .outerjoin(TagAlias) .where(SQLBoolExpressionBuilder().visit(search.ast)) ) - elif search.id: - statement = statement.where(Entry.id == search.id) elif search.name: statement = select(Entry).where( and_( @@ -442,11 +463,10 @@ class Library: extensions = self.prefs(LibraryPrefs.EXTENSION_LIST) is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) - if not search.id: # if `id` is set, we don't need to filter by extensions - if extensions and is_exclude_list: - statement = statement.where(Entry.suffix.notin_(extensions)) - elif extensions: - statement = statement.where(Entry.suffix.in_(extensions)) + if extensions and is_exclude_list: + statement = statement.where(Entry.suffix.notin_(extensions)) + elif extensions: + statement = statement.where(Entry.suffix.in_(extensions)) statement = statement.options( selectinload(Entry.text_fields), diff --git a/tagstudio/src/qt/widgets/item_thumb.py b/tagstudio/src/qt/widgets/item_thumb.py index de611d7c..8991ad83 100644 --- a/tagstudio/src/qt/widgets/item_thumb.py +++ b/tagstudio/src/qt/widgets/item_thumb.py @@ -453,9 +453,7 @@ class ItemThumb(FlowWidget): entry, toggle_value, tag_id, _FieldID.TAGS_META.name, create_field=True ) # update the entry - self.driver.frame_content[idx] = self.lib.search_library( - FilterState(id=entry.id) # TODO TSQLANG don't search, get entry directly by id - ).items[0] # self.lib.get_entry(entry.id) + self.driver.frame_content[idx] = self.lib.get_entry(entry.id) self.driver.update_badges(update_items) diff --git a/tagstudio/src/qt/widgets/preview_panel.py b/tagstudio/src/qt/widgets/preview_panel.py index 87be218e..f8ad05ff 100644 --- a/tagstudio/src/qt/widgets/preview_panel.py +++ b/tagstudio/src/qt/widgets/preview_panel.py @@ -286,16 +286,13 @@ class PreviewPanel(QWidget): def update_selected_entry(self, driver: "QtDriver"): for grid_idx in driver.selected: entry = driver.frame_content[grid_idx] - results = self.lib.search_library( - FilterState(id=entry.id) - ) # TODO TSQLANG don't search, get entry directly by id + result = self.lib.get_entry_full(entry.id) logger.info( "found item", - entries=len(results.items), grid_idx=grid_idx, lookup_id=entry.id, ) - self.driver.frame_content[grid_idx] = results[0] + self.driver.frame_content[grid_idx] = result def remove_field_prompt(self, name: str) -> str: return f'Are you sure you want to remove field "{name}"?' @@ -555,16 +552,13 @@ class PreviewPanel(QWidget): # TODO - Entry reload is maybe not necessary for grid_idx in self.driver.selected: entry = self.driver.frame_content[grid_idx] - results = self.lib.search_library( - FilterState(id=entry.id) - ) # TODO TSQLANG don't search, get entry by directly by ID + result = self.lib.get_entry_full(entry.id) logger.info( "found item", - entries=len(results.items), grid_idx=grid_idx, lookup_id=entry.id, ) - self.driver.frame_content[grid_idx] = results[0] + self.driver.frame_content[grid_idx] = result if len(self.driver.selected) == 1: # 1 Selected Entry diff --git a/tagstudio/tests/conftest.py b/tagstudio/tests/conftest.py index 39c3f753..baae929e 100644 --- a/tagstudio/tests/conftest.py +++ b/tagstudio/tests/conftest.py @@ -123,7 +123,7 @@ def entry_min(library): @pytest.fixture -def entry_full(library): +def entry_full(library: Library): yield next(library.get_entries(with_joins=True)) diff --git a/tagstudio/tests/test_library.py b/tagstudio/tests/test_library.py index 15289da2..8bfb7d0f 100644 --- a/tagstudio/tests/test_library.py +++ b/tagstudio/tests/test_library.py @@ -147,11 +147,11 @@ def test_tag_search(library): ) -def test_get_entry(library, entry_min): +def test_get_entry(library: Library, entry_min): assert entry_min.id - results = library.search_library(FilterState(id=entry_min.id)) - assert len(results) == results.total_count == 1 - assert results[0].tags + result = library.get_entry_full(entry_min.id) + assert result + assert result.tags def test_entries_count(library): @@ -186,7 +186,7 @@ def test_add_field_to_entry(library): assert len(entry.tag_box_fields) == 3 -def test_add_field_tag(library, entry_full, generate_tag): +def test_add_field_tag(library: Library, entry_full, generate_tag): # Given tag_name = "xxx" tag = generate_tag(tag_name) @@ -196,8 +196,8 @@ def test_add_field_tag(library, entry_full, generate_tag): library.add_field_tag(entry_full, tag, tag_field.type_key) # Then - results = library.search_library(FilterState(id=entry_full.id)) - tag_field = results[0].tag_box_fields[0] + results = library.get_entry_full(entry_full.id) + tag_field = results.tag_box_fields[0] assert [x.name for x in tag_field.tags if x.name == tag_name] @@ -347,7 +347,8 @@ def test_update_entry_with_multiple_identical_fields(library, entry_full): assert entry.text_fields[1].value == "new value" -def test_mirror_entry_fields(library, entry_full): +def test_mirror_entry_fields(library: Library, entry_full): + # new entry target_entry = Entry( folder=library.folder, path=Path("xxx"), @@ -360,16 +361,19 @@ def test_mirror_entry_fields(library, entry_full): ], ) + # insert new entry and get id entry_id = library.add_entries([target_entry])[0] - results = library.search_library(FilterState(id=entry_id)) - new_entry = results[0] + # get new entry from library + new_entry = library.get_entry_full(entry_id) + # mirror fields onto new entry library.mirror_entry_fields(new_entry, entry_full) - results = library.search_library(FilterState(id=entry_id)) - entry = results[0] + # get new entry from library again + entry = library.get_entry_full(entry_id) + # make sure fields are there after getting it from the library again assert len(entry.fields) == 4 assert {x.type_key for x in entry.fields} == { _FieldID.TITLE.name, @@ -416,12 +420,10 @@ def test_search_file_name(library, query_name, has_result): (222, 0), ], ) -def test_search_entry_id(library, query_name, has_result): - results = library.search_library( - FilterState(id=query_name), - ) +def test_search_entry_id(library: Library, query_name: int, has_result): + result = library.get_entry(query_name) - assert results.total_count == has_result + assert (result is not None) == has_result def test_update_field_order(library, entry_full):