diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index 8c28134d..8801d617 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -1710,8 +1710,7 @@ class Library: self, tag: Tag, parent_ids: list[int] | set[int] | None = None, - alias_names: list[str] | set[str] | None = None, - alias_ids: list[int] | set[int] | None = None, + aliases: Iterable[TagAlias] | None = None, ) -> Tag | None: with Session(self.engine, expire_on_commit=False) as session: try: @@ -1721,8 +1720,8 @@ class Library: if parent_ids is not None: self.update_parent_tags(tag, parent_ids, session) - if alias_ids is not None and alias_names is not None: - self.update_aliases(tag, alias_ids, alias_names, session) + if aliases is not None: + self.update_aliases(tag, aliases) session.commit() session.expunge(tag) @@ -2000,11 +1999,10 @@ class Library: self, tag: Tag, parent_ids: list[int] | set[int] | None = None, - alias_names: list[str] | set[str] | None = None, - alias_ids: list[int] | set[int] | None = None, + aliases: Iterable[TagAlias] | None = None, ) -> None: """Edit a Tag in the Library.""" - self.add_tag(tag, parent_ids, alias_names, alias_ids) + self.add_tag(tag, parent_ids, aliases) def update_color(self, old_color_group: TagColorGroup, new_color_group: TagColorGroup) -> None: """Update a TagColorGroup in the Library. If it doesn't already exist, create it.""" @@ -2055,25 +2053,50 @@ class Library: else: self.add_color(new_color_group) - def update_aliases( - self, - tag: Tag, - alias_ids: list[int] | set[int], - alias_names: list[str] | set[str], - session: Session, - ): - prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all() + def update_aliases(self, tag: Tag, aliases: Iterable[TagAlias]) -> bool: + """Update TagAliases for a given Tag.""" + with Session(self.engine) as session: + # Remove aliases that are no longer on the Tag + try: + old_aliases = session.scalars( + select(TagAlias).where(TagAlias.tag_id == tag.id) + ).all() + old_alias_ids: list[int] = [a.id for a in old_aliases] + for old_alias in old_aliases: + if old_alias.id not in [a.id for a in aliases] or not old_alias.name: + logger.warning( + "[Library] Deleting removed alias", id=old_alias.id, name=old_alias.name + ) + session.delete(old_alias) + session.commit() + except IntegrityError as e: + session.rollback() + logger.error("[Library] Could not update aliases", error=e) + return False - for alias in prev_aliases: - if alias.id not in alias_ids or alias.name not in alias_names: - session.delete(alias) - else: - alias_ids.remove(alias.id) - alias_names.remove(alias.name) + # Update or Add aliases + for alias in aliases: + # Sanitize alias names + alias.name = alias.name.strip() + if not alias.name: + continue - for alias_name in alias_names: - alias = TagAlias(alias_name, tag.id) - session.add(alias) + try: + if alias.id in old_alias_ids: + stmt = ( + update(TagAlias).where(TagAlias.id == alias.id).values(name=alias.name) + ) + session.execute(stmt) + else: + session.add(alias) + except IntegrityError as e: + session.rollback() + logger.error("[Library] Could not update or add alias", error=e) + return False + + session.commit() + + return True def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session: Session): if tag.id in parent_ids: diff --git a/src/tagstudio/core/library/alchemy/models.py b/src/tagstudio/core/library/alchemy/models.py index c2b81f66..668c80f5 100644 --- a/src/tagstudio/core/library/alchemy/models.py +++ b/src/tagstudio/core/library/alchemy/models.py @@ -163,6 +163,7 @@ class Tag(Base): def __hash__(self) -> int: return hash(self.id) + @override def __eq__(self, value: object) -> bool: if not isinstance(value, Tag): return False diff --git a/src/tagstudio/qt/controllers/tag_box_controller.py b/src/tagstudio/qt/controllers/tag_box_controller.py index 58e7919a..313f19cb 100644 --- a/src/tagstudio/qt/controllers/tag_box_controller.py +++ b/src/tagstudio/qt/controllers/tag_box_controller.py @@ -87,8 +87,7 @@ class TagBoxWidget(TagBoxWidgetView): self.__driver.lib.update_tag( build_tag_panel.build_tag(), parent_ids=set(build_tag_panel.parent_ids), - alias_names=set(build_tag_panel.alias_names), - alias_ids=set(build_tag_panel.alias_ids), + aliases=set(build_tag_panel.aliases), ) self.on_update.emit() diff --git a/src/tagstudio/qt/controllers/tag_search_panel_controller.py b/src/tagstudio/qt/controllers/tag_search_panel_controller.py index ecd53496..5d1efd87 100644 --- a/src/tagstudio/qt/controllers/tag_search_panel_controller.py +++ b/src/tagstudio/qt/controllers/tag_search_panel_controller.py @@ -110,7 +110,6 @@ class TagSearchPanel(SearchPanel[Tag]): Translations["tag.edit"], is_savable=True, ) - edit_tag_modal.saved.connect(lambda: self.edit_item(edit_tag_panel)) edit_tag_modal.show() @@ -191,10 +190,7 @@ class TagSearchPanel(SearchPanel[Tag]): if isinstance(edit_item_panel, BuildTagPanel): tag: Tag = edit_item_panel.build_tag() self.__lib.add_tag( - tag, - parent_ids=edit_item_panel.parent_ids, - alias_names=edit_item_panel.alias_names, - alias_ids=edit_item_panel.alias_ids, + tag, parent_ids=edit_item_panel.parent_ids, aliases=edit_item_panel.aliases ) if choose_item: @@ -215,8 +211,7 @@ class TagSearchPanel(SearchPanel[Tag]): self.__lib.update_tag( tag=edit_item_panel.build_tag(), parent_ids=edit_item_panel.parent_ids, - alias_names=edit_item_panel.alias_names, - alias_ids=edit_item_panel.alias_ids, + aliases=edit_item_panel.aliases, ) self.update_items(self.search_field.text()) diff --git a/src/tagstudio/qt/mixed/build_tag.py b/src/tagstudio/qt/mixed/build_tag.py index 4c81a968..ae10536c 100644 --- a/src/tagstudio/qt/mixed/build_tag.py +++ b/src/tagstudio/qt/mixed/build_tag.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: GPL-3.0-only -import sys from collections.abc import Callable +from functools import partial from typing import cast, override import structlog @@ -26,32 +26,30 @@ from PySide6.QtWidgets import ( ) from tagstudio.core.library.alchemy.library import Library -from tagstudio.core.library.alchemy.models import Tag, TagColorGroup +from tagstudio.core.library.alchemy.models import Tag, TagAlias, TagColorGroup from tagstudio.core.utils.types import unwrap -from tagstudio.qt.controllers.tag_search_panel_controller import TagSearchModal, TagSearchPanel +from tagstudio.qt.controllers.tag_search_panel_controller import TagSearchModal from tagstudio.qt.mixed.tag_color_preview import TagColorPreview from tagstudio.qt.mixed.tag_color_selection import TagColorSelection -from tagstudio.qt.mixed.tag_widget import ( - TagWidget, - get_tag_border_color, - get_tag_highlight_color, - get_tag_primary_color, - get_tag_text_color, -) +from tagstudio.qt.mixed.tag_widget import TagWidget from tagstudio.qt.translations import Translations from tagstudio.qt.views.panel_modal import PanelModal, PanelWidget from tagstudio.qt.views.stylesheets.stylesheets import ( checkbox_style, colored_radio_button_style, + get_tag_border_color, + get_tag_highlight_color, + get_tag_primary_color, + get_tag_text_color, header, line_edit_style, ) -from tagstudio.qt.views.tag_search_panel_view import TagSearchPanelView logger = structlog.get_logger(__name__) class CustomTableItem(QLineEdit): + # TODO: Look into using signals instead of callbacks def __init__( self, text: str, @@ -63,12 +61,10 @@ class CustomTableItem(QLineEdit): self.setText(text) self.on_return: Callable[..., None] = on_return self.on_backspace: Callable[..., None] = on_backspace - - def set_id(self, id: int): - self.id = id + self.alias: TagAlias @override - def keyPressEvent(self, arg__1: QKeyEvent): # noqa: N802 + def keyPressEvent(self, arg__1: QKeyEvent): if arg__1.key() == Qt.Key.Key_Return or arg__1.key() == Qt.Key.Key_Enter: self.on_return() elif arg__1.key() == Qt.Key.Key_Backspace and self.text().strip() == "": @@ -87,6 +83,8 @@ class BuildTagPanel(PanelWidget): self.tag_color_namespace: str | None self.tag_color_slug: str | None self.disambiguation_id: int | None + self.parent_ids: set[int] = set() + self.aliases: list[TagAlias] = [] self.setMinimumSize(300, 460) self.root_layout = QVBoxLayout(self) @@ -104,7 +102,7 @@ class BuildTagPanel(PanelWidget): self.name_layout.addWidget(self.name_title) self.name_field = QLineEdit() self.name_field.setFixedHeight(24) - self.name_field.textChanged.connect(self.on_name_changed) + self.name_field.textChanged.connect(self._on_name_change) self.name_field.setPlaceholderText(Translations["tag.tag_name_required"]) self.name_layout.addWidget(self.name_field) @@ -140,7 +138,7 @@ class BuildTagPanel(PanelWidget): self.aliases_add_button = QPushButton() self.aliases_add_button.setText("+") - self.aliases_add_button.clicked.connect(self.add_alias_callback) + self.aliases_add_button.clicked.connect(self._create_alias_callback) # Parent Tags ---------------------------------------------------------- self.parent_tags_widget = QWidget() @@ -179,7 +177,7 @@ class BuildTagPanel(PanelWidget): exclude_ids.append(tag.id) self.add_tag_modal = TagSearchModal(self.lib, exclude_ids) - self.add_tag_modal.tsp.item_chosen.connect(lambda x: self.add_parent_tag_callback(x)) + self.add_tag_modal.tsp.item_chosen.connect(lambda x: self._add_parent_tag_callback(x)) self.parent_tags_add_button.clicked.connect(self.add_tag_modal.show) # Color ---------------------------------------------------------------- @@ -202,9 +200,7 @@ class BuildTagPanel(PanelWidget): self.tag_color_selection = TagColorSelection(self.lib) chose_tag_color_title = Translations["tag.choose_color"] self.choose_color_modal = PanelModal( - self.tag_color_selection, - chose_tag_color_title, - chose_tag_color_title, + self.tag_color_selection, chose_tag_color_title, chose_tag_color_title ) self.choose_color_modal.done.connect( lambda: self.choose_color_callback(self.tag_color_selection.selected_color) @@ -252,12 +248,6 @@ class BuildTagPanel(PanelWidget): self.root_layout.addWidget(self.cat_widget) self.root_layout.addWidget(self.hidden_widget) - self.parent_ids: set[int] = set() - self.alias_ids: list[int] = [] - self.alias_names: list[str] = [] - self.new_alias_names: dict[int, str] = {} - self.new_item_id = sys.maxsize - self.set_tag(tag or Tag(name=Translations["tag.new"])) def backspace(self): @@ -269,10 +259,7 @@ class BuildTagPanel(PanelWidget): remove_row = 0 for i in range(0, row): item = self.aliases_table.cellWidget(i, 1) - if ( - isinstance(item, CustomTableItem) - and item.id == cast(CustomTableItem, focused_widget).id - ): + if isinstance(item, CustomTableItem) and item == cast(CustomTableItem, focused_widget): cast(QPushButton, self.aliases_table.cellWidget(i, 0)).click() remove_row = i break @@ -286,41 +273,36 @@ class BuildTagPanel(PanelWidget): self.aliases_table.cellWidget(remove_row - 1, 1).setFocus() def enter(self): + """When the Enter/Return key has been pressed.""" focused_widget = QApplication.focusWidget() if isinstance(focused_widget, CustomTableItem): - self.add_alias_callback() + self._create_alias_callback() - def add_parent_tag_callback(self, tag_id: int): - logger.info("add_parent_tag_callback", tag_id=tag_id) + def _add_parent_tag_callback(self, tag_id: int): self.parent_ids.add(tag_id) self.set_parent_tags() - def remove_parent_tag_callback(self, tag_id: int): - logger.info("remove_parent_tag_callback", tag_id=tag_id) + def _remove_parent_tag_callback(self, tag_id: int): self.parent_ids.remove(tag_id) self.set_parent_tags() - def add_alias_callback(self): - logger.info("add_alias_callback") + def _create_alias_callback(self): + alias = TagAlias("", tag_id=self.tag.id) + self.aliases.append(alias) - id = self.new_item_id - self.alias_ids.append(id) - self.new_alias_names[id] = "" - self.new_item_id -= 1 self._set_aliases() - row = self.aliases_table.rowCount() - 1 item = self.aliases_table.cellWidget(row, 1) item.setFocus() - def remove_alias_callback(self, alias_id: int): - logger.info("remove_alias_callback") - - self.alias_ids.remove(alias_id) + def remove_alias_callback(self, alias: TagAlias): + for i, a in enumerate(self.aliases): + if a.name == alias.name and a.id == alias.id: + del self.aliases[i] + continue self._set_aliases() def choose_color_callback(self, tag_color_group: TagColorGroup | None): - logger.info("choose_color_callback", tag_color_group=tag_color_group) if tag_color_group: self.tag_color_namespace = tag_color_group.namespace self.tag_color_slug = tag_color_group.slug @@ -378,19 +360,30 @@ class BuildTagPanel(PanelWidget): else: text_color = get_tag_text_color(primary_color, highlight_color) + def update_parent_tag_callback(build_tag_panel: BuildTagPanel): + self.lib.update_tag( + build_tag_panel.build_tag(), + parent_ids=set(build_tag_panel.parent_ids), + aliases=set(build_tag_panel.aliases), + ) + self.set_parent_tags() + + def on_parent_tag_edit(tag: Tag) -> None: + build_tag_panel = BuildTagPanel(self.lib, tag=tag) + edit_modal = PanelModal( + build_tag_panel, + self.lib.tag_display_name(tag), + "Edit Tag", + is_savable=True, + ) + edit_modal.saved.connect(partial(update_parent_tag_callback, build_tag_panel)) + edit_modal.show() + # Add Tag Widget - tag_widget = TagWidget( - tag, - library=self.lib, - has_edit=True, - has_remove=True, - ) - tag_widget.on_remove.connect(lambda t=parent_id: self.remove_parent_tag_callback(t)) - tag_widget.on_edit.connect( - lambda t=tag: TagSearchPanel( - library=self.lib, view=TagSearchPanelView(is_tag_chooser=True) - ).on_item_edit(t) - ) + tag_widget = TagWidget(tag, library=self.lib, has_edit=True, has_remove=True) + tag_widget.on_remove.connect(lambda t=parent_id: self._remove_parent_tag_callback(t)) + tag_widget.on_edit.connect(partial(on_parent_tag_edit, tag)) + row.addWidget(tag_widget) # Add Disambiguation Tag Button @@ -423,59 +416,34 @@ class BuildTagPanel(PanelWidget): else: button.setChecked(False) - def add_aliases(self): - names: set[str] = set() - for i in range(0, self.aliases_table.rowCount()): - widget = self.aliases_table.cellWidget(i, 1) - names.add(cast(CustomTableItem, widget).text()) - - remove: set[str] = set(self.alias_names) - names - self.alias_names = list(set(self.alias_names) - remove) - - for name in names: - # add new aliases - if name.strip() != "" and name not in set(self.alias_names): - self.alias_names.append(name) - elif name.strip() == "" and name in set(self.alias_names): - self.alias_names.remove(name) - - def _update_new_alias_name_dict(self): - for i in range(0, self.aliases_table.rowCount()): - widget = cast(CustomTableItem, self.aliases_table.cellWidget(i, 1)) - self.new_alias_names[widget.id] = widget.text() - def _set_aliases(self): - self._update_new_alias_name_dict() - while self.aliases_table.rowCount() > 0: self.aliases_table.removeRow(0) - self.alias_names.clear() - last: QWidget | None = self.panel_save_button - for alias_id in self.alias_ids: - alias = self.lib.get_alias(self.tag.id, alias_id) + aliases = list(self.aliases) + alias_names = [a.name for a in aliases] + sorted_aliases = sorted(aliases, key=lambda x: alias_names[aliases.index(x)]) - alias_name: str = alias.name if alias else self.new_alias_names[alias_id] + # Sort the TagAlias objects while keeping in-progress empty ones at the bottom + empty_aliases: list[TagAlias] = [] + while sorted_aliases and sorted_aliases[0].name == "": + empty_aliases.append(sorted_aliases.pop(0)) + for alias in empty_aliases: + sorted_aliases.append(alias) - # handel when an alias name changes - if alias_id in self.new_alias_names: - alias_name = self.new_alias_names[alias_id] - - self.alias_names.append(alias_name) - - remove_btn = QPushButton("-") - remove_btn.clicked.connect(lambda id=alias_id: self.remove_alias_callback(id)) + for alias in sorted_aliases: + remove_button = QPushButton("-") + remove_button.clicked.connect(partial(self.remove_alias_callback, alias)) row = self.aliases_table.rowCount() - new_item = CustomTableItem(alias_name, self.enter, self.backspace) - new_item.set_id(alias_id) - - new_item.editingFinished.connect(lambda item=new_item: self._alias_name_change(item)) + new_item = CustomTableItem(alias.name, self.enter, self.backspace) + new_item.alias = alias + new_item.editingFinished.connect(partial(self._on_alias_change, new_item)) self.aliases_table.insertRow(row) self.aliases_table.setCellWidget(row, 1, new_item) - self.aliases_table.setCellWidget(row, 0, remove_btn) + self.aliases_table.setCellWidget(row, 0, remove_button) if last is not None: self.setTabOrder(last, self.aliases_table.cellWidget(row, 1)) @@ -484,22 +452,25 @@ class BuildTagPanel(PanelWidget): ) last = self.aliases_table.cellWidget(row, 0) - def _alias_name_change(self, item: CustomTableItem): - self.new_alias_names[item.id] = item.text() + def _on_alias_change(self, item: CustomTableItem): + for alias in self.aliases: + if item.alias == alias: + alias.name = item.text() + item.alias.name = item.text() + continue def set_tag(self, tag: Tag): - logger.info("[BuildTagPanel] Setting Tag", tag=tag) + logger.info("[BuildTagPanel] Setting Tag", tag_id=tag.id) self.tag = tag - self.name_field.setText(tag.name) self.shorthand_field.setText(tag.shorthand or "") - for alias_id in tag.alias_ids: - self.alias_ids.append(alias_id) + for alias in tag.aliases: + self.aliases.append(alias) self._set_aliases() self.disambiguation_id = tag.disambiguation_id - for parent_id in tag.parent_ids: + for parent_id in self.tag.parent_ids: self.parent_ids.add(parent_id) self.set_parent_tags() @@ -516,9 +487,8 @@ class BuildTagPanel(PanelWidget): self.cat_checkbox.setChecked(tag.is_category) self.hidden_checkbox.setChecked(tag.is_hidden) - def on_name_changed(self): + def _on_name_change(self): is_empty = not self.name_field.text().strip() - self.name_field.setStyleSheet(line_edit_style() if is_empty else "") if self.panel_save_button is not None: @@ -526,8 +496,6 @@ class BuildTagPanel(PanelWidget): def build_tag(self) -> Tag: tag = self.tag - self.add_aliases() - tag.name = self.name_field.text() tag.shorthand = self.shorthand_field.text() tag.disambiguation_id = self.disambiguation_id @@ -536,7 +504,7 @@ class BuildTagPanel(PanelWidget): tag.is_category = self.cat_checkbox.isChecked() tag.is_hidden = self.hidden_checkbox.isChecked() - logger.info("built tag", tag=tag) + logger.info("[BuildTag] Build Tag", tag_id=tag.id, tag_name=tag.name) return tag @override @@ -550,4 +518,3 @@ class BuildTagPanel(PanelWidget): self.setTabOrder(unwrap(self.panel_save_button), self.aliases_table.cellWidget(0, 1)) self.name_field.selectAll() self.name_field.setFocus() - self._set_aliases() diff --git a/src/tagstudio/qt/mixed/tag_widget.py b/src/tagstudio/qt/mixed/tag_widget.py index c17cf196..cf114f10 100644 --- a/src/tagstudio/qt/mixed/tag_widget.py +++ b/src/tagstudio/qt/mixed/tag_widget.py @@ -110,14 +110,7 @@ class TagWidget(QWidget): tag: Tag | None def __init__( - self, - tag: Tag | None, - has_edit: bool, - has_remove: bool, - library: "Library | None" = None, - on_remove_callback: Callable[[], None] | None = None, - on_click_callback: Callable[[], None] | None = None, - on_edit_callback: Callable[[], None] | None = None, + self, tag: Tag | None, has_edit: bool, has_remove: bool, library: "Library | None" = None ) -> None: super().__init__() self.tag = tag @@ -134,14 +127,6 @@ class TagWidget(QWidget): self.bg_button = QPushButton(self) self.bg_button.setFlat(True) - # add callbacks - if on_remove_callback is not None: - self.on_remove.connect(on_remove_callback) - if on_click_callback is not None: - self.on_click.connect(on_click_callback) - if on_edit_callback is not None: - self.on_edit.connect(on_edit_callback) - # add edit action if has_edit: edit_action = QAction(self) diff --git a/src/tagstudio/qt/ts_qt.py b/src/tagstudio/qt/ts_qt.py index ad9aaadc..a5a9c840 100644 --- a/src/tagstudio/qt/ts_qt.py +++ b/src/tagstudio/qt/ts_qt.py @@ -890,8 +890,7 @@ class QtDriver(DriverMixin, QObject): self.lib.add_tag( panel.build_tag(), set(panel.parent_ids), - set(panel.alias_names), - set(panel.alias_ids), + set(panel.aliases), ), self.modal.hide(), ) diff --git a/tests/qt/test_build_tag_panel.py b/tests/qt/test_build_tag_panel.py index 5206f36c..024d1270 100644 --- a/tests/qt/test_build_tag_panel.py +++ b/tests/qt/test_build_tag_panel.py @@ -22,7 +22,7 @@ def test_build_tag_panel_add_sub_tag_callback( panel: BuildTagPanel = BuildTagPanel(library, child) qtbot.addWidget(panel) - panel.add_parent_tag_callback(parent.id) + panel._add_parent_tag_callback(parent.id) # pyright: ignore[reportPrivateUsage] assert len(panel.parent_ids) == 1 @@ -33,14 +33,14 @@ def test_build_tag_panel_remove_subtag_callback( parent = unwrap(library.add_tag(generate_tag("xxx", id=123))) child = unwrap(library.add_tag(generate_tag("xx", id=124))) - library.update_tag(child, {parent.id}, [], []) + library.update_tag(child, {parent.id}, []) child = unwrap(library.get_tag(child.id)) panel: BuildTagPanel = BuildTagPanel(library, child) qtbot.addWidget(panel) - panel.remove_parent_tag_callback(parent.id) + panel._remove_parent_tag_callback(parent.id) # pyright: ignore[reportPrivateUsage] assert len(panel.parent_ids) == 0 @@ -58,7 +58,7 @@ def test_build_tag_panel_add_alias_callback( panel: BuildTagPanel = BuildTagPanel(library, tag) qtbot.addWidget(panel) - panel.add_alias_callback() + panel._create_alias_callback() # pyright: ignore[reportPrivateUsage] assert panel.aliases_table.rowCount() == 1 @@ -68,7 +68,9 @@ def test_build_tag_panel_remove_alias_callback( ): tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123))) - library.update_tag(tag, [], {"alias", "alias_2"}, {123, 124}) + alias_1 = TagAlias("alias", tag.id) + alias_2 = TagAlias("alias_2", tag.id) + library.update_tag(tag, [], {alias_1, alias_2}) tag = unwrap(library.get_tag(tag.id)) @@ -79,12 +81,11 @@ def test_build_tag_panel_remove_alias_callback( qtbot.addWidget(panel) alias: TagAlias = unwrap(library.get_alias(tag.id, tag.alias_ids[0])) + panel.remove_alias_callback(alias) - panel.remove_alias_callback(alias.id) - - assert len(panel.alias_ids) == 1 - assert len(panel.alias_names) == 1 - assert alias.name not in panel.alias_names + assert len(panel.aliases) == 1 + assert alias not in panel.aliases + assert (alias.id, alias.name) not in [(a.id, a.name) for a in panel.aliases] def test_build_tag_panel_set_parent_tags( @@ -109,7 +110,9 @@ def test_build_tag_panel_add_aliases( ): tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123))) - library.update_tag(tag, [], {"alias", "alias_2"}, {123, 124}) + alias_1 = TagAlias("alias", tag.id) + alias_2 = TagAlias("alias_2", tag.id) + library.update_tag(tag, [], {alias_1, alias_2}) tag = unwrap(library.get_tag(tag.id)) @@ -132,22 +135,13 @@ def test_build_tag_panel_add_aliases( assert "alias" in alias_names assert "alias_2" in alias_names - old_text = widget.text() - widget.setText("alias_update") - - panel.add_aliases() - - assert old_text not in panel.alias_names - assert "alias_update" in panel.alias_names - assert len(panel.alias_names) == 2 - def test_build_tag_panel_set_aliases( qtbot: QtBot, library: Library, generate_tag: Callable[..., Tag] ): tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123))) - - library.update_tag(tag, [], {"alias"}, {123}) + alias_1 = TagAlias("Alias 1", tag.id) + library.update_tag(tag, [], [alias_1]) tag = unwrap(library.get_tag(tag.id)) @@ -157,8 +151,7 @@ def test_build_tag_panel_set_aliases( qtbot.addWidget(panel) assert panel.aliases_table.rowCount() == 1 - assert len(panel.alias_names) == 1 - assert len(panel.alias_ids) == 1 + assert len(panel.aliases) == 1 def test_build_tag_panel_set_tag(qtbot: QtBot, library: Library, generate_tag: Callable[..., Tag]): diff --git a/tests/test_library.py b/tests/test_library.py index 4cca5d8f..13b99dd5 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -15,7 +15,7 @@ from tagstudio.core.library.alchemy.fields import ( TextField, ) from tagstudio.core.library.alchemy.library import Library -from tagstudio.core.library.alchemy.models import Entry, Tag +from tagstudio.core.library.alchemy.models import Entry, Tag, TagAlias from tagstudio.core.utils.types import unwrap logger = structlog.get_logger() @@ -25,10 +25,9 @@ def test_library_add_alias(library: Library, generate_tag: Callable[..., Tag]): tag = unwrap(library.add_tag(generate_tag("xxx", id=123))) parent_ids: set[int] = set() - alias_ids: set[int] = set() - alias_names: set[str] = set() - alias_names.add("test_alias") - library.update_tag(tag, parent_ids, alias_names, alias_ids) + aliases: set[TagAlias] = set() + aliases.add(TagAlias("test_alias", tag.id)) + library.update_tag(tag, parent_ids, aliases) tag = unwrap(library.get_tag(tag.id)) alias_ids = set(tag.alias_ids) @@ -39,10 +38,9 @@ def test_library_get_alias(library: Library, generate_tag: Callable[..., Tag]): tag = unwrap(library.add_tag(generate_tag("xxx", id=123))) parent_ids: set[int] = set() - alias_ids: list[int] = [] - alias_names: set[str] = set() - alias_names.add("test_alias") - library.update_tag(tag, parent_ids, alias_names, alias_ids) + aliases: set[TagAlias] = set() + aliases.add(TagAlias("test_alias", tag.id)) + library.update_tag(tag, parent_ids, aliases) tag = unwrap(library.get_tag(tag.id)) alias_ids = tag.alias_ids @@ -54,19 +52,19 @@ def test_library_update_alias(library: Library, generate_tag: Callable[..., Tag] tag: Tag = unwrap(library.add_tag(generate_tag("xxx", id=123))) parent_ids: set[int] = set() - alias_ids: list[int] = [] - alias_names: set[str] = set() - alias_names.add("test_alias") - library.update_tag(tag, parent_ids, alias_names, alias_ids) + aliases: set[TagAlias] = set() + test_alias = TagAlias("test_alias", tag.id) + aliases.add(test_alias) + library.update_tag(tag, parent_ids, aliases) tag = unwrap(library.get_tag(tag.id)) alias_ids = tag.alias_ids alias = unwrap(library.get_alias(tag.id, alias_ids[0])) assert alias.name == "test_alias" - alias_names.remove("test_alias") - alias_names.add("alias_update") - library.update_tag(tag, parent_ids, alias_names, alias_ids) + aliases.remove(test_alias) + aliases.add(TagAlias("alias_update", tag.id)) + library.update_tag(tag, parent_ids, aliases) tag = unwrap(library.get_tag(tag.id)) assert len(tag.alias_ids) == 1 @@ -108,7 +106,7 @@ def test_tag_self_parent(library: Library, generate_tag: Callable[..., Tag]): tag = unwrap(library.add_tag(generate_tag("xxx", id=123))) assert tag.id == 123 - library.update_tag(tag, {tag.id}, [], []) + library.update_tag(tag, {tag.id}, []) tag = unwrap(library.get_tag(tag.id)) assert len(tag.parent_ids) == 0