perf: optimize page loading by referencing entries by ID (#954)

* refactor: search_library now returns ids instead of entries

* perf: optimize update_thumbs
This commit is contained in:
TheBobBobs
2025-08-04 23:04:07 +00:00
committed by GitHub
parent c2261d5b83
commit e115443811
14 changed files with 105 additions and 93 deletions

View File

@@ -7,7 +7,7 @@ import re
import shutil
import time
import unicodedata
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from datetime import UTC, datetime
from os import makedirs
@@ -170,26 +170,26 @@ class SearchResult:
Attributes:
total_count(int): total number of items for given query, might be different than len(items).
items(list[Entry]): for current page (size matches filter.page_size).
ids(list[int]): for current page (size matches filter.page_size).
"""
total_count: int
items: list[Entry]
ids: list[int]
def __bool__(self) -> bool:
"""Boolean evaluation for the wrapper.
:return: True if there are items in the result.
:return: True if there are ids in the result.
"""
return self.total_count > 0
def __len__(self) -> int:
"""Return the total number of items in the result."""
return len(self.items)
"""Return the total number of ids in the result."""
return len(self.ids)
def __getitem__(self, index: int) -> Entry:
"""Allow to access items via index directly on the wrapper."""
return self.items[index]
def __getitem__(self, index: int) -> int:
"""Allow to access ids via index directly on the wrapper."""
return self.ids[index]
@dataclass
@@ -611,7 +611,7 @@ class Library:
def apply_db9_filename_population(self, session: Session):
"""Populate the filename column introduced in DB_VERSION 9."""
for entry in self.get_entries():
for entry in self.all_entries():
session.merge(entry).filename = entry.path.name
session.commit()
logger.info("[Library][Migration] Populated filename column in entries table")
@@ -692,6 +692,12 @@ class Library:
entry.tags = tags
return entry
def get_entries(self, entry_ids: Iterable[int]) -> list[Entry]:
with Session(self.engine) as session:
statement = select(Entry).where(Entry.id.in_(entry_ids))
entries = dict((e.id, e) for e in session.scalars(statement))
return [entries[id] for id in entry_ids]
def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]:
"""Load entry and join with all joins and all tags."""
with Session(self.engine) as session:
@@ -746,12 +752,25 @@ class Library:
make_transient(entry)
return entry
def get_tag_entries(
self, tag_ids: Iterable[int], entry_ids: Iterable[int]
) -> dict[int, set[int]]:
"""Returns a dict of tag_id->(entry_ids with tag_id)."""
tag_entries: dict[int, set[int]] = dict((id, set()) for id in tag_ids)
with Session(self.engine) as session:
statement = select(TagEntry).where(
and_(TagEntry.tag_id.in_(tag_ids), TagEntry.entry_id.in_(entry_ids))
)
for tag_entry in session.scalars(statement).fetchall():
tag_entries[tag_entry.tag_id].add(tag_entry.entry_id)
return tag_entries
@property
def entries_count(self) -> int:
with Session(self.engine) as session:
return session.scalar(select(func.count(Entry.id)))
def get_entries(self, with_joins: bool = False) -> Iterator[Entry]:
def all_entries(self, with_joins: bool = False) -> Iterator[Entry]:
"""Load entries without joins."""
with Session(self.engine) as session:
stmt = select(Entry)
@@ -868,7 +887,7 @@ class Library:
assert self.engine
with Session(self.engine, expire_on_commit=False) as session:
statement = select(Entry)
statement = select(Entry.id, func.count().over())
if search.ast:
start_time = time.time()
@@ -886,13 +905,6 @@ class Library:
elif extensions:
statement = statement.where(Entry.suffix.in_(extensions))
statement = statement.distinct(Entry.id)
start_time = time.time()
query_count = select(func.count()).select_from(statement.alias("entries"))
count_all: int = session.execute(query_count).scalar() or 0
end_time = time.time()
logger.info(f"finished counting ({format_timespan(end_time - start_time)})")
sort_on: ColumnExpressionArgument = Entry.id
match search.sorting_mode:
case SortingModeEnum.DATE_ADDED:
@@ -913,13 +925,18 @@ class Library:
)
start_time = time.time()
items = session.scalars(statement).fetchall()
rows = session.execute(statement).fetchall()
ids = []
count = 0
for row in rows:
id, count = row._tuple()
ids.append(id)
end_time = time.time()
logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})")
res = SearchResult(
total_count=count_all,
items=list(items),
total_count=count,
ids=ids,
)
session.expunge_all()

View File

@@ -54,12 +54,13 @@ class DupeRegistry:
results = self.library.search_library(
BrowsingState.from_path(path_relative), 500
)
entries = self.library.get_entries(results.ids)
if not results:
# file not in library
continue
files.append(results[0])
files.append(entries[0])
if not len(files) > 1:
# only one file in the group, nothing to do
@@ -79,7 +80,7 @@ class DupeRegistry:
)
for i, entries in enumerate(self.groups):
remove_ids = [x.id for x in entries[1:]]
remove_ids = entries[1:]
logger.info("Removing entries group", ids=remove_ids)
self.library.remove_entries(remove_ids)
yield i - 1 # The -1 waits for the next step to finish

View File

@@ -27,7 +27,7 @@ class MissingRegistry:
"""Track the number of entries that point to an invalid filepath."""
logger.info("[refresh_missing_files] Refreshing missing files...")
self.missing_file_entries = []
for i, entry in enumerate(self.library.get_entries()):
for i, entry in enumerate(self.library.all_entries()):
full_path = self.library.library_dir / entry.path
if not full_path.exists() or not full_path.is_file():
self.missing_file_entries.append(entry)

View File

@@ -70,7 +70,7 @@ def folders_to_tags(library: Library):
reversed_tag = reverse_tag(library, tag, None)
add_tag_to_tree(reversed_tag)
for entry in library.get_entries():
for entry in library.all_entries():
folders = entry.path.parts[0:-1]
if not folders:
continue
@@ -125,7 +125,7 @@ def generate_preview_data(library: Library) -> BranchData:
reversed_tag = reverse_tag(library, tag, None)
add_tag_to_tree(reversed_tag)
for entry in library.get_entries():
for entry in library.all_entries():
folders = entry.path.parts[0:-1]
if not folders:
continue

View File

@@ -1437,7 +1437,8 @@ class QtDriver(DriverMixin, QObject):
logger.info("[QtDriver] Loading Entries...")
# TODO: The full entries with joins don't need to be grabbed here.
# Use a method that only selects the frame content but doesn't include the joins.
entries: list[Entry] = list(self.lib.get_entries_full(self.frame_content))
entries = self.lib.get_entries(self.frame_content)
tag_entries = self.lib.get_tag_entries([TAG_ARCHIVED, TAG_FAVORITE], self.frame_content)
logger.info("[QtDriver] Building Filenames...")
filenames: list[Path] = [self.lib.library_dir / e.path for e in entries]
logger.info("[QtDriver] Done! Processing ItemThumbs...")
@@ -1483,27 +1484,8 @@ class QtDriver(DriverMixin, QObject):
(time.time(), filenames[index], base_size, ratio, is_loading, is_grid_thumb),
)
)
item_thumb.assign_badge(BadgeType.ARCHIVED, entry.is_archived)
item_thumb.assign_badge(BadgeType.FAVORITE, entry.is_favorite)
item_thumb.update_clickable(
clickable=(
lambda checked=False, item_id=entry.id: self.toggle_item_selection(
item_id,
append=(
QGuiApplication.keyboardModifiers()
== Qt.KeyboardModifier.ControlModifier
),
bridge=(
QGuiApplication.keyboardModifiers() == Qt.KeyboardModifier.ShiftModifier
),
)
)
)
item_thumb.delete_action.triggered.connect(
lambda checked=False, f=filenames[index], e_id=entry.id: self.delete_files_callback(
f, e_id
)
)
item_thumb.assign_badge(BadgeType.ARCHIVED, entry.id in tag_entries[TAG_ARCHIVED])
item_thumb.assign_badge(BadgeType.FAVORITE, entry.id in tag_entries[TAG_FAVORITE])
# Restore Selected Borders
is_selected = item_thumb.item_id in self.selected
@@ -1588,7 +1570,7 @@ class QtDriver(DriverMixin, QObject):
)
# update page content
self.frame_content = [item.id for item in results.items]
self.frame_content = results.ids
self.update_thumbs()
# update pagination

View File

@@ -14,7 +14,7 @@ from warnings import catch_warnings
import structlog
from PIL import Image, ImageQt
from PySide6.QtCore import QEvent, QMimeData, QSize, Qt, QUrl
from PySide6.QtGui import QAction, QDrag, QEnterEvent, QMouseEvent, QPixmap
from PySide6.QtGui import QAction, QDrag, QEnterEvent, QGuiApplication, QMouseEvent, QPixmap
from PySide6.QtWidgets import (
QBoxLayout,
QCheckBox,
@@ -321,7 +321,16 @@ class ItemThumb(FlowWidget):
self.base_layout.addWidget(self.thumb_container)
self.base_layout.addWidget(self.file_label)
self.thumb_button.clicked.connect(
lambda: self.driver.toggle_item_selection(
self.item_id,
append=(QGuiApplication.keyboardModifiers() == Qt.KeyboardModifier.ControlModifier),
bridge=(QGuiApplication.keyboardModifiers() == Qt.KeyboardModifier.ShiftModifier),
)
)
self.delete_action.triggered.connect(
lambda: self.driver.delete_files_callback(self.opener.filepath, self.item_id)
)
self.set_mode(mode)
@property

View File

@@ -136,5 +136,6 @@ class ThumbButton(QPushButtonWrapper):
return super().leaveEvent(event)
def set_selected(self, value: bool) -> None: # noqa: N802
self.selected = value
self.repaint()
if value != self.selected:
self.selected = value
self.repaint()

View File

@@ -134,12 +134,12 @@ def search_library() -> Library:
@pytest.fixture
def entry_min(library):
yield next(library.get_entries())
yield next(library.all_entries())
@pytest.fixture
def entry_full(library: Library):
yield next(library.get_entries(with_joins=True))
yield next(library.all_entries(with_joins=True))
@pytest.fixture
@@ -168,7 +168,7 @@ def qt_driver(qtbot, library, library_dir: Path):
driver.lib = library
# TODO - downsize this method and use it
# driver.start()
driver.frame_content = list(library.get_entries())
driver.frame_content = list(library.all_entries())
yield driver

View File

@@ -3,5 +3,5 @@ from tagstudio.qt.modals.folders_to_tags import folders_to_tags
def test_folders_to_tags(library):
folders_to_tags(library)
entry = [x for x in library.get_entries(with_joins=True) if "bar.md" in str(x.path)][0]
entry = [x for x in library.all_entries(with_joins=True) if "bar.md" in str(x.path)][0]
assert {x.name for x in entry.tags} == {"two", "bar"}

View File

@@ -29,4 +29,5 @@ def test_refresh_missing_files(library: Library):
# `bar.md` should be relinked to new correct path
results = library.search_library(BrowsingState.from_path("bar.md"), page_size=500)
assert results[0].path == Path("bar.md")
entries = library.get_entries(results.ids)
assert entries[0].path == Path("bar.md")

View File

@@ -54,7 +54,7 @@ def test_add_tag_to_selection_single(qt_driver, library, entry_full):
panel.field_containers_widget.add_tags_to_selected(2000)
# Then reload entry
refreshed_entry = next(library.get_entries(with_joins=True))
refreshed_entry = next(library.all_entries(with_joins=True))
assert {t.id for t in refreshed_entry.tags} == {1000, 2000}
@@ -71,13 +71,13 @@ def test_add_same_tag_to_selection_single(qt_driver, library, entry_full):
panel.field_containers_widget.add_tags_to_selected(1000)
# Then reload entry
refreshed_entry = next(library.get_entries(with_joins=True))
refreshed_entry = next(library.all_entries(with_joins=True))
assert {t.id for t in refreshed_entry.tags} == {1000}
def test_add_tag_to_selection_multiple(qt_driver, library):
panel = PreviewPanel(library, qt_driver)
all_entries = library.get_entries(with_joins=True)
all_entries = library.all_entries(with_joins=True)
# We want to verify that tag 1000 is on some, but not all entries already.
tag_present_on_some: bool = False
@@ -93,7 +93,7 @@ def test_add_tag_to_selection_multiple(qt_driver, library):
assert tag_absent_on_some
# Select the multiple entries
for i, e in enumerate(library.get_entries(with_joins=True), start=0):
for i, e in enumerate(library.all_entries(with_joins=True), start=0):
qt_driver.toggle_item_selection(e.id, append=(True if i == 0 else False), bridge=False) # noqa: SIM210
panel.set_selection(qt_driver.selected)
@@ -101,7 +101,7 @@ def test_add_tag_to_selection_multiple(qt_driver, library):
panel.field_containers_widget.add_tags_to_selected(1000)
# Then reload all entries and recheck the presence of tag 1000
refreshed_entries = library.get_entries(with_joins=True)
refreshed_entries = library.all_entries(with_joins=True)
tag_present_on_some: bool = False
tag_absent_on_some: bool = False

View File

@@ -68,7 +68,7 @@ if TYPE_CHECKING:
def test_browsing_state_update(qt_driver: "QtDriver"):
# Given
for entry in qt_driver.lib.get_entries(with_joins=True):
for entry in qt_driver.lib.all_entries(with_joins=True):
thumb = ItemThumb(ItemType.ENTRY, qt_driver.lib, qt_driver, (100, 100))
qt_driver.item_thumbs.append(thumb)
qt_driver.frame_content.append(entry)

View File

@@ -193,7 +193,7 @@ def test_remove_tag(library: Library, generate_tag):
@pytest.mark.parametrize("is_exclude", [True, False])
def test_search_filter_extensions(library: Library, is_exclude: bool):
# Given
entries = list(library.get_entries())
entries = list(library.all_entries())
assert len(entries) == 2, entries
library.set_prefs(LibraryPrefs.IS_EXCLUDE_LIST, is_exclude)
@@ -201,18 +201,19 @@ def test_search_filter_extensions(library: Library, is_exclude: bool):
# When
results = library.search_library(BrowsingState.show_all(), page_size=500)
entries = library.get_entries(results.ids)
# Then
assert results.total_count == 1
assert len(results) == 1
entry = results[0]
entry = entries[0]
assert (entry.path.suffix == ".txt") == is_exclude
def test_search_library_case_insensitive(library: Library):
# Given
entries = list(library.get_entries(with_joins=True))
entries = list(library.all_entries(with_joins=True))
assert len(entries) == 2, entries
entry = entries[0]
@@ -228,7 +229,7 @@ def test_search_library_case_insensitive(library: Library):
assert results.total_count == 1
assert len(results) == 1
assert results[0].id == entry.id
assert results[0] == entry.id
def test_preferences(library: Library):
@@ -241,7 +242,7 @@ def test_remove_entry_field(library: Library, entry_full):
library.remove_entry_field(title_field, [entry_full.id])
entry = next(library.get_entries(with_joins=True))
entry = next(library.all_entries(with_joins=True))
assert not entry.text_fields
@@ -257,7 +258,7 @@ def test_remove_field_entry_with_multiple_field(library: Library, entry_full):
library.remove_entry_field(title_field, [entry_full.id])
# Then one field should remain
entry = next(library.get_entries(with_joins=True))
entry = next(library.all_entries(with_joins=True))
assert len(entry.text_fields) == 1
@@ -270,7 +271,7 @@ def test_update_entry_field(library: Library, entry_full):
"new value",
)
entry = next(library.get_entries(with_joins=True))
entry = next(library.all_entries(with_joins=True))
assert entry.text_fields[0].value == "new value"
@@ -290,7 +291,7 @@ def test_update_entry_with_multiple_identical_fields(library: Library, entry_ful
)
# Then only one should be updated
entry = next(library.get_entries(with_joins=True))
entry = next(library.all_entries(with_joins=True))
assert entry.text_fields[0].value == ""
assert entry.text_fields[1].value == "new value"
@@ -378,7 +379,7 @@ def test_remove_tags_from_entries(library: Library, entry_full):
removed_tag_id = tag.id
library.remove_tags_from_entries(entry_full.id, tag.id)
entry = next(library.get_entries(with_joins=True))
entry = next(library.all_entries(with_joins=True))
assert removed_tag_id not in [t.id for t in entry.tags]
@@ -417,7 +418,7 @@ def test_update_field_order(library: Library, entry_full):
)
# Then
entry = next(library.get_entries(with_joins=True))
entry = next(library.all_entries(with_joins=True))
assert entry.text_fields[0].position == 0
assert entry.text_fields[0].value == "first"
assert entry.text_fields[1].position == 1
@@ -445,59 +446,59 @@ def test_library_prefs_multiple_identical_vals():
def test_path_search_ilike(library: Library):
results = library.search_library(BrowsingState.from_path("bar.md"), page_size=500)
assert results.total_count == 1
assert len(results.items) == 1
assert len(results.ids) == 1
def test_path_search_like(library: Library):
results = library.search_library(BrowsingState.from_path("BAR.MD"), page_size=500)
assert results.total_count == 0
assert len(results.items) == 0
assert len(results.ids) == 0
def test_path_search_default_with_sep(library: Library):
results = library.search_library(BrowsingState.from_path("one/two"), page_size=500)
assert results.total_count == 1
assert len(results.items) == 1
assert len(results.ids) == 1
def test_path_search_glob_after(library: Library):
results = library.search_library(BrowsingState.from_path("foo*"), page_size=500)
assert results.total_count == 1
assert len(results.items) == 1
assert len(results.ids) == 1
def test_path_search_glob_in_front(library: Library):
results = library.search_library(BrowsingState.from_path("*bar.md"), page_size=500)
assert results.total_count == 1
assert len(results.items) == 1
assert len(results.ids) == 1
def test_path_search_glob_both_sides(library: Library):
results = library.search_library(BrowsingState.from_path("*one/two*"), page_size=500)
assert results.total_count == 1
assert len(results.items) == 1
assert len(results.ids) == 1
# TODO: deduplicate this code with pytest parametrisation or a for loop
def test_path_search_ilike_glob_equality(library: Library):
results_ilike = library.search_library(BrowsingState.from_path("one/two"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*one/two*"), page_size=500)
assert [e.id for e in results_ilike.items] == [e.id for e in results_glob.items]
assert results_ilike.ids == results_glob.ids
results_ilike, results_glob = None, None
results_ilike = library.search_library(BrowsingState.from_path("bar.md"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*bar.md*"), page_size=500)
assert [e.id for e in results_ilike.items] == [e.id for e in results_glob.items]
assert results_ilike.ids == results_glob.ids
results_ilike, results_glob = None, None
results_ilike = library.search_library(BrowsingState.from_path("bar"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*bar*"), page_size=500)
assert [e.id for e in results_ilike.items] == [e.id for e in results_glob.items]
assert results_ilike.ids == results_glob.ids
results_ilike, results_glob = None, None
results_ilike = library.search_library(BrowsingState.from_path("bar.md"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*bar.md*"), page_size=500)
assert [e.id for e in results_ilike.items] == [e.id for e in results_glob.items]
assert results_ilike.ids == results_glob.ids
results_ilike, results_glob = None, None
@@ -505,29 +506,29 @@ def test_path_search_ilike_glob_equality(library: Library):
def test_path_search_like_glob_equality(library: Library):
results_ilike = library.search_library(BrowsingState.from_path("ONE/two"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*ONE/two*"), page_size=500)
assert [e.id for e in results_ilike.items] == [e.id for e in results_glob.items]
assert results_ilike.ids == results_glob.ids
results_ilike, results_glob = None, None
results_ilike = library.search_library(BrowsingState.from_path("BAR.MD"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*BAR.MD*"), page_size=500)
assert [e.id for e in results_ilike.items] == [e.id for e in results_glob.items]
assert results_ilike.ids == results_glob.ids
results_ilike, results_glob = None, None
results_ilike = library.search_library(BrowsingState.from_path("BAR.MD"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*bar.md*"), page_size=500)
assert [e.id for e in results_ilike.items] != [e.id for e in results_glob.items]
assert results_ilike.ids != results_glob.ids
results_ilike, results_glob = None, None
results_ilike = library.search_library(BrowsingState.from_path("bar.md"), page_size=500)
results_glob = library.search_library(BrowsingState.from_path("*BAR.MD*"), page_size=500)
assert [e.id for e in results_ilike.items] != [e.id for e in results_glob.items]
assert results_ilike.ids != results_glob.ids
results_ilike, results_glob = None, None
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("md", 1), ("txt", 1), ("png", 0)])
def test_filetype_search(library: Library, filetype, num_of_filetype):
results = library.search_library(BrowsingState.from_filetype(filetype), page_size=500)
assert len(results.items) == num_of_filetype
assert len(results.ids) == num_of_filetype
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("png", 2), ("apng", 1), ("ng", 0)])
@@ -535,10 +536,10 @@ def test_filetype_return_one_filetype(file_mediatypes_library: Library, filetype
results = file_mediatypes_library.search_library(
BrowsingState.from_filetype(filetype), page_size=500
)
assert len(results.items) == num_of_filetype
assert len(results.ids) == num_of_filetype
@pytest.mark.parametrize(["mediatype", "num_of_mediatype"], [("plaintext", 2), ("image", 0)])
def test_mediatype_search(library: Library, mediatype, num_of_mediatype):
results = library.search_library(BrowsingState.from_mediatype(mediatype), page_size=500)
assert len(results.items) == num_of_mediatype
assert len(results.ids) == num_of_mediatype

View File

@@ -8,7 +8,7 @@ from tagstudio.core.query_lang.util import ParsingError
def verify_count(lib: Library, query: str, count: int):
results = lib.search_library(BrowsingState.from_search_query(query), page_size=500)
assert results.total_count == count
assert len(results.items) == count
assert len(results.ids) == count
@pytest.mark.parametrize(