perf: optimize db queries for preview panel (#942)

* perf: optimize mapping of category->tags

* perf: one less db call for Library.tag_display_name

* fix: include joins in Library.get_tag_hierarchy

* fix: remove category if empty in preview panel

* fix: add missing imports and remove unneeded dict

* fix: add tags that are categories to their own category

* fix: flip parent_id/child_id in get_tag_hierarchy

* fix: prevent trying to save duplicate TagParents
This commit is contained in:
TheBobBobs
2025-08-29 20:42:15 +00:00
committed by GitHub
parent 08d0ba4eee
commit d8919ab283
8 changed files with 84 additions and 99 deletions

View File

@@ -32,6 +32,7 @@ from sqlalchemy import (
desc,
exists,
func,
inspect,
or_,
select,
text,
@@ -43,6 +44,7 @@ from sqlalchemy.orm import (
contains_eager,
joinedload,
make_transient,
noload,
selectinload,
)
@@ -312,13 +314,12 @@ class Library:
return f
return None
def tag_display_name(self, tag_id: int) -> str:
with Session(self.engine) as session:
tag = session.scalar(select(Tag).where(Tag.id == tag_id))
if not tag:
return "<NO TAG>"
def tag_display_name(self, tag: Tag | None) -> str:
if not tag:
return "<NO TAG>"
if tag.disambiguation_id:
if tag.disambiguation_id:
with Session(self.engine) as session:
disam_tag = session.scalar(select(Tag).where(Tag.id == tag.disambiguation_id))
if not disam_tag:
return "<NO DISAM TAG>"
@@ -326,8 +327,8 @@ class Library:
if not disam_name:
disam_name = disam_tag.name
return f"{tag.name} ({disam_name})"
else:
return tag.name
else:
return tag.name
def open_library(
self, library_dir: Path, storage_path: Path | str | None = None
@@ -1544,6 +1545,45 @@ class Library:
return session.scalar(statement)
def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]:
"""Get a dictionary containing tags in `tag_ids` and all of their ancestor tags."""
current_tag_ids: set[int] = set(tag_ids)
all_tag_ids: set[int] = set()
all_tags: dict[int, Tag] = {}
all_tag_parents: dict[int, list[int]] = {}
with Session(self.engine) as session:
while len(current_tag_ids) > 0:
all_tag_ids.update(current_tag_ids)
statement = select(TagParent).where(TagParent.child_id.in_(current_tag_ids))
tag_parents = session.scalars(statement).fetchall()
current_tag_ids.clear()
for tag_parent in tag_parents:
all_tag_parents.setdefault(tag_parent.child_id, []).append(tag_parent.parent_id)
current_tag_ids.add(tag_parent.parent_id)
current_tag_ids = current_tag_ids.difference(all_tag_ids)
statement = select(Tag).where(Tag.id.in_(all_tag_ids))
statement = statement.options(
noload(Tag.parent_tags), selectinload(Tag.aliases), joinedload(Tag.color)
)
tags = session.scalars(statement).fetchall()
for tag in tags:
all_tags[tag.id] = tag
for tag in all_tags.values():
# Sqlalchemy tracks this as a change to the parent_tags field
tag.parent_tags = {all_tags[p] for p in all_tag_parents.get(tag.id, [])}
# When calling session.add with this tag instance sqlalchemy will
# attempt to create TagParents that already exist.
state = inspect(tag)
# Prevent sqlalchemy from thinking any fields are different from what's commited
# commited_state contains original values for fields that have changed.
# empty when no fields have changed
state.committed_state.clear()
return all_tags
def add_parent_tag(self, parent_id: int, child_id: int) -> bool:
if parent_id == child_id:
return False

View File

@@ -156,6 +156,9 @@ class Tag(Base):
def __repr__(self) -> str:
return self.__str__()
def __hash__(self) -> int:
return hash(self.id)
def __lt__(self, other) -> bool:
return self.name < other.name

View File

@@ -75,7 +75,7 @@ class TagBoxWidget(TagBoxWidgetView):
edit_modal = PanelModal(
build_tag_panel,
self.__driver.lib.tag_display_name(tag.id),
self.__driver.lib.tag_display_name(tag),
"Edit Tag",
done_callback=self.on_update.emit,
has_save=True,

View File

@@ -62,7 +62,7 @@ class TagDatabasePanel(TagSearchPanel):
message_box = QMessageBox(
QMessageBox.Question, # type: ignore
Translations["tag.remove"],
Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag.id)),
Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag)),
QMessageBox.Ok | QMessageBox.Cancel, # type: ignore
)

View File

@@ -387,7 +387,7 @@ class TagSearchPanel(PanelWidget):
self.edit_modal = PanelModal(
build_tag_panel,
self.lib.tag_display_name(tag.id),
self.lib.tag_display_name(tag),
Translations["tag.edit"],
done_callback=(self.update_tags(self.search_field.text())),
has_save=True,

View File

@@ -32,7 +32,7 @@ class TagBoxWidgetView(FieldWidget):
self.setLayout(self.__root_layout)
def set_tags(self, tags: Iterable[Tag]) -> None:
tags_ = sorted(list(tags), key=lambda tag: self.__lib.tag_display_name(tag.id))
tags_ = sorted(list(tags), key=lambda tag: self.__lib.tag_display_name(tag))
logger.info("[TagBoxWidget] Tags:", tags=tags)
while self.__root_layout.itemAt(0):
self.__root_layout.takeAt(0).widget().deleteLater() # pyright: ignore[reportOptionalMemberAccess]

View File

@@ -160,96 +160,38 @@ class FieldContainers(QWidget):
c.setHidden(True)
def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]:
"""Get a dictionary of category tags mapped to their respective tags."""
cats: dict[Tag | None, set[Tag]] = {}
cats[None] = set()
"""Get a dictionary of category tags mapped to their respective tags.
base_tag_ids: set[int] = {x.id for x in tags}
exhausted: set[int] = set()
cluster_map: dict[int, set[int]] = {}
Example:
Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to:
"Cartoon Network" -> Johnny Bravo,
"Character" -> "Johnny Bravo",
"TV" -> Johnny Bravo"
"""
hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags)
def add_to_cluster(tag_id: int, p_ids: list[int] | None = None):
"""Maps a Tag's child tags' IDs back to it's parent tag's ID.
Example:
Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to:
"Cartoon Network" -> Johnny Bravo,
"Character" -> "Johnny Bravo",
"TV" -> Johnny Bravo"
"""
tag_obj = unwrap(self.lib.get_tag(tag_id)) # Get full object
if p_ids is None:
p_ids = tag_obj.parent_ids
for p_id in p_ids:
if cluster_map.get(p_id) is None:
cluster_map[p_id] = set()
# If the p_tag has p_tags of its own, recursively link those to the original Tag.
if tag_id not in cluster_map[p_id]:
cluster_map[p_id].add(tag_id)
p_tag = unwrap(self.lib.get_tag(p_id)) # Get full object
if p_tag.parent_ids:
add_to_cluster(
tag_id,
[sub_id for sub_id in p_tag.parent_ids if sub_id != tag_id],
)
exhausted.add(p_id)
exhausted.add(tag_id)
for tag in tags:
add_to_cluster(tag.id)
logger.info("[FieldContainers] Entry Cluster", entry_cluster=exhausted)
logger.info("[FieldContainers] Cluster Map", cluster_map=cluster_map)
# Initialize all categories from parents.
tags_ = {t for tid in exhausted if (t := self.lib.get_tag(tid)) is not None}
for tag in tags_:
categories: dict[Tag | None, set[Tag]] = {None: set()}
for tag in hierarchy_tags.values():
if tag.is_category:
cats[tag] = set()
logger.info("[FieldContainers] Blank Tag Categories", cats=cats)
categories[tag] = set()
for tag in tags:
tag = hierarchy_tags[tag.id]
has_category_parent = False
parent_tags = tag.parent_tags
while len(parent_tags) > 0:
grandparent_tags: set[Tag] = set()
for parent_tag in parent_tags:
if parent_tag in categories:
categories[parent_tag].add(tag)
has_category_parent = True
grandparent_tags.update(parent_tag.parent_tags)
parent_tags = grandparent_tags
if tag.is_category:
categories[tag].add(tag)
elif not has_category_parent:
categories[None].add(tag)
# Add tags to any applicable categories.
added_ids: set[int] = set()
for key in cats:
logger.info("[FieldContainers] Checking category tag key", key=key)
if key:
logger.info(
"[FieldContainers] Key cluster:", key=key, cluster=cluster_map.get(key.id)
)
if final_tags := cluster_map.get(key.id, set()).union([key.id]):
cats[key] = {
t
for tid in final_tags
if tid in base_tag_ids and (t := self.lib.get_tag(tid)) is not None
}
added_ids = added_ids.union({tid for tid in final_tags if tid in base_tag_ids})
# Add remaining tags to None key (general case).
cats[None] = {
t
for tid in base_tag_ids
if tid not in added_ids and (t := self.lib.get_tag(tid)) is not None
}
logger.info(
"[FieldContainers] Key cluster: None, general case!",
general_tags=cats[None],
added=added_ids,
base_tag_ids=base_tag_ids,
)
# Remove unused categories
empty: list[Tag | None] = []
for k, v in list(cats.items()):
if not v:
empty.append(k)
for key in empty:
cats.pop(key, None)
logger.info("[FieldContainers] Tag Categories", categories=cats)
return cats
return dict((c, d) for c, d in categories.items() if len(d) > 0)
def remove_field_prompt(self, name: str) -> str:
return Translations.format("library.field.confirm_remove", name=name)

View File

@@ -266,7 +266,7 @@ class TagWidget(QWidget):
)
if self.lib:
self.bg_button.setText(escape_text(self.lib.tag_display_name(tag.id)))
self.bg_button.setText(escape_text(self.lib.tag_display_name(tag)))
else:
self.bg_button.setText(escape_text(tag.name))