fix: fix tag aliases (#1399)

* fix: fix tag aliases

* chore: remove logger statements from tag panel test

* chore: very important fix

* chore: remove commented-out code
This commit is contained in:
Travis Abendshien
2026-06-28 23:06:40 -07:00
committed by GitHub
parent 4da6037cbd
commit e509e247d5
9 changed files with 166 additions and 206 deletions

View File

@@ -1710,8 +1710,7 @@ class Library:
self,
tag: Tag,
parent_ids: list[int] | set[int] | None = None,
alias_names: list[str] | set[str] | None = None,
alias_ids: list[int] | set[int] | None = None,
aliases: Iterable[TagAlias] | None = None,
) -> Tag | None:
with Session(self.engine, expire_on_commit=False) as session:
try:
@@ -1721,8 +1720,8 @@ class Library:
if parent_ids is not None:
self.update_parent_tags(tag, parent_ids, session)
if alias_ids is not None and alias_names is not None:
self.update_aliases(tag, alias_ids, alias_names, session)
if aliases is not None:
self.update_aliases(tag, aliases)
session.commit()
session.expunge(tag)
@@ -2000,11 +1999,10 @@ class Library:
self,
tag: Tag,
parent_ids: list[int] | set[int] | None = None,
alias_names: list[str] | set[str] | None = None,
alias_ids: list[int] | set[int] | None = None,
aliases: Iterable[TagAlias] | None = None,
) -> None:
"""Edit a Tag in the Library."""
self.add_tag(tag, parent_ids, alias_names, alias_ids)
self.add_tag(tag, parent_ids, aliases)
def update_color(self, old_color_group: TagColorGroup, new_color_group: TagColorGroup) -> None:
"""Update a TagColorGroup in the Library. If it doesn't already exist, create it."""
@@ -2055,25 +2053,50 @@ class Library:
else:
self.add_color(new_color_group)
def update_aliases(
self,
tag: Tag,
alias_ids: list[int] | set[int],
alias_names: list[str] | set[str],
session: Session,
):
prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all()
def update_aliases(self, tag: Tag, aliases: Iterable[TagAlias]) -> bool:
"""Update TagAliases for a given Tag."""
with Session(self.engine) as session:
# Remove aliases that are no longer on the Tag
try:
old_aliases = session.scalars(
select(TagAlias).where(TagAlias.tag_id == tag.id)
).all()
old_alias_ids: list[int] = [a.id for a in old_aliases]
for old_alias in old_aliases:
if old_alias.id not in [a.id for a in aliases] or not old_alias.name:
logger.warning(
"[Library] Deleting removed alias", id=old_alias.id, name=old_alias.name
)
session.delete(old_alias)
session.commit()
except IntegrityError as e:
session.rollback()
logger.error("[Library] Could not update aliases", error=e)
return False
for alias in prev_aliases:
if alias.id not in alias_ids or alias.name not in alias_names:
session.delete(alias)
else:
alias_ids.remove(alias.id)
alias_names.remove(alias.name)
# Update or Add aliases
for alias in aliases:
# Sanitize alias names
alias.name = alias.name.strip()
if not alias.name:
continue
for alias_name in alias_names:
alias = TagAlias(alias_name, tag.id)
session.add(alias)
try:
if alias.id in old_alias_ids:
stmt = (
update(TagAlias).where(TagAlias.id == alias.id).values(name=alias.name)
)
session.execute(stmt)
else:
session.add(alias)
except IntegrityError as e:
session.rollback()
logger.error("[Library] Could not update or add alias", error=e)
return False
session.commit()
return True
def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session: Session):
if tag.id in parent_ids:

View File

@@ -163,6 +163,7 @@ class Tag(Base):
def __hash__(self) -> int:
return hash(self.id)
@override
def __eq__(self, value: object) -> bool:
if not isinstance(value, Tag):
return False

View File

@@ -87,8 +87,7 @@ class TagBoxWidget(TagBoxWidgetView):
self.__driver.lib.update_tag(
build_tag_panel.build_tag(),
parent_ids=set(build_tag_panel.parent_ids),
alias_names=set(build_tag_panel.alias_names),
alias_ids=set(build_tag_panel.alias_ids),
aliases=set(build_tag_panel.aliases),
)
self.on_update.emit()

View File

@@ -110,7 +110,6 @@ class TagSearchPanel(SearchPanel[Tag]):
Translations["tag.edit"],
is_savable=True,
)
edit_tag_modal.saved.connect(lambda: self.edit_item(edit_tag_panel))
edit_tag_modal.show()
@@ -191,10 +190,7 @@ class TagSearchPanel(SearchPanel[Tag]):
if isinstance(edit_item_panel, BuildTagPanel):
tag: Tag = edit_item_panel.build_tag()
self.__lib.add_tag(
tag,
parent_ids=edit_item_panel.parent_ids,
alias_names=edit_item_panel.alias_names,
alias_ids=edit_item_panel.alias_ids,
tag, parent_ids=edit_item_panel.parent_ids, aliases=edit_item_panel.aliases
)
if choose_item:
@@ -215,8 +211,7 @@ class TagSearchPanel(SearchPanel[Tag]):
self.__lib.update_tag(
tag=edit_item_panel.build_tag(),
parent_ids=edit_item_panel.parent_ids,
alias_names=edit_item_panel.alias_names,
alias_ids=edit_item_panel.alias_ids,
aliases=edit_item_panel.aliases,
)
self.update_items(self.search_field.text())

View File

@@ -2,8 +2,8 @@
# SPDX-License-Identifier: GPL-3.0-only
import sys
from collections.abc import Callable
from functools import partial
from typing import cast, override
import structlog
@@ -26,32 +26,30 @@ from PySide6.QtWidgets import (
)
from tagstudio.core.library.alchemy.library import Library
from tagstudio.core.library.alchemy.models import Tag, TagColorGroup
from tagstudio.core.library.alchemy.models import Tag, TagAlias, TagColorGroup
from tagstudio.core.utils.types import unwrap
from tagstudio.qt.controllers.tag_search_panel_controller import TagSearchModal, TagSearchPanel
from tagstudio.qt.controllers.tag_search_panel_controller import TagSearchModal
from tagstudio.qt.mixed.tag_color_preview import TagColorPreview
from tagstudio.qt.mixed.tag_color_selection import TagColorSelection
from tagstudio.qt.mixed.tag_widget import (
TagWidget,
get_tag_border_color,
get_tag_highlight_color,
get_tag_primary_color,
get_tag_text_color,
)
from tagstudio.qt.mixed.tag_widget import TagWidget
from tagstudio.qt.translations import Translations
from tagstudio.qt.views.panel_modal import PanelModal, PanelWidget
from tagstudio.qt.views.stylesheets.stylesheets import (
checkbox_style,
colored_radio_button_style,
get_tag_border_color,
get_tag_highlight_color,
get_tag_primary_color,
get_tag_text_color,
header,
line_edit_style,
)
from tagstudio.qt.views.tag_search_panel_view import TagSearchPanelView
logger = structlog.get_logger(__name__)
class CustomTableItem(QLineEdit):
# TODO: Look into using signals instead of callbacks
def __init__(
self,
text: str,
@@ -63,12 +61,10 @@ class CustomTableItem(QLineEdit):
self.setText(text)
self.on_return: Callable[..., None] = on_return
self.on_backspace: Callable[..., None] = on_backspace
def set_id(self, id: int):
self.id = id
self.alias: TagAlias
@override
def keyPressEvent(self, arg__1: QKeyEvent): # noqa: N802
def keyPressEvent(self, arg__1: QKeyEvent):
if arg__1.key() == Qt.Key.Key_Return or arg__1.key() == Qt.Key.Key_Enter:
self.on_return()
elif arg__1.key() == Qt.Key.Key_Backspace and self.text().strip() == "":
@@ -87,6 +83,8 @@ class BuildTagPanel(PanelWidget):
self.tag_color_namespace: str | None
self.tag_color_slug: str | None
self.disambiguation_id: int | None
self.parent_ids: set[int] = set()
self.aliases: list[TagAlias] = []
self.setMinimumSize(300, 460)
self.root_layout = QVBoxLayout(self)
@@ -104,7 +102,7 @@ class BuildTagPanel(PanelWidget):
self.name_layout.addWidget(self.name_title)
self.name_field = QLineEdit()
self.name_field.setFixedHeight(24)
self.name_field.textChanged.connect(self.on_name_changed)
self.name_field.textChanged.connect(self._on_name_change)
self.name_field.setPlaceholderText(Translations["tag.tag_name_required"])
self.name_layout.addWidget(self.name_field)
@@ -140,7 +138,7 @@ class BuildTagPanel(PanelWidget):
self.aliases_add_button = QPushButton()
self.aliases_add_button.setText("+")
self.aliases_add_button.clicked.connect(self.add_alias_callback)
self.aliases_add_button.clicked.connect(self._create_alias_callback)
# Parent Tags ----------------------------------------------------------
self.parent_tags_widget = QWidget()
@@ -179,7 +177,7 @@ class BuildTagPanel(PanelWidget):
exclude_ids.append(tag.id)
self.add_tag_modal = TagSearchModal(self.lib, exclude_ids)
self.add_tag_modal.tsp.item_chosen.connect(lambda x: self.add_parent_tag_callback(x))
self.add_tag_modal.tsp.item_chosen.connect(lambda x: self._add_parent_tag_callback(x))
self.parent_tags_add_button.clicked.connect(self.add_tag_modal.show)
# Color ----------------------------------------------------------------
@@ -202,9 +200,7 @@ class BuildTagPanel(PanelWidget):
self.tag_color_selection = TagColorSelection(self.lib)
chose_tag_color_title = Translations["tag.choose_color"]
self.choose_color_modal = PanelModal(
self.tag_color_selection,
chose_tag_color_title,
chose_tag_color_title,
self.tag_color_selection, chose_tag_color_title, chose_tag_color_title
)
self.choose_color_modal.done.connect(
lambda: self.choose_color_callback(self.tag_color_selection.selected_color)
@@ -252,12 +248,6 @@ class BuildTagPanel(PanelWidget):
self.root_layout.addWidget(self.cat_widget)
self.root_layout.addWidget(self.hidden_widget)
self.parent_ids: set[int] = set()
self.alias_ids: list[int] = []
self.alias_names: list[str] = []
self.new_alias_names: dict[int, str] = {}
self.new_item_id = sys.maxsize
self.set_tag(tag or Tag(name=Translations["tag.new"]))
def backspace(self):
@@ -269,10 +259,7 @@ class BuildTagPanel(PanelWidget):
remove_row = 0
for i in range(0, row):
item = self.aliases_table.cellWidget(i, 1)
if (
isinstance(item, CustomTableItem)
and item.id == cast(CustomTableItem, focused_widget).id
):
if isinstance(item, CustomTableItem) and item == cast(CustomTableItem, focused_widget):
cast(QPushButton, self.aliases_table.cellWidget(i, 0)).click()
remove_row = i
break
@@ -286,41 +273,36 @@ class BuildTagPanel(PanelWidget):
self.aliases_table.cellWidget(remove_row - 1, 1).setFocus()
def enter(self):
"""When the Enter/Return key has been pressed."""
focused_widget = QApplication.focusWidget()
if isinstance(focused_widget, CustomTableItem):
self.add_alias_callback()
self._create_alias_callback()
def add_parent_tag_callback(self, tag_id: int):
logger.info("add_parent_tag_callback", tag_id=tag_id)
def _add_parent_tag_callback(self, tag_id: int):
self.parent_ids.add(tag_id)
self.set_parent_tags()
def remove_parent_tag_callback(self, tag_id: int):
logger.info("remove_parent_tag_callback", tag_id=tag_id)
def _remove_parent_tag_callback(self, tag_id: int):
self.parent_ids.remove(tag_id)
self.set_parent_tags()
def add_alias_callback(self):
logger.info("add_alias_callback")
def _create_alias_callback(self):
alias = TagAlias("", tag_id=self.tag.id)
self.aliases.append(alias)
id = self.new_item_id
self.alias_ids.append(id)
self.new_alias_names[id] = ""
self.new_item_id -= 1
self._set_aliases()
row = self.aliases_table.rowCount() - 1
item = self.aliases_table.cellWidget(row, 1)
item.setFocus()
def remove_alias_callback(self, alias_id: int):
logger.info("remove_alias_callback")
self.alias_ids.remove(alias_id)
def remove_alias_callback(self, alias: TagAlias):
for i, a in enumerate(self.aliases):
if a.name == alias.name and a.id == alias.id:
del self.aliases[i]
continue
self._set_aliases()
def choose_color_callback(self, tag_color_group: TagColorGroup | None):
logger.info("choose_color_callback", tag_color_group=tag_color_group)
if tag_color_group:
self.tag_color_namespace = tag_color_group.namespace
self.tag_color_slug = tag_color_group.slug
@@ -378,19 +360,30 @@ class BuildTagPanel(PanelWidget):
else:
text_color = get_tag_text_color(primary_color, highlight_color)
def update_parent_tag_callback(build_tag_panel: BuildTagPanel):
self.lib.update_tag(
build_tag_panel.build_tag(),
parent_ids=set(build_tag_panel.parent_ids),
aliases=set(build_tag_panel.aliases),
)
self.set_parent_tags()
def on_parent_tag_edit(tag: Tag) -> None:
build_tag_panel = BuildTagPanel(self.lib, tag=tag)
edit_modal = PanelModal(
build_tag_panel,
self.lib.tag_display_name(tag),
"Edit Tag",
is_savable=True,
)
edit_modal.saved.connect(partial(update_parent_tag_callback, build_tag_panel))
edit_modal.show()
# Add Tag Widget
tag_widget = TagWidget(
tag,
library=self.lib,
has_edit=True,
has_remove=True,
)
tag_widget.on_remove.connect(lambda t=parent_id: self.remove_parent_tag_callback(t))
tag_widget.on_edit.connect(
lambda t=tag: TagSearchPanel(
library=self.lib, view=TagSearchPanelView(is_tag_chooser=True)
).on_item_edit(t)
)
tag_widget = TagWidget(tag, library=self.lib, has_edit=True, has_remove=True)
tag_widget.on_remove.connect(lambda t=parent_id: self._remove_parent_tag_callback(t))
tag_widget.on_edit.connect(partial(on_parent_tag_edit, tag))
row.addWidget(tag_widget)
# Add Disambiguation Tag Button
@@ -423,59 +416,34 @@ class BuildTagPanel(PanelWidget):
else:
button.setChecked(False)
def add_aliases(self):
names: set[str] = set()
for i in range(0, self.aliases_table.rowCount()):
widget = self.aliases_table.cellWidget(i, 1)
names.add(cast(CustomTableItem, widget).text())
remove: set[str] = set(self.alias_names) - names
self.alias_names = list(set(self.alias_names) - remove)
for name in names:
# add new aliases
if name.strip() != "" and name not in set(self.alias_names):
self.alias_names.append(name)
elif name.strip() == "" and name in set(self.alias_names):
self.alias_names.remove(name)
def _update_new_alias_name_dict(self):
for i in range(0, self.aliases_table.rowCount()):
widget = cast(CustomTableItem, self.aliases_table.cellWidget(i, 1))
self.new_alias_names[widget.id] = widget.text()
def _set_aliases(self):
self._update_new_alias_name_dict()
while self.aliases_table.rowCount() > 0:
self.aliases_table.removeRow(0)
self.alias_names.clear()
last: QWidget | None = self.panel_save_button
for alias_id in self.alias_ids:
alias = self.lib.get_alias(self.tag.id, alias_id)
aliases = list(self.aliases)
alias_names = [a.name for a in aliases]
sorted_aliases = sorted(aliases, key=lambda x: alias_names[aliases.index(x)])
alias_name: str = alias.name if alias else self.new_alias_names[alias_id]
# Sort the TagAlias objects while keeping in-progress empty ones at the bottom
empty_aliases: list[TagAlias] = []
while sorted_aliases and sorted_aliases[0].name == "":
empty_aliases.append(sorted_aliases.pop(0))
for alias in empty_aliases:
sorted_aliases.append(alias)
# handel when an alias name changes
if alias_id in self.new_alias_names:
alias_name = self.new_alias_names[alias_id]
self.alias_names.append(alias_name)
remove_btn = QPushButton("-")
remove_btn.clicked.connect(lambda id=alias_id: self.remove_alias_callback(id))
for alias in sorted_aliases:
remove_button = QPushButton("-")
remove_button.clicked.connect(partial(self.remove_alias_callback, alias))
row = self.aliases_table.rowCount()
new_item = CustomTableItem(alias_name, self.enter, self.backspace)
new_item.set_id(alias_id)
new_item.editingFinished.connect(lambda item=new_item: self._alias_name_change(item))
new_item = CustomTableItem(alias.name, self.enter, self.backspace)
new_item.alias = alias
new_item.editingFinished.connect(partial(self._on_alias_change, new_item))
self.aliases_table.insertRow(row)
self.aliases_table.setCellWidget(row, 1, new_item)
self.aliases_table.setCellWidget(row, 0, remove_btn)
self.aliases_table.setCellWidget(row, 0, remove_button)
if last is not None:
self.setTabOrder(last, self.aliases_table.cellWidget(row, 1))
@@ -484,22 +452,25 @@ class BuildTagPanel(PanelWidget):
)
last = self.aliases_table.cellWidget(row, 0)
def _alias_name_change(self, item: CustomTableItem):
self.new_alias_names[item.id] = item.text()
def _on_alias_change(self, item: CustomTableItem):
for alias in self.aliases:
if item.alias == alias:
alias.name = item.text()
item.alias.name = item.text()
continue
def set_tag(self, tag: Tag):
logger.info("[BuildTagPanel] Setting Tag", tag=tag)
logger.info("[BuildTagPanel] Setting Tag", tag_id=tag.id)
self.tag = tag
self.name_field.setText(tag.name)
self.shorthand_field.setText(tag.shorthand or "")
for alias_id in tag.alias_ids:
self.alias_ids.append(alias_id)
for alias in tag.aliases:
self.aliases.append(alias)
self._set_aliases()
self.disambiguation_id = tag.disambiguation_id
for parent_id in tag.parent_ids:
for parent_id in self.tag.parent_ids:
self.parent_ids.add(parent_id)
self.set_parent_tags()
@@ -516,9 +487,8 @@ class BuildTagPanel(PanelWidget):
self.cat_checkbox.setChecked(tag.is_category)
self.hidden_checkbox.setChecked(tag.is_hidden)
def on_name_changed(self):
def _on_name_change(self):
is_empty = not self.name_field.text().strip()
self.name_field.setStyleSheet(line_edit_style() if is_empty else "")
if self.panel_save_button is not None:
@@ -526,8 +496,6 @@ class BuildTagPanel(PanelWidget):
def build_tag(self) -> Tag:
tag = self.tag
self.add_aliases()
tag.name = self.name_field.text()
tag.shorthand = self.shorthand_field.text()
tag.disambiguation_id = self.disambiguation_id
@@ -536,7 +504,7 @@ class BuildTagPanel(PanelWidget):
tag.is_category = self.cat_checkbox.isChecked()
tag.is_hidden = self.hidden_checkbox.isChecked()
logger.info("built tag", tag=tag)
logger.info("[BuildTag] Build Tag", tag_id=tag.id, tag_name=tag.name)
return tag
@override
@@ -550,4 +518,3 @@ class BuildTagPanel(PanelWidget):
self.setTabOrder(unwrap(self.panel_save_button), self.aliases_table.cellWidget(0, 1))
self.name_field.selectAll()
self.name_field.setFocus()
self._set_aliases()

View File

@@ -110,14 +110,7 @@ class TagWidget(QWidget):
tag: Tag | None
def __init__(
self,
tag: Tag | None,
has_edit: bool,
has_remove: bool,
library: "Library | None" = None,
on_remove_callback: Callable[[], None] | None = None,
on_click_callback: Callable[[], None] | None = None,
on_edit_callback: Callable[[], None] | None = None,
self, tag: Tag | None, has_edit: bool, has_remove: bool, library: "Library | None" = None
) -> None:
super().__init__()
self.tag = tag
@@ -134,14 +127,6 @@ class TagWidget(QWidget):
self.bg_button = QPushButton(self)
self.bg_button.setFlat(True)
# add callbacks
if on_remove_callback is not None:
self.on_remove.connect(on_remove_callback)
if on_click_callback is not None:
self.on_click.connect(on_click_callback)
if on_edit_callback is not None:
self.on_edit.connect(on_edit_callback)
# add edit action
if has_edit:
edit_action = QAction(self)

View File

@@ -890,8 +890,7 @@ class QtDriver(DriverMixin, QObject):
self.lib.add_tag(
panel.build_tag(),
set(panel.parent_ids),
set(panel.alias_names),
set(panel.alias_ids),
set(panel.aliases),
),
self.modal.hide(),
)

View File

@@ -22,7 +22,7 @@ def test_build_tag_panel_add_sub_tag_callback(
panel: BuildTagPanel = BuildTagPanel(library, child)
qtbot.addWidget(panel)
panel.add_parent_tag_callback(parent.id)
panel._add_parent_tag_callback(parent.id) # pyright: ignore[reportPrivateUsage]
assert len(panel.parent_ids) == 1
@@ -33,14 +33,14 @@ def test_build_tag_panel_remove_subtag_callback(
parent = unwrap(library.add_tag(generate_tag("xxx", id=123)))
child = unwrap(library.add_tag(generate_tag("xx", id=124)))
library.update_tag(child, {parent.id}, [], [])
library.update_tag(child, {parent.id}, [])
child = unwrap(library.get_tag(child.id))
panel: BuildTagPanel = BuildTagPanel(library, child)
qtbot.addWidget(panel)
panel.remove_parent_tag_callback(parent.id)
panel._remove_parent_tag_callback(parent.id) # pyright: ignore[reportPrivateUsage]
assert len(panel.parent_ids) == 0
@@ -58,7 +58,7 @@ def test_build_tag_panel_add_alias_callback(
panel: BuildTagPanel = BuildTagPanel(library, tag)
qtbot.addWidget(panel)
panel.add_alias_callback()
panel._create_alias_callback() # pyright: ignore[reportPrivateUsage]
assert panel.aliases_table.rowCount() == 1
@@ -68,7 +68,9 @@ def test_build_tag_panel_remove_alias_callback(
):
tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123)))
library.update_tag(tag, [], {"alias", "alias_2"}, {123, 124})
alias_1 = TagAlias("alias", tag.id)
alias_2 = TagAlias("alias_2", tag.id)
library.update_tag(tag, [], {alias_1, alias_2})
tag = unwrap(library.get_tag(tag.id))
@@ -79,12 +81,11 @@ def test_build_tag_panel_remove_alias_callback(
qtbot.addWidget(panel)
alias: TagAlias = unwrap(library.get_alias(tag.id, tag.alias_ids[0]))
panel.remove_alias_callback(alias)
panel.remove_alias_callback(alias.id)
assert len(panel.alias_ids) == 1
assert len(panel.alias_names) == 1
assert alias.name not in panel.alias_names
assert len(panel.aliases) == 1
assert alias not in panel.aliases
assert (alias.id, alias.name) not in [(a.id, a.name) for a in panel.aliases]
def test_build_tag_panel_set_parent_tags(
@@ -109,7 +110,9 @@ def test_build_tag_panel_add_aliases(
):
tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123)))
library.update_tag(tag, [], {"alias", "alias_2"}, {123, 124})
alias_1 = TagAlias("alias", tag.id)
alias_2 = TagAlias("alias_2", tag.id)
library.update_tag(tag, [], {alias_1, alias_2})
tag = unwrap(library.get_tag(tag.id))
@@ -132,22 +135,13 @@ def test_build_tag_panel_add_aliases(
assert "alias" in alias_names
assert "alias_2" in alias_names
old_text = widget.text()
widget.setText("alias_update")
panel.add_aliases()
assert old_text not in panel.alias_names
assert "alias_update" in panel.alias_names
assert len(panel.alias_names) == 2
def test_build_tag_panel_set_aliases(
qtbot: QtBot, library: Library, generate_tag: Callable[..., Tag]
):
tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123)))
library.update_tag(tag, [], {"alias"}, {123})
alias_1 = TagAlias("Alias 1", tag.id)
library.update_tag(tag, [], [alias_1])
tag = unwrap(library.get_tag(tag.id))
@@ -157,8 +151,7 @@ def test_build_tag_panel_set_aliases(
qtbot.addWidget(panel)
assert panel.aliases_table.rowCount() == 1
assert len(panel.alias_names) == 1
assert len(panel.alias_ids) == 1
assert len(panel.aliases) == 1
def test_build_tag_panel_set_tag(qtbot: QtBot, library: Library, generate_tag: Callable[..., Tag]):

View File

@@ -15,7 +15,7 @@ from tagstudio.core.library.alchemy.fields import (
TextField,
)
from tagstudio.core.library.alchemy.library import Library
from tagstudio.core.library.alchemy.models import Entry, Tag
from tagstudio.core.library.alchemy.models import Entry, Tag, TagAlias
from tagstudio.core.utils.types import unwrap
logger = structlog.get_logger()
@@ -25,10 +25,9 @@ def test_library_add_alias(library: Library, generate_tag: Callable[..., Tag]):
tag = unwrap(library.add_tag(generate_tag("xxx", id=123)))
parent_ids: set[int] = set()
alias_ids: set[int] = set()
alias_names: set[str] = set()
alias_names.add("test_alias")
library.update_tag(tag, parent_ids, alias_names, alias_ids)
aliases: set[TagAlias] = set()
aliases.add(TagAlias("test_alias", tag.id))
library.update_tag(tag, parent_ids, aliases)
tag = unwrap(library.get_tag(tag.id))
alias_ids = set(tag.alias_ids)
@@ -39,10 +38,9 @@ def test_library_get_alias(library: Library, generate_tag: Callable[..., Tag]):
tag = unwrap(library.add_tag(generate_tag("xxx", id=123)))
parent_ids: set[int] = set()
alias_ids: list[int] = []
alias_names: set[str] = set()
alias_names.add("test_alias")
library.update_tag(tag, parent_ids, alias_names, alias_ids)
aliases: set[TagAlias] = set()
aliases.add(TagAlias("test_alias", tag.id))
library.update_tag(tag, parent_ids, aliases)
tag = unwrap(library.get_tag(tag.id))
alias_ids = tag.alias_ids
@@ -54,19 +52,19 @@ def test_library_update_alias(library: Library, generate_tag: Callable[..., Tag]
tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123)))
parent_ids: set[int] = set()
alias_ids: list[int] = []
alias_names: set[str] = set()
alias_names.add("test_alias")
library.update_tag(tag, parent_ids, alias_names, alias_ids)
aliases: set[TagAlias] = set()
test_alias = TagAlias("test_alias", tag.id)
aliases.add(test_alias)
library.update_tag(tag, parent_ids, aliases)
tag = unwrap(library.get_tag(tag.id))
alias_ids = tag.alias_ids
alias = unwrap(library.get_alias(tag.id, alias_ids[0]))
assert alias.name == "test_alias"
alias_names.remove("test_alias")
alias_names.add("alias_update")
library.update_tag(tag, parent_ids, alias_names, alias_ids)
aliases.remove(test_alias)
aliases.add(TagAlias("alias_update", tag.id))
library.update_tag(tag, parent_ids, aliases)
tag = unwrap(library.get_tag(tag.id))
assert len(tag.alias_ids) == 1
@@ -108,7 +106,7 @@ def test_tag_self_parent(library: Library, generate_tag: Callable[..., Tag]):
tag = unwrap(library.add_tag(generate_tag("xxx", id=123)))
assert tag.id == 123
library.update_tag(tag, {tag.id}, [], [])
library.update_tag(tag, {tag.id}, [])
tag = unwrap(library.get_tag(tag.id))
assert len(tag.parent_ids) == 0