diff --git a/tagstudio/src/core/constants.py b/tagstudio/src/core/constants.py index 0ca05377..07ed3814 100644 --- a/tagstudio/src/core/constants.py +++ b/tagstudio/src/core/constants.py @@ -13,3 +13,4 @@ FONT_SAMPLE_SIZES: list[int] = [10, 15, 20] TAG_FAVORITE = 1 TAG_ARCHIVED = 0 +RESERVED_TAG_IDS = range(0, 999) diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index a2734b1b..b12c93e8 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -657,6 +657,39 @@ class Library: session.execute(update_stmt) session.commit() + def remove_tag(self, tag: Tag): + with Session(self.engine, expire_on_commit=False) as session: + try: + subtags = session.scalars( + select(TagSubtag).where(TagSubtag.parent_id == tag.id) + ).all() + + tags_query = select(Tag).options( + selectinload(Tag.subtags), selectinload(Tag.aliases) + ) + tag = session.scalar(tags_query.where(Tag.id == tag.id)) + + aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)) + + for alias in aliases or []: + session.delete(alias) + + for subtag in subtags or []: + session.delete(subtag) + session.expunge(subtag) + + session.delete(tag) + + session.commit() + + session.expunge(tag) + return tag + + except IntegrityError as e: + logger.exception(e) + session.rollback() + return None + def remove_tag_from_field(self, tag: Tag, field: TagBoxField) -> None: with Session(self.engine) as session: field_ = session.scalars(select(TagBoxField).where(TagBoxField.id == field.id)).one() diff --git a/tagstudio/src/qt/modals/tag_database.py b/tagstudio/src/qt/modals/tag_database.py index 937b4fb2..6c5195f3 100644 --- a/tagstudio/src/qt/modals/tag_database.py +++ b/tagstudio/src/qt/modals/tag_database.py @@ -9,10 +9,13 @@ from PySide6.QtWidgets import ( QFrame, QHBoxLayout, QLineEdit, + QMessageBox, + QPushButton, QScrollArea, QVBoxLayout, QWidget, ) +from src.core.constants import RESERVED_TAG_IDS from src.core.library import Library, Tag from src.qt.modals.build_tag import BuildTagPanel from src.qt.widgets.panel import PanelModal, PanelWidget @@ -59,8 +62,32 @@ class TagDatabasePanel(PanelWidget): self.scroll_area.setFrameShape(QFrame.Shape.NoFrame) self.scroll_area.setWidget(self.scroll_contents) + self.create_tag_button = QPushButton() + self.create_tag_button.setText("Create Tag") + self.create_tag_button.clicked.connect(self.build_tag) + self.root_layout.addWidget(self.search_field) self.root_layout.addWidget(self.scroll_area) + self.root_layout.addWidget(self.create_tag_button) + self.update_tags() + + def build_tag(self): + self.modal = PanelModal( + BuildTagPanel(self.lib), + "New Tag", + "Add Tag", + has_save=True, + ) + + panel: BuildTagPanel = self.modal.widget + self.modal.saved.connect( + lambda: ( + self.lib.add_tag(panel.build_tag(), panel.subtag_ids), + self.modal.hide(), + self.update_tags(), + ) + ) + self.modal.show() def on_return(self, text: str): if text and self.first_tag_id >= 0: @@ -84,14 +111,41 @@ class TagDatabasePanel(PanelWidget): row = QHBoxLayout(container) row.setContentsMargins(0, 0, 0, 0) row.setSpacing(3) - tag_widget = TagWidget(tag, has_edit=True, has_remove=False) + + if tag.id in RESERVED_TAG_IDS: + tag_widget = TagWidget(tag, has_edit=False, has_remove=False) + else: + tag_widget = TagWidget(tag, has_edit=True, has_remove=True) + tag_widget.on_edit.connect(lambda checked=False, t=tag: self.edit_tag(t)) + tag_widget.on_remove.connect(lambda t=tag: self.remove_tag(t)) row.addWidget(tag_widget) self.scroll_layout.addWidget(container) self.search_field.setFocus() + def remove_tag(self, tag: Tag): + if tag.id in RESERVED_TAG_IDS: + return + + message_box = QMessageBox() + message_box.setWindowTitle("Remove tag") + message_box.setText("Are you sure you want to remove " + tag.name + "?") + message_box.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) # type: ignore + message_box.setIcon(QMessageBox.Question) # type: ignore + + result = message_box.exec() + + if result != QMessageBox.Ok: # type: ignore + return + + self.lib.remove_tag(tag) + self.update_tags() + def edit_tag(self, tag: Tag): + if tag.id in RESERVED_TAG_IDS: + return + build_tag_panel = BuildTagPanel(self.lib, tag=tag) self.edit_modal = PanelModal( diff --git a/tagstudio/src/qt/widgets/panel.py b/tagstudio/src/qt/widgets/panel.py index d16803c7..b164acc9 100755 --- a/tagstudio/src/qt/widgets/panel.py +++ b/tagstudio/src/qt/widgets/panel.py @@ -91,6 +91,10 @@ class PanelModal(QWidget): self.root_layout.setStretch(1, 2) self.root_layout.addWidget(self.button_container) + def closeEvent(self, event): # noqa: N802 + self.done_button.click() + event.accept() + class PanelWidget(QWidget): """Used for widgets that go in a modal panel, ex. for editing or searching.""" diff --git a/tagstudio/src/qt/widgets/thumb_renderer.py b/tagstudio/src/qt/widgets/thumb_renderer.py index b17f519a..45550a84 100644 --- a/tagstudio/src/qt/widgets/thumb_renderer.py +++ b/tagstudio/src/qt/widgets/thumb_renderer.py @@ -12,7 +12,6 @@ from pathlib import Path import cv2 import numpy as np -import pillow_jxl # noqa: F401 import rawpy import structlog from mutagen import MutagenError, flac, id3, mp4 diff --git a/tagstudio/tests/test_library.py b/tagstudio/tests/test_library.py index 04fa965e..82f9522e 100644 --- a/tagstudio/tests/test_library.py +++ b/tagstudio/tests/test_library.py @@ -213,6 +213,17 @@ def test_subtags_add(library, generate_tag): assert tag.subtag_ids +def test_remove_tag(library, generate_tag): + tag = library.add_tag(generate_tag("food", id=123)) + + assert tag + + tag_count = len(library.tags) + + library.remove_tag(tag) + assert len(library.tags) == tag_count - 1 + + @pytest.mark.parametrize("is_exclude", [True, False]) def test_search_filter_extensions(library, is_exclude): # Given