diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index f3472442..9b2d3aa1 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from os import makedirs from pathlib import Path +from typing import TYPE_CHECKING from uuid import uuid4 from warnings import catch_warnings @@ -69,6 +70,10 @@ from .joins import TagEntry, TagParent from .models import Entry, Folder, Namespace, Preferences, Tag, TagAlias, TagColorGroup, ValueType from .visitors import SQLBoolExpressionBuilder +if TYPE_CHECKING: + from sqlalchemy import Select + + logger = structlog.get_logger(__name__) TAG_CHILDREN_QUERY = text(""" @@ -259,7 +264,7 @@ class Library: for k, v in field.items(): # Old tag fields get added as tags if k in LEGACY_TAG_FIELD_IDS: - self.add_tags_to_entry(entry_id=entry.id + 1, tag_ids=v) + self.add_tags_to_entries(entry_ids=entry.id + 1, tag_ids=v) else: self.add_field_to_entry( entry_id=(entry.id + 1), # JSON IDs start at 0 instead of 1 @@ -513,30 +518,49 @@ class Library: self, entry_id: int, with_fields: bool = True, with_tags: bool = True ) -> Entry | None: """Load entry and join with all joins and all tags.""" + # NOTE: TODO: Currently this method makes multiple separate queries to the db and combines + # those into a final Entry object (if using "with" args). This was done due to it being + # much more efficient than the existing join query, however there likely exists a single + # query that can accomplish the same task without exhibiting the same slowdown. with Session(self.engine) as session: - statement = select(Entry).where(Entry.id == entry_id) + tags: set[Tag] | None = None + tag_stmt: Select[tuple[Tag]] + entry_stmt = select(Entry).where(Entry.id == entry_id).limit(1) if with_fields: - statement = ( - statement.outerjoin(Entry.text_fields) + entry_stmt = ( + entry_stmt.outerjoin(Entry.text_fields) .outerjoin(Entry.datetime_fields) .options(selectinload(Entry.text_fields), selectinload(Entry.datetime_fields)) ) + # if with_tags: + # entry_stmt = entry_stmt.outerjoin(Entry.tags).options(selectinload(Entry.tags)) if with_tags: - statement = ( - statement.outerjoin(Entry.tags) - .outerjoin(TagAlias) - .options( - selectinload(Entry.tags).options( - joinedload(Tag.aliases), - joinedload(Tag.parent_tags), - ) + tag_stmt = select(Tag).where( + and_( + TagEntry.tag_id == Tag.id, + TagEntry.entry_id == entry_id, ) ) - entry = session.scalar(statement) + + start_time = time.time() + entry = session.scalar(entry_stmt) + if with_tags: + tags = set(session.scalars(tag_stmt)) # pyright: ignore [reportPossiblyUnboundVariable] + end_time = time.time() + logger.info( + f"[Library] Time it took to get entry: " + f"{format_timespan(end_time-start_time, max_units=5)}", + with_fields=with_fields, + with_tags=with_tags, + ) if not entry: return None session.expunge(entry) make_transient(entry) + + # Recombine the separately queried tags with the base entry object. + if with_tags and tags: + entry.tags = tags return entry def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]: @@ -1089,41 +1113,49 @@ class Library: session.rollback() return None - def add_tags_to_entry(self, entry_id: int, tag_ids: int | list[int] | set[int]) -> bool: - """Add one or more tags to an entry.""" - tag_ids = [tag_ids] if isinstance(tag_ids, int) else tag_ids + def add_tags_to_entries( + self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int] + ) -> bool: + """Add one or more tags to one or more entries.""" + entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids + tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids with Session(self.engine, expire_on_commit=False) as session: - for tag_id in tag_ids: - try: - session.add(TagEntry(tag_id=tag_id, entry_id=entry_id)) - session.flush() - except IntegrityError: - session.rollback() + for tag_id in tag_ids_: + for entry_id in entry_ids_: + try: + session.add(TagEntry(tag_id=tag_id, entry_id=entry_id)) + session.flush() + except IntegrityError: + session.rollback() try: session.commit() except IntegrityError as e: - logger.warning("[add_tags_to_entry]", warning=e) + logger.warning("[Library][add_tags_to_entries]", warning=e) session.rollback() return False return True - def remove_tags_from_entry(self, entry_id: int, tag_ids: int | list[int] | set[int]) -> bool: - """Remove one or more tags from an entry.""" + def remove_tags_from_entries( + self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int] + ) -> bool: + """Remove one or more tags from one or more entries.""" + entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids with Session(self.engine, expire_on_commit=False) as session: try: for tag_id in tag_ids_: - tag_entry = session.scalars( - select(TagEntry).where( - and_( - TagEntry.tag_id == tag_id, - TagEntry.entry_id == entry_id, + for entry_id in entry_ids_: + tag_entry = session.scalars( + select(TagEntry).where( + and_( + TagEntry.tag_id == tag_id, + TagEntry.entry_id == entry_id, + ) ) - ) - ).first() - if tag_entry: - session.delete(tag_entry) - session.commit() + ).first() + if tag_entry: + session.delete(tag_entry) + session.flush() session.commit() return True except IntegrityError as e: @@ -1331,7 +1363,7 @@ class Library: value=field.value, ) tag_ids = [tag.id for tag in from_entry.tags] - self.add_tags_to_entry(into_entry.id, tag_ids) + self.add_tags_to_entries(into_entry.id, tag_ids) self.remove_entries([from_entry.id]) @property diff --git a/tagstudio/src/qt/modals/folders_to_tags.py b/tagstudio/src/qt/modals/folders_to_tags.py index d37941ae..0b50fa46 100644 --- a/tagstudio/src/qt/modals/folders_to_tags.py +++ b/tagstudio/src/qt/modals/folders_to_tags.py @@ -75,7 +75,7 @@ def folders_to_tags(library: Library): tag = add_folders_to_tree(library, tree, folders).tag if tag and not entry.has_tag(tag): - library.add_tags_to_entry(entry.id, tag.id) + library.add_tags_to_entries(entry.id, tag.id) logger.info("Done") diff --git a/tagstudio/src/qt/ts_qt.py b/tagstudio/src/qt/ts_qt.py index 3e3c8390..9afd7f16 100644 --- a/tagstudio/src/qt/ts_qt.py +++ b/tagstudio/src/qt/ts_qt.py @@ -952,8 +952,7 @@ class QtDriver(DriverMixin, QObject): self.preview_panel.update_widgets() def add_tags_to_selected_callback(self, tag_ids: list[int]): - for entry_id in self.selected: - self.lib.add_tags_to_entry(entry_id, tag_ids) + self.lib.add_tags_to_entries(self.selected, tag_ids) def delete_files_callback(self, origin_path: str | Path, origin_id: int | None = None): """Callback to send on or more files to the system trash. @@ -1359,7 +1358,7 @@ class QtDriver(DriverMixin, QObject): exists = True if not exists: self.lib.add_field_to_entry(id, field_id=field.type_key, value=field.value) - self.lib.add_tags_to_entry(id, self.copy_buffer["tags"]) + self.lib.add_tags_to_entries(id, self.copy_buffer["tags"]) if len(self.selected) > 1: if TAG_ARCHIVED in self.copy_buffer["tags"]: self.update_badges({BadgeType.ARCHIVED: True}, origin_id=0, add_tags=False) @@ -1650,14 +1649,41 @@ class QtDriver(DriverMixin, QObject): the items. Defaults to True. """ item_ids = self.selected if (not origin_id or origin_id in self.selected) else [origin_id] + pending_entries: dict[BadgeType, list[int]] = {} + logger.info( + "[QtDriver][update_badges] Updating ItemThumb badges", + badge_values=badge_values, + origin_id=origin_id, + add_tags=add_tags, + ) for it in self.item_thumbs: if it.item_id in item_ids: for badge_type, value in badge_values.items(): if add_tags: + if not pending_entries.get(badge_type): + pending_entries[badge_type] = [] + pending_entries[badge_type].append(it.item_id) it.toggle_item_tag(it.item_id, value, BADGE_TAGS[badge_type]) it.assign_badge(badge_type, value) + if not add_tags: + return + + logger.info( + "[QtDriver][update_badges] Adding tags to updated entries", + pending_entries=pending_entries, + ) + for badge_type, value in badge_values.items(): + if value: + self.lib.add_tags_to_entries( + pending_entries.get(badge_type, []), BADGE_TAGS[badge_type] + ) + else: + self.lib.remove_tags_from_entries( + pending_entries.get(badge_type, []), BADGE_TAGS[badge_type] + ) + def filter_items(self, filter: FilterState | None = None) -> None: if not self.lib.library_dir: logger.info("Library not loaded") diff --git a/tagstudio/src/qt/widgets/item_thumb.py b/tagstudio/src/qt/widgets/item_thumb.py index 2ec04966..a55660a5 100644 --- a/tagstudio/src/qt/widgets/item_thumb.py +++ b/tagstudio/src/qt/widgets/item_thumb.py @@ -499,15 +499,11 @@ class ItemThumb(FlowWidget): toggle_value: bool, tag_id: int, ): - logger.info("toggle_item_tag", entry_id=entry_id, toggle_value=toggle_value, tag_id=tag_id) - - if toggle_value: - self.lib.add_tags_to_entry(entry_id, tag_id) - else: - self.lib.remove_tags_from_entry(entry_id, tag_id) - - if self.driver.preview_panel.is_open: - self.driver.preview_panel.update_widgets(update_preview=False) + if entry_id in self.driver.selected and self.driver.preview_panel.is_open: + if len(self.driver.selected) == 1: + self.driver.preview_panel.fields.update_toggled_tag(tag_id, toggle_value) + else: + pass def mouseMoveEvent(self, event): # noqa: N802 if event.buttons() is not Qt.MouseButton.LeftButton: diff --git a/tagstudio/src/qt/widgets/preview/field_containers.py b/tagstudio/src/qt/widgets/preview/field_containers.py index 5cc76bd5..88ecd7d9 100644 --- a/tagstudio/src/qt/widgets/preview/field_containers.py +++ b/tagstudio/src/qt/widgets/preview/field_containers.py @@ -114,13 +114,18 @@ class FieldContainers(QWidget): logger.warning("[FieldContainers] Updating Selection", entry_id=entry_id) self.cached_entries = [self.lib.get_entry_full(entry_id)] - entry_ = self.cached_entries[0] - container_len: int = len(entry_.fields) - container_index = 0 + entry = self.cached_entries[0] + self.update_granular(entry.tags, entry.fields, update_badges) + def update_granular( + self, entry_tags: set[Tag], entry_fields: list[BaseField], update_badges: bool = True + ): + """Individually update elements of the item preview.""" + container_len: int = len(entry_fields) + container_index = 0 # Write tag container(s) - if entry_.tags: - categories = self.get_tag_categories(entry_.tags) + if entry_tags: + categories = self.get_tag_categories(entry_tags) for cat, tags in sorted(categories.items(), key=lambda kv: (kv[0] is None, kv)): self.write_tag_container( container_index, tags=tags, category_tag=cat, is_mixed=False @@ -128,10 +133,10 @@ class FieldContainers(QWidget): container_index += 1 container_len += 1 if update_badges: - self.emit_badge_signals({t.id for t in entry_.tags}) + self.emit_badge_signals({t.id for t in entry_tags}) # Write field container(s) - for index, field in enumerate(entry_.fields, start=container_index): + for index, field in enumerate(entry_fields, start=container_index): self.write_container(index, field, is_mixed=False) # Hide leftover container(s) @@ -140,6 +145,17 @@ class FieldContainers(QWidget): if i > (container_len - 1): c.setHidden(True) + def update_toggled_tag(self, tag_id: int, toggle_value: bool): + """Visually add or remove a tag from the item preview without needing to query the db.""" + entry = self.cached_entries[0] + tag = self.lib.get_tag(tag_id) + if not tag: + return + new_tags = ( + entry.tags.union({tag}) if toggle_value else {t for t in entry.tags if t.id != tag_id} + ) + self.update_granular(entry_tags=new_tags, entry_fields=entry.fields, update_badges=False) + def hide_containers(self): """Hide all field and tag containers.""" for c in self.containers: @@ -262,7 +278,7 @@ class FieldContainers(QWidget): tags=tags, ) for entry_id in self.driver.selected: - self.lib.add_tags_to_entry( + self.lib.add_tags_to_entries( entry_id, tag_ids=tags, ) diff --git a/tagstudio/src/qt/widgets/tag_box.py b/tagstudio/src/qt/widgets/tag_box.py index 5f4fb848..953f0759 100644 --- a/tagstudio/src/qt/widgets/tag_box.py +++ b/tagstudio/src/qt/widgets/tag_box.py @@ -101,6 +101,6 @@ class TagBoxWidget(FieldWidget): ) for entry_id in self.driver.selected: - self.driver.lib.remove_tags_from_entry(entry_id, tag_id) + self.driver.lib.remove_tags_from_entries(entry_id, tag_id) self.updated.emit() diff --git a/tagstudio/tests/conftest.py b/tagstudio/tests/conftest.py index 7ad4a36a..4093eee3 100644 --- a/tagstudio/tests/conftest.py +++ b/tagstudio/tests/conftest.py @@ -95,7 +95,7 @@ def library(request): path=pathlib.Path("foo.txt"), fields=lib.default_fields, ) - assert lib.add_tags_to_entry(entry.id, tag.id) + assert lib.add_tags_to_entries(entry.id, tag.id) entry2 = Entry( id=2, @@ -103,7 +103,7 @@ def library(request): path=pathlib.Path("one/two/bar.md"), fields=lib.default_fields, ) - assert lib.add_tags_to_entry(entry2.id, tag2.id) + assert lib.add_tags_to_entries(entry2.id, tag2.id) assert lib.add_entries([entry, entry2]) assert len(lib.tags) == 6 diff --git a/tagstudio/tests/qt/test_field_containers.py b/tagstudio/tests/qt/test_field_containers.py index e501866d..12bffd67 100644 --- a/tagstudio/tests/qt/test_field_containers.py +++ b/tagstudio/tests/qt/test_field_containers.py @@ -119,7 +119,7 @@ def test_meta_tag_category(qt_driver, library, entry_full): panel = PreviewPanel(library, qt_driver) # Ensure the Favorite tag is on entry_full - library.add_tags_to_entry(1, entry_full.id) + library.add_tags_to_entries(1, entry_full.id) # Select the single entry qt_driver.toggle_item_selection(entry_full.id, append=False, bridge=False) @@ -151,7 +151,7 @@ def test_custom_tag_category(qt_driver, library, entry_full): ) # Ensure the Favorite tag is on entry_full - library.add_tags_to_entry(1, entry_full.id) + library.add_tags_to_entries(1, entry_full.id) # Select the single entry qt_driver.toggle_item_selection(entry_full.id, append=False, bridge=False) diff --git a/tagstudio/tests/test_library.py b/tagstudio/tests/test_library.py index eb8c4401..5ce7242d 100644 --- a/tagstudio/tests/test_library.py +++ b/tagstudio/tests/test_library.py @@ -330,8 +330,8 @@ def test_merge_entries(library: Library): tag_0 = library.add_tag(Tag(id=1000, name="tag_0")) tag_1 = library.add_tag(Tag(id=1001, name="tag_1")) tag_2 = library.add_tag(Tag(id=1002, name="tag_2")) - library.add_tags_to_entry(ids[0], [tag_0.id, tag_2.id]) - library.add_tags_to_entry(ids[1], [tag_1.id]) + library.add_tags_to_entries(ids[0], [tag_0.id, tag_2.id]) + library.add_tags_to_entries(ids[1], [tag_1.id]) library.merge_entries(entry_a, entry_b) assert library.has_path_entry(Path("b")) assert not library.has_path_entry(Path("a")) @@ -344,11 +344,11 @@ def test_merge_entries(library: Library): AssertionError() -def test_remove_tag_from_entry(library, entry_full): +def test_remove_tags_from_entries(library, entry_full): removed_tag_id = -1 for tag in entry_full.tags: removed_tag_id = tag.id - library.remove_tags_from_entry(entry_full.id, tag.id) + library.remove_tags_from_entries(entry_full.id, tag.id) entry = next(library.get_entries(with_joins=True)) assert removed_tag_id not in [t.id for t in entry.tags]