mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-01-28 22:01:24 +00:00
perf: optimize query methods and reduce preview panel updates (#794)
* fix(ui): don't update preview for non-selected badge toggles * ui: optimize badge updating NOTE: In theory this approach opens the doors for a visual state mismatch between the tags shown and the tags on the actual entry. The performance increase comes from the fact that `get_entry_full()` isn't called, which is the real bottleneck here. * perf: optimize query methods
This commit is contained in:
committed by
GitHub
parent
a2b9237be4
commit
6b646f8955
@@ -11,6 +11,7 @@ from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from os import makedirs
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
from warnings import catch_warnings
|
||||
|
||||
@@ -69,6 +70,10 @@ from .joins import TagEntry, TagParent
|
||||
from .models import Entry, Folder, Namespace, Preferences, Tag, TagAlias, TagColorGroup, ValueType
|
||||
from .visitors import SQLBoolExpressionBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy import Select
|
||||
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
TAG_CHILDREN_QUERY = text("""
|
||||
@@ -259,7 +264,7 @@ class Library:
|
||||
for k, v in field.items():
|
||||
# Old tag fields get added as tags
|
||||
if k in LEGACY_TAG_FIELD_IDS:
|
||||
self.add_tags_to_entry(entry_id=entry.id + 1, tag_ids=v)
|
||||
self.add_tags_to_entries(entry_ids=entry.id + 1, tag_ids=v)
|
||||
else:
|
||||
self.add_field_to_entry(
|
||||
entry_id=(entry.id + 1), # JSON IDs start at 0 instead of 1
|
||||
@@ -513,30 +518,49 @@ class Library:
|
||||
self, entry_id: int, with_fields: bool = True, with_tags: bool = True
|
||||
) -> Entry | None:
|
||||
"""Load entry and join with all joins and all tags."""
|
||||
# NOTE: TODO: Currently this method makes multiple separate queries to the db and combines
|
||||
# those into a final Entry object (if using "with" args). This was done due to it being
|
||||
# much more efficient than the existing join query, however there likely exists a single
|
||||
# query that can accomplish the same task without exhibiting the same slowdown.
|
||||
with Session(self.engine) as session:
|
||||
statement = select(Entry).where(Entry.id == entry_id)
|
||||
tags: set[Tag] | None = None
|
||||
tag_stmt: Select[tuple[Tag]]
|
||||
entry_stmt = select(Entry).where(Entry.id == entry_id).limit(1)
|
||||
if with_fields:
|
||||
statement = (
|
||||
statement.outerjoin(Entry.text_fields)
|
||||
entry_stmt = (
|
||||
entry_stmt.outerjoin(Entry.text_fields)
|
||||
.outerjoin(Entry.datetime_fields)
|
||||
.options(selectinload(Entry.text_fields), selectinload(Entry.datetime_fields))
|
||||
)
|
||||
# if with_tags:
|
||||
# entry_stmt = entry_stmt.outerjoin(Entry.tags).options(selectinload(Entry.tags))
|
||||
if with_tags:
|
||||
statement = (
|
||||
statement.outerjoin(Entry.tags)
|
||||
.outerjoin(TagAlias)
|
||||
.options(
|
||||
selectinload(Entry.tags).options(
|
||||
joinedload(Tag.aliases),
|
||||
joinedload(Tag.parent_tags),
|
||||
)
|
||||
tag_stmt = select(Tag).where(
|
||||
and_(
|
||||
TagEntry.tag_id == Tag.id,
|
||||
TagEntry.entry_id == entry_id,
|
||||
)
|
||||
)
|
||||
entry = session.scalar(statement)
|
||||
|
||||
start_time = time.time()
|
||||
entry = session.scalar(entry_stmt)
|
||||
if with_tags:
|
||||
tags = set(session.scalars(tag_stmt)) # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f"[Library] Time it took to get entry: "
|
||||
f"{format_timespan(end_time-start_time, max_units=5)}",
|
||||
with_fields=with_fields,
|
||||
with_tags=with_tags,
|
||||
)
|
||||
if not entry:
|
||||
return None
|
||||
session.expunge(entry)
|
||||
make_transient(entry)
|
||||
|
||||
# Recombine the separately queried tags with the base entry object.
|
||||
if with_tags and tags:
|
||||
entry.tags = tags
|
||||
return entry
|
||||
|
||||
def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]:
|
||||
@@ -1089,41 +1113,49 @@ class Library:
|
||||
session.rollback()
|
||||
return None
|
||||
|
||||
def add_tags_to_entry(self, entry_id: int, tag_ids: int | list[int] | set[int]) -> bool:
|
||||
"""Add one or more tags to an entry."""
|
||||
tag_ids = [tag_ids] if isinstance(tag_ids, int) else tag_ids
|
||||
def add_tags_to_entries(
|
||||
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
|
||||
) -> bool:
|
||||
"""Add one or more tags to one or more entries."""
|
||||
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
|
||||
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
for tag_id in tag_ids:
|
||||
try:
|
||||
session.add(TagEntry(tag_id=tag_id, entry_id=entry_id))
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
session.rollback()
|
||||
for tag_id in tag_ids_:
|
||||
for entry_id in entry_ids_:
|
||||
try:
|
||||
session.add(TagEntry(tag_id=tag_id, entry_id=entry_id))
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
session.rollback()
|
||||
try:
|
||||
session.commit()
|
||||
except IntegrityError as e:
|
||||
logger.warning("[add_tags_to_entry]", warning=e)
|
||||
logger.warning("[Library][add_tags_to_entries]", warning=e)
|
||||
session.rollback()
|
||||
return False
|
||||
return True
|
||||
|
||||
def remove_tags_from_entry(self, entry_id: int, tag_ids: int | list[int] | set[int]) -> bool:
|
||||
"""Remove one or more tags from an entry."""
|
||||
def remove_tags_from_entries(
|
||||
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
|
||||
) -> bool:
|
||||
"""Remove one or more tags from one or more entries."""
|
||||
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
|
||||
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
for tag_id in tag_ids_:
|
||||
tag_entry = session.scalars(
|
||||
select(TagEntry).where(
|
||||
and_(
|
||||
TagEntry.tag_id == tag_id,
|
||||
TagEntry.entry_id == entry_id,
|
||||
for entry_id in entry_ids_:
|
||||
tag_entry = session.scalars(
|
||||
select(TagEntry).where(
|
||||
and_(
|
||||
TagEntry.tag_id == tag_id,
|
||||
TagEntry.entry_id == entry_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if tag_entry:
|
||||
session.delete(tag_entry)
|
||||
session.commit()
|
||||
).first()
|
||||
if tag_entry:
|
||||
session.delete(tag_entry)
|
||||
session.flush()
|
||||
session.commit()
|
||||
return True
|
||||
except IntegrityError as e:
|
||||
@@ -1331,7 +1363,7 @@ class Library:
|
||||
value=field.value,
|
||||
)
|
||||
tag_ids = [tag.id for tag in from_entry.tags]
|
||||
self.add_tags_to_entry(into_entry.id, tag_ids)
|
||||
self.add_tags_to_entries(into_entry.id, tag_ids)
|
||||
self.remove_entries([from_entry.id])
|
||||
|
||||
@property
|
||||
|
||||
@@ -75,7 +75,7 @@ def folders_to_tags(library: Library):
|
||||
|
||||
tag = add_folders_to_tree(library, tree, folders).tag
|
||||
if tag and not entry.has_tag(tag):
|
||||
library.add_tags_to_entry(entry.id, tag.id)
|
||||
library.add_tags_to_entries(entry.id, tag.id)
|
||||
|
||||
logger.info("Done")
|
||||
|
||||
|
||||
@@ -952,8 +952,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
self.preview_panel.update_widgets()
|
||||
|
||||
def add_tags_to_selected_callback(self, tag_ids: list[int]):
|
||||
for entry_id in self.selected:
|
||||
self.lib.add_tags_to_entry(entry_id, tag_ids)
|
||||
self.lib.add_tags_to_entries(self.selected, tag_ids)
|
||||
|
||||
def delete_files_callback(self, origin_path: str | Path, origin_id: int | None = None):
|
||||
"""Callback to send on or more files to the system trash.
|
||||
@@ -1359,7 +1358,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
exists = True
|
||||
if not exists:
|
||||
self.lib.add_field_to_entry(id, field_id=field.type_key, value=field.value)
|
||||
self.lib.add_tags_to_entry(id, self.copy_buffer["tags"])
|
||||
self.lib.add_tags_to_entries(id, self.copy_buffer["tags"])
|
||||
if len(self.selected) > 1:
|
||||
if TAG_ARCHIVED in self.copy_buffer["tags"]:
|
||||
self.update_badges({BadgeType.ARCHIVED: True}, origin_id=0, add_tags=False)
|
||||
@@ -1650,14 +1649,41 @@ class QtDriver(DriverMixin, QObject):
|
||||
the items. Defaults to True.
|
||||
"""
|
||||
item_ids = self.selected if (not origin_id or origin_id in self.selected) else [origin_id]
|
||||
pending_entries: dict[BadgeType, list[int]] = {}
|
||||
|
||||
logger.info(
|
||||
"[QtDriver][update_badges] Updating ItemThumb badges",
|
||||
badge_values=badge_values,
|
||||
origin_id=origin_id,
|
||||
add_tags=add_tags,
|
||||
)
|
||||
for it in self.item_thumbs:
|
||||
if it.item_id in item_ids:
|
||||
for badge_type, value in badge_values.items():
|
||||
if add_tags:
|
||||
if not pending_entries.get(badge_type):
|
||||
pending_entries[badge_type] = []
|
||||
pending_entries[badge_type].append(it.item_id)
|
||||
it.toggle_item_tag(it.item_id, value, BADGE_TAGS[badge_type])
|
||||
it.assign_badge(badge_type, value)
|
||||
|
||||
if not add_tags:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"[QtDriver][update_badges] Adding tags to updated entries",
|
||||
pending_entries=pending_entries,
|
||||
)
|
||||
for badge_type, value in badge_values.items():
|
||||
if value:
|
||||
self.lib.add_tags_to_entries(
|
||||
pending_entries.get(badge_type, []), BADGE_TAGS[badge_type]
|
||||
)
|
||||
else:
|
||||
self.lib.remove_tags_from_entries(
|
||||
pending_entries.get(badge_type, []), BADGE_TAGS[badge_type]
|
||||
)
|
||||
|
||||
def filter_items(self, filter: FilterState | None = None) -> None:
|
||||
if not self.lib.library_dir:
|
||||
logger.info("Library not loaded")
|
||||
|
||||
@@ -499,15 +499,11 @@ class ItemThumb(FlowWidget):
|
||||
toggle_value: bool,
|
||||
tag_id: int,
|
||||
):
|
||||
logger.info("toggle_item_tag", entry_id=entry_id, toggle_value=toggle_value, tag_id=tag_id)
|
||||
|
||||
if toggle_value:
|
||||
self.lib.add_tags_to_entry(entry_id, tag_id)
|
||||
else:
|
||||
self.lib.remove_tags_from_entry(entry_id, tag_id)
|
||||
|
||||
if self.driver.preview_panel.is_open:
|
||||
self.driver.preview_panel.update_widgets(update_preview=False)
|
||||
if entry_id in self.driver.selected and self.driver.preview_panel.is_open:
|
||||
if len(self.driver.selected) == 1:
|
||||
self.driver.preview_panel.fields.update_toggled_tag(tag_id, toggle_value)
|
||||
else:
|
||||
pass
|
||||
|
||||
def mouseMoveEvent(self, event): # noqa: N802
|
||||
if event.buttons() is not Qt.MouseButton.LeftButton:
|
||||
|
||||
@@ -114,13 +114,18 @@ class FieldContainers(QWidget):
|
||||
logger.warning("[FieldContainers] Updating Selection", entry_id=entry_id)
|
||||
|
||||
self.cached_entries = [self.lib.get_entry_full(entry_id)]
|
||||
entry_ = self.cached_entries[0]
|
||||
container_len: int = len(entry_.fields)
|
||||
container_index = 0
|
||||
entry = self.cached_entries[0]
|
||||
self.update_granular(entry.tags, entry.fields, update_badges)
|
||||
|
||||
def update_granular(
|
||||
self, entry_tags: set[Tag], entry_fields: list[BaseField], update_badges: bool = True
|
||||
):
|
||||
"""Individually update elements of the item preview."""
|
||||
container_len: int = len(entry_fields)
|
||||
container_index = 0
|
||||
# Write tag container(s)
|
||||
if entry_.tags:
|
||||
categories = self.get_tag_categories(entry_.tags)
|
||||
if entry_tags:
|
||||
categories = self.get_tag_categories(entry_tags)
|
||||
for cat, tags in sorted(categories.items(), key=lambda kv: (kv[0] is None, kv)):
|
||||
self.write_tag_container(
|
||||
container_index, tags=tags, category_tag=cat, is_mixed=False
|
||||
@@ -128,10 +133,10 @@ class FieldContainers(QWidget):
|
||||
container_index += 1
|
||||
container_len += 1
|
||||
if update_badges:
|
||||
self.emit_badge_signals({t.id for t in entry_.tags})
|
||||
self.emit_badge_signals({t.id for t in entry_tags})
|
||||
|
||||
# Write field container(s)
|
||||
for index, field in enumerate(entry_.fields, start=container_index):
|
||||
for index, field in enumerate(entry_fields, start=container_index):
|
||||
self.write_container(index, field, is_mixed=False)
|
||||
|
||||
# Hide leftover container(s)
|
||||
@@ -140,6 +145,17 @@ class FieldContainers(QWidget):
|
||||
if i > (container_len - 1):
|
||||
c.setHidden(True)
|
||||
|
||||
def update_toggled_tag(self, tag_id: int, toggle_value: bool):
|
||||
"""Visually add or remove a tag from the item preview without needing to query the db."""
|
||||
entry = self.cached_entries[0]
|
||||
tag = self.lib.get_tag(tag_id)
|
||||
if not tag:
|
||||
return
|
||||
new_tags = (
|
||||
entry.tags.union({tag}) if toggle_value else {t for t in entry.tags if t.id != tag_id}
|
||||
)
|
||||
self.update_granular(entry_tags=new_tags, entry_fields=entry.fields, update_badges=False)
|
||||
|
||||
def hide_containers(self):
|
||||
"""Hide all field and tag containers."""
|
||||
for c in self.containers:
|
||||
@@ -262,7 +278,7 @@ class FieldContainers(QWidget):
|
||||
tags=tags,
|
||||
)
|
||||
for entry_id in self.driver.selected:
|
||||
self.lib.add_tags_to_entry(
|
||||
self.lib.add_tags_to_entries(
|
||||
entry_id,
|
||||
tag_ids=tags,
|
||||
)
|
||||
|
||||
@@ -101,6 +101,6 @@ class TagBoxWidget(FieldWidget):
|
||||
)
|
||||
|
||||
for entry_id in self.driver.selected:
|
||||
self.driver.lib.remove_tags_from_entry(entry_id, tag_id)
|
||||
self.driver.lib.remove_tags_from_entries(entry_id, tag_id)
|
||||
|
||||
self.updated.emit()
|
||||
|
||||
@@ -95,7 +95,7 @@ def library(request):
|
||||
path=pathlib.Path("foo.txt"),
|
||||
fields=lib.default_fields,
|
||||
)
|
||||
assert lib.add_tags_to_entry(entry.id, tag.id)
|
||||
assert lib.add_tags_to_entries(entry.id, tag.id)
|
||||
|
||||
entry2 = Entry(
|
||||
id=2,
|
||||
@@ -103,7 +103,7 @@ def library(request):
|
||||
path=pathlib.Path("one/two/bar.md"),
|
||||
fields=lib.default_fields,
|
||||
)
|
||||
assert lib.add_tags_to_entry(entry2.id, tag2.id)
|
||||
assert lib.add_tags_to_entries(entry2.id, tag2.id)
|
||||
|
||||
assert lib.add_entries([entry, entry2])
|
||||
assert len(lib.tags) == 6
|
||||
|
||||
@@ -119,7 +119,7 @@ def test_meta_tag_category(qt_driver, library, entry_full):
|
||||
panel = PreviewPanel(library, qt_driver)
|
||||
|
||||
# Ensure the Favorite tag is on entry_full
|
||||
library.add_tags_to_entry(1, entry_full.id)
|
||||
library.add_tags_to_entries(1, entry_full.id)
|
||||
|
||||
# Select the single entry
|
||||
qt_driver.toggle_item_selection(entry_full.id, append=False, bridge=False)
|
||||
@@ -151,7 +151,7 @@ def test_custom_tag_category(qt_driver, library, entry_full):
|
||||
)
|
||||
|
||||
# Ensure the Favorite tag is on entry_full
|
||||
library.add_tags_to_entry(1, entry_full.id)
|
||||
library.add_tags_to_entries(1, entry_full.id)
|
||||
|
||||
# Select the single entry
|
||||
qt_driver.toggle_item_selection(entry_full.id, append=False, bridge=False)
|
||||
|
||||
@@ -330,8 +330,8 @@ def test_merge_entries(library: Library):
|
||||
tag_0 = library.add_tag(Tag(id=1000, name="tag_0"))
|
||||
tag_1 = library.add_tag(Tag(id=1001, name="tag_1"))
|
||||
tag_2 = library.add_tag(Tag(id=1002, name="tag_2"))
|
||||
library.add_tags_to_entry(ids[0], [tag_0.id, tag_2.id])
|
||||
library.add_tags_to_entry(ids[1], [tag_1.id])
|
||||
library.add_tags_to_entries(ids[0], [tag_0.id, tag_2.id])
|
||||
library.add_tags_to_entries(ids[1], [tag_1.id])
|
||||
library.merge_entries(entry_a, entry_b)
|
||||
assert library.has_path_entry(Path("b"))
|
||||
assert not library.has_path_entry(Path("a"))
|
||||
@@ -344,11 +344,11 @@ def test_merge_entries(library: Library):
|
||||
AssertionError()
|
||||
|
||||
|
||||
def test_remove_tag_from_entry(library, entry_full):
|
||||
def test_remove_tags_from_entries(library, entry_full):
|
||||
removed_tag_id = -1
|
||||
for tag in entry_full.tags:
|
||||
removed_tag_id = tag.id
|
||||
library.remove_tags_from_entry(entry_full.id, tag.id)
|
||||
library.remove_tags_from_entries(entry_full.id, tag.id)
|
||||
|
||||
entry = next(library.get_entries(with_joins=True))
|
||||
assert removed_tag_id not in [t.id for t in entry.tags]
|
||||
|
||||
Reference in New Issue
Block a user