mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-05-25 02:02:38 +00:00
perf: Optimize searching tags with DB indexes (#1129)
* perf: create sqlite indexes for common columns * perf: optimize Library.search_tags * fix(tag_search): do ordering before applying limit * tag_search: order shorter tag names first * update tests * cleanup * tag_search: use same sorting order when returning all tags * use dict for deduplicating tags * fix(tag_search): return descendants instead of ancestors * perf(tag_search): remove slow calls to method `Library.tags`
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import unicodedata
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
@@ -571,6 +572,20 @@ class Library:
|
||||
if loaded_db_version < 200:
|
||||
self.__apply_db200_migrations(session)
|
||||
|
||||
session.execute(
|
||||
text("CREATE INDEX IF NOT EXISTS idx_tags_name_shorthand ON tags (name, shorthand)")
|
||||
)
|
||||
session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_tag_parents_child_id ON tag_parents (child_id)"
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_tag_entries_entry_id ON tag_entries (entry_id)"
|
||||
)
|
||||
)
|
||||
|
||||
# Update DB_VERSION
|
||||
if loaded_db_version < DB_VERSION:
|
||||
self.set_version(DB_VERSION_CURRENT_KEY, DB_VERSION)
|
||||
@@ -1178,55 +1193,71 @@ class Library:
|
||||
|
||||
return res
|
||||
|
||||
def search_tags(self, name: str | None, limit: int = 100) -> list[set[Tag]]:
|
||||
def search_tags(self, name: str | None, limit: int = 100) -> tuple[list[Tag], list[Tag]]:
|
||||
"""Return a list of Tag records matching the query."""
|
||||
if limit <= 0:
|
||||
limit = sys.maxsize
|
||||
|
||||
name = name or ""
|
||||
name = name.lower()
|
||||
|
||||
def sort_key(text: str):
|
||||
priority = text.startswith(name)
|
||||
p_ordering = len(text) if priority else sys.maxsize
|
||||
return (not priority, p_ordering, text)
|
||||
|
||||
with Session(self.engine) as session:
|
||||
query = select(Tag).outerjoin(TagAlias).order_by(func.lower(Tag.name))
|
||||
query = query.options(
|
||||
selectinload(Tag.parent_tags),
|
||||
selectinload(Tag.aliases),
|
||||
)
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
query = select(Tag.id, Tag.name)
|
||||
|
||||
if limit > 0 and not name:
|
||||
query = query.order_by(Tag.name).limit(limit)
|
||||
|
||||
if name:
|
||||
query = query.where(
|
||||
or_(
|
||||
Tag.name.icontains(name),
|
||||
Tag.shorthand.icontains(name),
|
||||
TagAlias.name.icontains(name),
|
||||
)
|
||||
)
|
||||
|
||||
direct_tags = set(session.scalars(query))
|
||||
ancestor_tag_ids: list[Tag] = []
|
||||
for tag in direct_tags:
|
||||
ancestor_tag_ids.extend(
|
||||
list(session.scalars(TAG_CHILDREN_QUERY, {"tag_id": tag.id}))
|
||||
)
|
||||
tags = list(session.execute(query))
|
||||
|
||||
ancestor_tags = session.scalars(
|
||||
select(Tag)
|
||||
.where(Tag.id.in_(ancestor_tag_ids))
|
||||
.options(selectinload(Tag.parent_tags), selectinload(Tag.aliases))
|
||||
)
|
||||
if name:
|
||||
query = select(TagAlias.tag_id, TagAlias.name).where(TagAlias.name.icontains(name))
|
||||
tags.extend(session.execute(query))
|
||||
|
||||
res = [
|
||||
direct_tags,
|
||||
{at for at in ancestor_tags if at not in direct_tags},
|
||||
]
|
||||
tags.sort(key=lambda t: sort_key(t[1]))
|
||||
# Use order from Tag.name or TagAlias.name depending on which comes first for each tag.
|
||||
# Value=0 to avoid unnecessary copying of tag names.
|
||||
tag_ids = list(dict((id, 0) for id, _ in tags).keys())
|
||||
|
||||
logger.info(
|
||||
"searching tags",
|
||||
search=name,
|
||||
limit=limit,
|
||||
statement=str(query),
|
||||
results=len(res),
|
||||
results=len(tag_ids),
|
||||
)
|
||||
tag_ids = tag_ids[:limit]
|
||||
|
||||
session.expunge_all()
|
||||
all_ids = set(tag_ids)
|
||||
for tag_id in tag_ids:
|
||||
if len(all_ids) >= limit:
|
||||
break
|
||||
for id in session.scalars(TAG_CHILDREN_QUERY, {"tag_id": tag_id}):
|
||||
all_ids.add(id)
|
||||
if len(all_ids) >= limit:
|
||||
break
|
||||
|
||||
return res
|
||||
hierarchy = self.get_tag_hierarchy(all_ids)
|
||||
|
||||
direct_tags = [hierarchy.pop(id) for id in tag_ids]
|
||||
|
||||
all_ids.difference_update(tag_ids)
|
||||
descendant_tags = [hierarchy.pop(id) for id in all_ids]
|
||||
descendant_tags.sort(key=lambda t: sort_key(t.name))
|
||||
|
||||
return direct_tags, descendant_tags
|
||||
|
||||
def update_entry_path(self, entry_id: int | Entry, path: Path) -> bool:
|
||||
"""Set the path field of an entry.
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
|
||||
|
||||
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from warnings import catch_warnings
|
||||
|
||||
@@ -107,9 +106,6 @@ class TagSearchPanel(PanelWidget):
|
||||
self.limit_combobox.addItems([str(x) for x in TagSearchPanel._limit_items])
|
||||
self.limit_combobox.setCurrentIndex(TagSearchPanel._default_limit_idx)
|
||||
self.limit_combobox.currentIndexChanged.connect(self.update_limit)
|
||||
self.previous_limit: int = (
|
||||
TagSearchPanel.tag_limit if isinstance(TagSearchPanel.tag_limit, int) else -1
|
||||
)
|
||||
self.limit_layout.addWidget(self.limit_combobox)
|
||||
self.limit_layout.addStretch(1)
|
||||
|
||||
@@ -218,32 +214,13 @@ class TagSearchPanel(PanelWidget):
|
||||
self.scroll_layout.takeAt(self.scroll_layout.count() - 1).widget().deleteLater()
|
||||
self.create_button_in_layout = False
|
||||
|
||||
# Get results for the search query
|
||||
query_lower = "" if not query else query.lower()
|
||||
# Only use the tag limit if it's an actual number (aka not "All Tags")
|
||||
tag_limit = TagSearchPanel.tag_limit if isinstance(TagSearchPanel.tag_limit, int) else -1
|
||||
tag_results: list[set[Tag]] = self.lib.search_tags(name=query, limit=tag_limit)
|
||||
if self.exclude:
|
||||
tag_results[0] = {t for t in tag_results[0] if t.id not in self.exclude}
|
||||
tag_results[1] = {t for t in tag_results[1] if t.id not in self.exclude}
|
||||
direct_tags, descendant_tags = self.lib.search_tags(name=query, limit=tag_limit)
|
||||
|
||||
# Sort and prioritize the results
|
||||
results_0 = list(tag_results[0])
|
||||
results_0.sort(key=lambda tag: tag.name.lower())
|
||||
results_1 = list(tag_results[1])
|
||||
results_1.sort(key=lambda tag: tag.name.lower())
|
||||
raw_results = list(results_0 + results_1)
|
||||
priority_results: set[Tag] = set()
|
||||
all_results: list[Tag] = []
|
||||
all_results = [t for t in direct_tags if t.id not in self.exclude]
|
||||
all_results.extend(t for t in descendant_tags if t.id not in self.exclude)
|
||||
|
||||
if query and query.strip():
|
||||
for tag in raw_results:
|
||||
if tag.name.lower().startswith(query_lower):
|
||||
priority_results.add(tag)
|
||||
|
||||
all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
|
||||
r for r in raw_results if r not in priority_results
|
||||
]
|
||||
if tag_limit > 0:
|
||||
all_results = all_results[:tag_limit]
|
||||
|
||||
@@ -255,15 +232,11 @@ class TagSearchPanel(PanelWidget):
|
||||
self.first_tag_id = None
|
||||
|
||||
# Update every tag widget with the new search result data
|
||||
norm_previous = self.previous_limit if self.previous_limit > 0 else len(self.lib.tags)
|
||||
norm_limit = tag_limit if tag_limit > 0 else len(self.lib.tags)
|
||||
range_limit = max(norm_previous, norm_limit)
|
||||
for i in range(0, range_limit):
|
||||
tag = None
|
||||
with contextlib.suppress(IndexError):
|
||||
tag = all_results[i]
|
||||
self.set_tag_widget(tag=tag, index=i)
|
||||
self.previous_limit = tag_limit
|
||||
for i in range(0, len(all_results)):
|
||||
tag = all_results[i]
|
||||
self.set_tag_widget(tag, i)
|
||||
for i in range(len(all_results), self.scroll_layout.count()):
|
||||
self.set_tag_widget(None, i)
|
||||
|
||||
# Add back the "Create & Add" button
|
||||
if query and query.strip():
|
||||
@@ -326,6 +299,9 @@ class TagSearchPanel(PanelWidget):
|
||||
|
||||
def update_limit(self, index: int):
|
||||
logger.info("[TagSearchPanel] Updating tag limit")
|
||||
if TagSearchPanel.cur_limit_idx == index:
|
||||
return
|
||||
|
||||
TagSearchPanel.cur_limit_idx = index
|
||||
|
||||
if index < len(self._limit_items) - 1:
|
||||
@@ -337,9 +313,6 @@ class TagSearchPanel(PanelWidget):
|
||||
if index != self.limit_combobox.currentIndex():
|
||||
self.limit_combobox.setCurrentIndex(index)
|
||||
|
||||
if self.previous_limit == TagSearchPanel.tag_limit:
|
||||
return
|
||||
|
||||
self.update_tags(self.search_field.text())
|
||||
|
||||
def on_return(self, text: str):
|
||||
|
||||
@@ -130,10 +130,10 @@ def test_library_search(library: Library, entry_full: Entry):
|
||||
def test_tag_search(library: Library):
|
||||
tag = library.tags[0]
|
||||
|
||||
assert library.search_tags(tag.name.lower())
|
||||
assert library.search_tags(tag.name.upper())
|
||||
assert library.search_tags(tag.name[2:-2])
|
||||
assert library.search_tags(tag.name * 2) == [set(), set()]
|
||||
assert library.search_tags(tag.name.lower())[0]
|
||||
assert library.search_tags(tag.name.upper())[0]
|
||||
assert library.search_tags(tag.name[2:-2])[0]
|
||||
assert library.search_tags(tag.name * 2) == ([], [])
|
||||
|
||||
|
||||
def test_get_entry(library: Library, entry_min: Entry):
|
||||
|
||||
Reference in New Issue
Block a user