mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-02-01 07:39:10 +00:00
perf: optimize db queries for preview panel (#942)
* perf: optimize mapping of category->tags * perf: one less db call for Library.tag_display_name * fix: include joins in Library.get_tag_hierarchy * fix: remove category if empty in preview panel * fix: add missing imports and remove unneeded dict * fix: add tags that are categories to their own category * fix: flip parent_id/child_id in get_tag_hierarchy * fix: prevent trying to save duplicate TagParents
This commit is contained in:
@@ -32,6 +32,7 @@ from sqlalchemy import (
|
||||
desc,
|
||||
exists,
|
||||
func,
|
||||
inspect,
|
||||
or_,
|
||||
select,
|
||||
text,
|
||||
@@ -43,6 +44,7 @@ from sqlalchemy.orm import (
|
||||
contains_eager,
|
||||
joinedload,
|
||||
make_transient,
|
||||
noload,
|
||||
selectinload,
|
||||
)
|
||||
|
||||
@@ -312,13 +314,12 @@ class Library:
|
||||
return f
|
||||
return None
|
||||
|
||||
def tag_display_name(self, tag_id: int) -> str:
|
||||
with Session(self.engine) as session:
|
||||
tag = session.scalar(select(Tag).where(Tag.id == tag_id))
|
||||
if not tag:
|
||||
return "<NO TAG>"
|
||||
def tag_display_name(self, tag: Tag | None) -> str:
|
||||
if not tag:
|
||||
return "<NO TAG>"
|
||||
|
||||
if tag.disambiguation_id:
|
||||
if tag.disambiguation_id:
|
||||
with Session(self.engine) as session:
|
||||
disam_tag = session.scalar(select(Tag).where(Tag.id == tag.disambiguation_id))
|
||||
if not disam_tag:
|
||||
return "<NO DISAM TAG>"
|
||||
@@ -326,8 +327,8 @@ class Library:
|
||||
if not disam_name:
|
||||
disam_name = disam_tag.name
|
||||
return f"{tag.name} ({disam_name})"
|
||||
else:
|
||||
return tag.name
|
||||
else:
|
||||
return tag.name
|
||||
|
||||
def open_library(
|
||||
self, library_dir: Path, storage_path: Path | str | None = None
|
||||
@@ -1544,6 +1545,45 @@ class Library:
|
||||
|
||||
return session.scalar(statement)
|
||||
|
||||
def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]:
|
||||
"""Get a dictionary containing tags in `tag_ids` and all of their ancestor tags."""
|
||||
current_tag_ids: set[int] = set(tag_ids)
|
||||
all_tag_ids: set[int] = set()
|
||||
all_tags: dict[int, Tag] = {}
|
||||
all_tag_parents: dict[int, list[int]] = {}
|
||||
|
||||
with Session(self.engine) as session:
|
||||
while len(current_tag_ids) > 0:
|
||||
all_tag_ids.update(current_tag_ids)
|
||||
statement = select(TagParent).where(TagParent.child_id.in_(current_tag_ids))
|
||||
tag_parents = session.scalars(statement).fetchall()
|
||||
current_tag_ids.clear()
|
||||
for tag_parent in tag_parents:
|
||||
all_tag_parents.setdefault(tag_parent.child_id, []).append(tag_parent.parent_id)
|
||||
current_tag_ids.add(tag_parent.parent_id)
|
||||
current_tag_ids = current_tag_ids.difference(all_tag_ids)
|
||||
|
||||
statement = select(Tag).where(Tag.id.in_(all_tag_ids))
|
||||
statement = statement.options(
|
||||
noload(Tag.parent_tags), selectinload(Tag.aliases), joinedload(Tag.color)
|
||||
)
|
||||
tags = session.scalars(statement).fetchall()
|
||||
for tag in tags:
|
||||
all_tags[tag.id] = tag
|
||||
for tag in all_tags.values():
|
||||
# Sqlalchemy tracks this as a change to the parent_tags field
|
||||
tag.parent_tags = {all_tags[p] for p in all_tag_parents.get(tag.id, [])}
|
||||
# When calling session.add with this tag instance sqlalchemy will
|
||||
# attempt to create TagParents that already exist.
|
||||
|
||||
state = inspect(tag)
|
||||
# Prevent sqlalchemy from thinking any fields are different from what's commited
|
||||
# commited_state contains original values for fields that have changed.
|
||||
# empty when no fields have changed
|
||||
state.committed_state.clear()
|
||||
|
||||
return all_tags
|
||||
|
||||
def add_parent_tag(self, parent_id: int, child_id: int) -> bool:
|
||||
if parent_id == child_id:
|
||||
return False
|
||||
|
||||
@@ -156,6 +156,9 @@ class Tag(Base):
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
return self.name < other.name
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class TagBoxWidget(TagBoxWidgetView):
|
||||
|
||||
edit_modal = PanelModal(
|
||||
build_tag_panel,
|
||||
self.__driver.lib.tag_display_name(tag.id),
|
||||
self.__driver.lib.tag_display_name(tag),
|
||||
"Edit Tag",
|
||||
done_callback=self.on_update.emit,
|
||||
has_save=True,
|
||||
|
||||
@@ -62,7 +62,7 @@ class TagDatabasePanel(TagSearchPanel):
|
||||
message_box = QMessageBox(
|
||||
QMessageBox.Question, # type: ignore
|
||||
Translations["tag.remove"],
|
||||
Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag.id)),
|
||||
Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag)),
|
||||
QMessageBox.Ok | QMessageBox.Cancel, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -387,7 +387,7 @@ class TagSearchPanel(PanelWidget):
|
||||
|
||||
self.edit_modal = PanelModal(
|
||||
build_tag_panel,
|
||||
self.lib.tag_display_name(tag.id),
|
||||
self.lib.tag_display_name(tag),
|
||||
Translations["tag.edit"],
|
||||
done_callback=(self.update_tags(self.search_field.text())),
|
||||
has_save=True,
|
||||
|
||||
@@ -32,7 +32,7 @@ class TagBoxWidgetView(FieldWidget):
|
||||
self.setLayout(self.__root_layout)
|
||||
|
||||
def set_tags(self, tags: Iterable[Tag]) -> None:
|
||||
tags_ = sorted(list(tags), key=lambda tag: self.__lib.tag_display_name(tag.id))
|
||||
tags_ = sorted(list(tags), key=lambda tag: self.__lib.tag_display_name(tag))
|
||||
logger.info("[TagBoxWidget] Tags:", tags=tags)
|
||||
while self.__root_layout.itemAt(0):
|
||||
self.__root_layout.takeAt(0).widget().deleteLater() # pyright: ignore[reportOptionalMemberAccess]
|
||||
|
||||
@@ -160,96 +160,38 @@ class FieldContainers(QWidget):
|
||||
c.setHidden(True)
|
||||
|
||||
def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]:
|
||||
"""Get a dictionary of category tags mapped to their respective tags."""
|
||||
cats: dict[Tag | None, set[Tag]] = {}
|
||||
cats[None] = set()
|
||||
"""Get a dictionary of category tags mapped to their respective tags.
|
||||
|
||||
base_tag_ids: set[int] = {x.id for x in tags}
|
||||
exhausted: set[int] = set()
|
||||
cluster_map: dict[int, set[int]] = {}
|
||||
Example:
|
||||
Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to:
|
||||
"Cartoon Network" -> Johnny Bravo,
|
||||
"Character" -> "Johnny Bravo",
|
||||
"TV" -> Johnny Bravo"
|
||||
"""
|
||||
hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags)
|
||||
|
||||
def add_to_cluster(tag_id: int, p_ids: list[int] | None = None):
|
||||
"""Maps a Tag's child tags' IDs back to it's parent tag's ID.
|
||||
|
||||
Example:
|
||||
Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to:
|
||||
"Cartoon Network" -> Johnny Bravo,
|
||||
"Character" -> "Johnny Bravo",
|
||||
"TV" -> Johnny Bravo"
|
||||
"""
|
||||
tag_obj = unwrap(self.lib.get_tag(tag_id)) # Get full object
|
||||
if p_ids is None:
|
||||
p_ids = tag_obj.parent_ids
|
||||
|
||||
for p_id in p_ids:
|
||||
if cluster_map.get(p_id) is None:
|
||||
cluster_map[p_id] = set()
|
||||
# If the p_tag has p_tags of its own, recursively link those to the original Tag.
|
||||
if tag_id not in cluster_map[p_id]:
|
||||
cluster_map[p_id].add(tag_id)
|
||||
p_tag = unwrap(self.lib.get_tag(p_id)) # Get full object
|
||||
if p_tag.parent_ids:
|
||||
add_to_cluster(
|
||||
tag_id,
|
||||
[sub_id for sub_id in p_tag.parent_ids if sub_id != tag_id],
|
||||
)
|
||||
exhausted.add(p_id)
|
||||
exhausted.add(tag_id)
|
||||
|
||||
for tag in tags:
|
||||
add_to_cluster(tag.id)
|
||||
|
||||
logger.info("[FieldContainers] Entry Cluster", entry_cluster=exhausted)
|
||||
logger.info("[FieldContainers] Cluster Map", cluster_map=cluster_map)
|
||||
|
||||
# Initialize all categories from parents.
|
||||
tags_ = {t for tid in exhausted if (t := self.lib.get_tag(tid)) is not None}
|
||||
for tag in tags_:
|
||||
categories: dict[Tag | None, set[Tag]] = {None: set()}
|
||||
for tag in hierarchy_tags.values():
|
||||
if tag.is_category:
|
||||
cats[tag] = set()
|
||||
logger.info("[FieldContainers] Blank Tag Categories", cats=cats)
|
||||
categories[tag] = set()
|
||||
for tag in tags:
|
||||
tag = hierarchy_tags[tag.id]
|
||||
has_category_parent = False
|
||||
parent_tags = tag.parent_tags
|
||||
while len(parent_tags) > 0:
|
||||
grandparent_tags: set[Tag] = set()
|
||||
for parent_tag in parent_tags:
|
||||
if parent_tag in categories:
|
||||
categories[parent_tag].add(tag)
|
||||
has_category_parent = True
|
||||
grandparent_tags.update(parent_tag.parent_tags)
|
||||
parent_tags = grandparent_tags
|
||||
if tag.is_category:
|
||||
categories[tag].add(tag)
|
||||
elif not has_category_parent:
|
||||
categories[None].add(tag)
|
||||
|
||||
# Add tags to any applicable categories.
|
||||
added_ids: set[int] = set()
|
||||
for key in cats:
|
||||
logger.info("[FieldContainers] Checking category tag key", key=key)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
"[FieldContainers] Key cluster:", key=key, cluster=cluster_map.get(key.id)
|
||||
)
|
||||
|
||||
if final_tags := cluster_map.get(key.id, set()).union([key.id]):
|
||||
cats[key] = {
|
||||
t
|
||||
for tid in final_tags
|
||||
if tid in base_tag_ids and (t := self.lib.get_tag(tid)) is not None
|
||||
}
|
||||
added_ids = added_ids.union({tid for tid in final_tags if tid in base_tag_ids})
|
||||
|
||||
# Add remaining tags to None key (general case).
|
||||
cats[None] = {
|
||||
t
|
||||
for tid in base_tag_ids
|
||||
if tid not in added_ids and (t := self.lib.get_tag(tid)) is not None
|
||||
}
|
||||
logger.info(
|
||||
"[FieldContainers] Key cluster: None, general case!",
|
||||
general_tags=cats[None],
|
||||
added=added_ids,
|
||||
base_tag_ids=base_tag_ids,
|
||||
)
|
||||
|
||||
# Remove unused categories
|
||||
empty: list[Tag | None] = []
|
||||
for k, v in list(cats.items()):
|
||||
if not v:
|
||||
empty.append(k)
|
||||
for key in empty:
|
||||
cats.pop(key, None)
|
||||
|
||||
logger.info("[FieldContainers] Tag Categories", categories=cats)
|
||||
return cats
|
||||
return dict((c, d) for c, d in categories.items() if len(d) > 0)
|
||||
|
||||
def remove_field_prompt(self, name: str) -> str:
|
||||
return Translations.format("library.field.confirm_remove", name=name)
|
||||
|
||||
@@ -266,7 +266,7 @@ class TagWidget(QWidget):
|
||||
)
|
||||
|
||||
if self.lib:
|
||||
self.bg_button.setText(escape_text(self.lib.tag_display_name(tag.id)))
|
||||
self.bg_button.setText(escape_text(self.lib.tag_display_name(tag)))
|
||||
else:
|
||||
self.bg_button.setText(escape_text(tag.name))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user