diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index 9badbe7a..f9ba256a 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -24,6 +24,7 @@ from sqlalchemy import ( from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import ( Session, + aliased, contains_eager, make_transient, selectinload, @@ -417,13 +418,18 @@ class Library: statement = select(Entry) if search.tag: + SubtagAlias = aliased(Tag) # noqa: N806 statement = ( statement.join(Entry.tag_box_fields) .join(TagBoxField.tags) + .outerjoin(Tag.aliases) + .outerjoin(SubtagAlias, Tag.subtags) .where( or_( Tag.name.ilike(search.tag), Tag.shorthand.ilike(search.tag), + TagAlias.name.ilike(search.tag), + SubtagAlias.name.ilike(search.tag), ) ) ) @@ -752,18 +758,23 @@ class Library: ) return True - def add_tag(self, tag: Tag, subtag_ids: list[int] | None = None) -> Tag | None: + def add_tag( + self, + tag: Tag, + subtag_ids: set[int] | None = None, + alias_names: set[str] | None = None, + alias_ids: set[int] | None = None, + ) -> Tag | None: with Session(self.engine, expire_on_commit=False) as session: try: session.add(tag) session.flush() - for subtag_id in subtag_ids or []: - subtag = TagSubtag( - parent_id=tag.id, - child_id=subtag_id, - ) - session.add(subtag) + if subtag_ids is not None: + self.update_subtags(tag, subtag_ids, session) + + if alias_ids is not None and alias_names is not None: + self.update_aliases(tag, alias_ids, alias_names, session) session.commit() @@ -847,25 +858,38 @@ class Library: def get_tag(self, tag_id: int) -> Tag: with Session(self.engine) as session: - tags_query = select(Tag).options(selectinload(Tag.subtags)) + tags_query = select(Tag).options(selectinload(Tag.subtags), selectinload(Tag.aliases)) tag = session.scalar(tags_query.where(Tag.id == tag_id)) session.expunge(tag) for subtag in tag.subtags: session.expunge(subtag) + for alias in tag.aliases: + session.expunge(alias) + return tag + def get_alias(self, tag_id: int, alias_id: int) -> TagAlias: + with Session(self.engine) as session: + alias_query = select(TagAlias).where(TagAlias.id == alias_id, TagAlias.tag_id == tag_id) + alias = session.scalar(alias_query.where(TagAlias.id == alias_id)) + + return alias + def add_subtag(self, base_id: int, new_tag_id: int) -> bool: + if base_id == new_tag_id: + return False + # open session and save as parent tag with Session(self.engine) as session: - tag = TagSubtag( + subtag = TagSubtag( parent_id=base_id, child_id=new_tag_id, ) try: - session.add(tag) + session.add(subtag) session.commit() return True except IntegrityError: @@ -873,49 +897,62 @@ class Library: logger.exception("IntegrityError") return False - def update_tag(self, tag: Tag, subtag_ids: list[int]) -> None: - """Edit a Tag in the Library.""" - # TODO - maybe merge this with add_tag? - - if tag.shorthand: - tag.shorthand = slugify(tag.shorthand) - - if tag.aliases: - # TODO - ... - - # save the tag + def remove_subtag(self, base_id: int, remove_tag_id: int) -> bool: with Session(self.engine) as session: - try: - # update the existing tag - session.add(tag) - session.flush() + p_id = base_id + r_id = remove_tag_id + remove = session.query(TagSubtag).filter_by(parent_id=p_id, child_id=r_id).one() + session.delete(remove) + session.commit() - # load all tag's subtag to know which to remove - prev_subtags = session.scalars( - select(TagSubtag).where(TagSubtag.parent_id == tag.id) - ).all() + return True - for subtag in prev_subtags: - if subtag.child_id not in subtag_ids: - session.delete(subtag) - else: - # no change, remove from list - subtag_ids.remove(subtag.child_id) + def update_tag( + self, + tag: Tag, + subtag_ids: set[int] | None = None, + alias_names: set[str] | None = None, + alias_ids: set[int] | None = None, + ) -> None: + """Edit a Tag in the Library.""" + self.add_tag(tag, subtag_ids, alias_names, alias_ids) + + def update_aliases(self, tag, alias_ids, alias_names, session): + prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all() + + for alias in prev_aliases: + if alias.id not in alias_ids or alias.name not in alias_names: + session.delete(alias) + else: + alias_ids.remove(alias.id) + alias_names.remove(alias.name) + + for alias_name in alias_names: + alias = TagAlias(alias_name, tag.id) + session.add(alias) + + def update_subtags(self, tag, subtag_ids, session): + if tag.id in subtag_ids: + subtag_ids.remove(tag.id) + + # load all tag's subtag to know which to remove + prev_subtags = session.scalars(select(TagSubtag).where(TagSubtag.parent_id == tag.id)).all() + + for subtag in prev_subtags: + if subtag.child_id not in subtag_ids: + session.delete(subtag) + else: + # no change, remove from list + subtag_ids.remove(subtag.child_id) # create remaining items - for subtag_id in subtag_ids: - # add new subtag - subtag = TagSubtag( - parent_id=tag.id, - child_id=subtag_id, - ) - session.add(subtag) - - session.commit() - except IntegrityError: - session.rollback() - logger.exception("IntegrityError") + for subtag_id in subtag_ids: + # add new subtag + subtag = TagSubtag( + parent_id=tag.id, + child_id=subtag_id, + ) + session.add(subtag) def prefs(self, key: LibraryPrefs) -> Any: # load given item from Preferences table diff --git a/tagstudio/src/core/library/alchemy/models.py b/tagstudio/src/core/library/alchemy/models.py index 09b54e3e..734c6823 100644 --- a/tagstudio/src/core/library/alchemy/models.py +++ b/tagstudio/src/core/library/alchemy/models.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional from sqlalchemy import JSON, ForeignKey, Integer, event from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -29,11 +28,11 @@ class TagAlias(Base): tag_id: Mapped[int] = mapped_column(ForeignKey("tags.id")) tag: Mapped["Tag"] = relationship(back_populates="aliases") - def __init__(self, name: str, tag: Optional["Tag"] = None): + def __init__(self, name: str, tag_id: int | None = None): self.name = name - if tag: - self.tag = tag + if tag_id is not None: + self.tag_id = tag_id super().__init__() @@ -73,6 +72,10 @@ class Tag(Base): def alias_strings(self) -> list[str]: return [alias.name for alias in self.aliases] + @property + def alias_ids(self) -> list[int]: + return [tag.id for tag in self.aliases] + def __init__( self, name: str, diff --git a/tagstudio/src/qt/flowlayout.py b/tagstudio/src/qt/flowlayout.py index 6334cf95..046b93c9 100644 --- a/tagstudio/src/qt/flowlayout.py +++ b/tagstudio/src/qt/flowlayout.py @@ -140,4 +140,7 @@ class FlowLayout(QLayout): x = next_x line_height = max(line_height, item.sizeHint().height()) + if len(self._item_list) == 0: + return 0 + return y + line_height - rect.y() * ((len(self._item_list)) / len(self._item_list)) diff --git a/tagstudio/src/qt/modals/build_tag.py b/tagstudio/src/qt/modals/build_tag.py old mode 100755 new mode 100644 index af779819..1ed8821d --- a/tagstudio/src/qt/modals/build_tag.py +++ b/tagstudio/src/qt/modals/build_tag.py @@ -3,25 +3,31 @@ # Created for TagStudio: https://github.com/CyanVoxel/TagStudio +import math +from typing import cast + import structlog +from PySide6 import QtCore from PySide6.QtCore import Qt, Signal +from PySide6.QtGui import ( + QAction, +) from PySide6.QtWidgets import ( + QApplication, QComboBox, - QFrame, QLabel, QLineEdit, QPushButton, - QScrollArea, - QTextEdit, QVBoxLayout, QWidget, ) from src.core.library import Library, Tag from src.core.library.alchemy.enums import TagColor from src.core.palette import ColorType, UiColor, get_tag_color, get_ui_color +from src.qt.flowlayout import FlowLayout from src.qt.modals.tag_search import TagSearchPanel from src.qt.widgets.panel import PanelModal, PanelWidget -from src.qt.widgets.tag import TagWidget +from src.qt.widgets.tag import TagAliasWidget, TagWidget logger = structlog.get_logger(__name__) @@ -79,12 +85,47 @@ class BuildTagPanel(PanelWidget): self.aliases_title = QLabel() self.aliases_title.setText("Aliases") self.aliases_layout.addWidget(self.aliases_title) - self.aliases_field = QTextEdit() - self.aliases_field.setAcceptRichText(False) - self.aliases_field.setMinimumHeight(40) - self.aliases_layout.addWidget(self.aliases_field) + + self.aliases_flow_widget = QWidget() + self.aliases_flow_layout = FlowLayout(self.aliases_flow_widget) + self.aliases_flow_layout.setContentsMargins(0, 0, 0, 0) + self.aliases_flow_layout.enable_grid_optimizations(value=False) + + self.alias_add_button = QPushButton() + self.alias_add_button.setMinimumSize(23, 23) + self.alias_add_button.setMaximumSize(23, 23) + self.alias_add_button.setText("+") + self.alias_add_button.setToolTip("CTRL + A") + self.alias_add_button.setShortcut( + QtCore.QKeyCombination( + QtCore.Qt.KeyboardModifier(QtCore.Qt.KeyboardModifier.ControlModifier), + QtCore.Qt.Key.Key_A, + ) + ) + self.alias_add_button.setStyleSheet( + f"QPushButton{{" + f"background: #1e1e1e;" + f"color: #FFFFFF;" + f"font-weight: bold;" + f"border-color: #333333;" + f"border-radius: 6px;" + f"border-style:solid;" + f"border-width:{math.ceil(self.devicePixelRatio())}px;" + f"padding-bottom: 5px;" + f"font-size: 20px;" + f"}}" + f"QPushButton::hover" + f"{{" + f"border-color: #CCCCCC;" + f"background: #555555;" + f"}}" + ) + + self.alias_add_button.clicked.connect(lambda: self.add_alias_callback()) + self.aliases_flow_layout.addWidget(self.alias_add_button) # Subtags ------------------------------------------------------------ + self.subtags_widget = QWidget() self.subtags_layout = QVBoxLayout(self.subtags_widget) self.subtags_layout.setStretch(1, 1) @@ -96,28 +137,52 @@ class BuildTagPanel(PanelWidget): self.subtags_title.setText("Parent Tags") self.subtags_layout.addWidget(self.subtags_title) - self.scroll_contents = QWidget() - self.scroll_layout = QVBoxLayout(self.scroll_contents) - self.scroll_layout.setContentsMargins(6, 0, 6, 0) - self.scroll_layout.setAlignment(Qt.AlignmentFlag.AlignTop) - - self.scroll_area = QScrollArea() - # self.scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn) - self.scroll_area.setWidgetResizable(True) - self.scroll_area.setFrameShadow(QFrame.Shadow.Plain) - self.scroll_area.setFrameShape(QFrame.Shape.NoFrame) - self.scroll_area.setWidget(self.scroll_contents) - # self.scroll_area.setMinimumHeight(60) - - self.subtags_layout.addWidget(self.scroll_area) + self.subtag_flow_widget = QWidget() + self.subtag_flow_layout = FlowLayout(self.subtag_flow_widget) + self.subtag_flow_layout.setContentsMargins(0, 0, 0, 0) + self.subtag_flow_layout.enable_grid_optimizations(value=False) self.subtags_add_button = QPushButton() + self.subtags_add_button.setCursor(Qt.CursorShape.PointingHandCursor) self.subtags_add_button.setText("+") - tsp = TagSearchPanel(self.lib) + self.subtags_add_button.setToolTip("CTRL + P") + self.subtags_add_button.setMinimumSize(23, 23) + self.subtags_add_button.setMaximumSize(23, 23) + self.subtags_add_button.setShortcut( + QtCore.QKeyCombination( + QtCore.Qt.KeyboardModifier(QtCore.Qt.KeyboardModifier.ControlModifier), + QtCore.Qt.Key.Key_P, + ) + ) + self.subtags_add_button.setStyleSheet( + f"QPushButton{{" + f"background: #1e1e1e;" + f"color: #FFFFFF;" + f"font-weight: bold;" + f"border-color: #333333;" + f"border-radius: 6px;" + f"border-style:solid;" + f"border-width:{math.ceil(self.devicePixelRatio())}px;" + f"padding-bottom: 5px;" + f"font-size: 20px;" + f"}}" + f"QPushButton::hover" + f"{{" + f"border-color: #CCCCCC;" + f"background: #555555;" + f"}}" + ) + self.subtag_flow_layout.addWidget(self.subtags_add_button) + + exclude_ids: list[int] = list() + if tag is not None: + exclude_ids.append(tag.id) + + tsp = TagSearchPanel(self.lib, exclude_ids) tsp.tag_chosen.connect(lambda x: self.add_subtag_callback(x)) self.add_tag_modal = PanelModal(tsp, "Add Parent Tags", "Add Parent Tags") self.subtags_add_button.clicked.connect(self.add_tag_modal.show) - self.subtags_layout.addWidget(self.subtags_add_button) + # self.subtags_layout.addWidget(self.subtags_add_button) # self.subtags_field = TagBoxWidget() # self.subtags_field.setMinimumHeight(60) @@ -147,68 +212,204 @@ class BuildTagPanel(PanelWidget): "font-weight:600;" f"color:{get_tag_color(ColorType.TEXT, self.color_field.currentData())};" f"background-color:{get_tag_color( - ColorType.PRIMARY, + ColorType.PRIMARY, self.color_field.currentData())};" ) ) ) self.color_layout.addWidget(self.color_field) + remove_selected_alias_action = QAction("remove selected alias", self) + remove_selected_alias_action.triggered.connect(self.remove_selected_alias) + remove_selected_alias_action.setShortcut( + QtCore.QKeyCombination( + QtCore.Qt.KeyboardModifier(QtCore.Qt.KeyboardModifier.ControlModifier), + QtCore.Qt.Key.Key_D, + ) + ) + self.addAction(remove_selected_alias_action) # Add Widgets to Layout ================================================ self.root_layout.addWidget(self.name_widget) self.root_layout.addWidget(self.shorthand_widget) self.root_layout.addWidget(self.aliases_widget) + self.root_layout.addWidget(self.aliases_flow_widget) self.root_layout.addWidget(self.subtags_widget) + self.root_layout.addWidget(self.subtag_flow_widget) self.root_layout.addWidget(self.color_widget) # self.parent().done.connect(self.update_tag) - # TODO - fill subtags - self.subtags: set[int] = set() + self.subtag_ids: set[int] = set() + self.alias_ids: set[int] = set() + self.alias_names: set[str] = set() + self.new_alias_names: dict = dict() + self.set_tag(tag or Tag(name="New Tag")) if tag is None: self.name_field.selectAll() + def keyPressEvent(self, event): # noqa: N802 + if event.key() == Qt.Key_Return or event.key() == Qt.Key_Enter: # type: ignore + focused_widget = QApplication.focusWidget() + if isinstance(focused_widget.parent(), TagAliasWidget): + self.add_alias_callback() + + def remove_selected_alias(self): + count = self.aliases_flow_layout.count() + if count <= 0: + return + + focused_widget = QApplication.focusWidget() + + if focused_widget is None: + return + + if isinstance(focused_widget.parent(), TagAliasWidget): + cast(TagAliasWidget, focused_widget.parent()).on_remove.emit() + + count = self.aliases_flow_layout.count() + if count > 1: + cast( + TagAliasWidget, self.aliases_flow_layout.itemAt(count - 2).widget() + ).text_field.setFocus() + else: + self.alias_add_button.setFocus() + def add_subtag_callback(self, tag_id: int): logger.info("add_subtag_callback", tag_id=tag_id) - self.subtags.add(tag_id) + self.subtag_ids.add(tag_id) self.set_subtags() def remove_subtag_callback(self, tag_id: int): logger.info("removing subtag", tag_id=tag_id) - self.subtags.remove(tag_id) + self.subtag_ids.remove(tag_id) self.set_subtags() - def set_subtags(self): - while self.scroll_layout.itemAt(0): - self.scroll_layout.takeAt(0).widget().deleteLater() + def add_alias_callback(self): + logger.info("add_alias_callback") + # bug passing in the text for a here means when the text changes + # the remove callback uses what a whas initialy assigned + new_field = TagAliasWidget() + id = new_field.__hash__() + new_field.id = id - c = QWidget() - layout = QVBoxLayout(c) - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(3) - for tag_id in self.subtags: + new_field.on_remove.connect(lambda a="": self.remove_alias_callback(a, id)) + new_field.setMaximumHeight(25) + new_field.setMinimumHeight(25) + + self.alias_ids.add(id) + self.new_alias_names[id] = "" + self.aliases_flow_layout.addWidget(new_field) + new_field.text_field.setFocus() + self.aliases_flow_layout.addWidget(self.alias_add_button) + + def remove_alias_callback(self, alias_name: str, alias_id: int | None = None): + logger.info("remove_alias_callback") + self.alias_ids.remove(alias_id) + self._set_aliases() + + def set_subtags(self): + while self.subtag_flow_layout.itemAt(1): + self.subtag_flow_layout.takeAt(0).widget().deleteLater() + + for tag_id in self.subtag_ids: tag = self.lib.get_tag(tag_id) tw = TagWidget(tag, has_edit=False, has_remove=True) tw.on_remove.connect(lambda t=tag_id: self.remove_subtag_callback(t)) - layout.addWidget(tw) - self.scroll_layout.addWidget(c) + self.subtag_flow_layout.addWidget(tw) + + self.subtag_flow_layout.addWidget(self.subtags_add_button) + + def add_aliases(self): + fields: set[TagAliasWidget] = set() + for i in range(0, self.aliases_flow_layout.count() - 1): + widget = self.aliases_flow_layout.itemAt(i).widget() + + if not isinstance(widget, TagAliasWidget): + return + + field: TagAliasWidget = cast(TagAliasWidget, widget) + fields.add(field) + + remove: set[str] = self.alias_names - set([a.text_field.text() for a in fields]) + + self.alias_names = self.alias_names - remove + + for field in fields: + # add new aliases + if field.text_field.text() != "": + self.alias_names.add(field.text_field.text()) + + def _update_new_alias_name_dict(self): + for i in range(0, self.aliases_flow_layout.count() - 1): + widget = self.aliases_flow_layout.itemAt(i).widget() + + if not isinstance(widget, TagAliasWidget): + return + + field: TagAliasWidget = cast(TagAliasWidget, widget) + text_field_text = field.text_field.text() + + self.new_alias_names[field.id] = text_field_text + + def _set_aliases(self): + self._update_new_alias_name_dict() + + while self.aliases_flow_layout.itemAt(1): + self.aliases_flow_layout.takeAt(0).widget().deleteLater() + + self.alias_names.clear() + + for alias_id in self.alias_ids: + alias = self.lib.get_alias(self.tag.id, alias_id) + + alias_name = alias.name if alias else self.new_alias_names[alias_id] + + new_field = TagAliasWidget( + alias_id, + alias_name, + lambda a=alias_name, id=alias_id: self.remove_alias_callback(a, id), + ) + new_field.setMaximumHeight(25) + new_field.setMinimumHeight(25) + self.aliases_flow_layout.addWidget(new_field) + self.alias_names.add(alias_name) + + self.aliases_flow_layout.addWidget(self.alias_add_button) def set_tag(self, tag: Tag): + self.tag = tag + + self.tag = tag + logger.info("setting tag", tag=tag) self.name_field.setText(tag.name) self.shorthand_field.setText(tag.shorthand or "") - # TODO: Implement aliases - # self.aliases_field.setText("\n".join(tag.aliases)) + + for alias_id in tag.alias_ids: + self.alias_ids.add(alias_id) + + self._set_aliases() + + for subtag in tag.subtag_ids: + self.subtag_ids.add(subtag) + + for alias_id in tag.alias_ids: + self.alias_ids.add(alias_id) + + self._set_aliases() + + for subtag in tag.subtag_ids: + self.subtag_ids.add(subtag) + self.set_subtags() + # select item in self.color_field where the userData value matched tag.color for i in range(self.color_field.count()): if self.color_field.itemData(i) == tag.color: self.color_field.setCurrentIndex(i) break - self.tag = tag - def on_name_changed(self): is_empty = not self.name_field.text().strip() @@ -226,6 +427,8 @@ class BuildTagPanel(PanelWidget): tag = self.tag + self.add_aliases() + tag.name = self.name_field.text() tag.shorthand = self.shorthand_field.text() tag.color = color diff --git a/tagstudio/src/qt/modals/tag_database.py b/tagstudio/src/qt/modals/tag_database.py index f1aebad6..9375c841 100644 --- a/tagstudio/src/qt/modals/tag_database.py +++ b/tagstudio/src/qt/modals/tag_database.py @@ -99,5 +99,5 @@ class TagDatabasePanel(PanelWidget): self.edit_modal.show() def edit_tag_callback(self, btp: BuildTagPanel): - self.lib.add_tag(btp.build_tag()) + self.lib.update_tag(btp.build_tag(), btp.subtag_ids, btp.alias_names, btp.alias_ids) self.update_tags(self.search_field.text()) diff --git a/tagstudio/src/qt/modals/tag_search.py b/tagstudio/src/qt/modals/tag_search.py index c44278fd..1bb25731 100644 --- a/tagstudio/src/qt/modals/tag_search.py +++ b/tagstudio/src/qt/modals/tag_search.py @@ -28,9 +28,10 @@ logger = structlog.get_logger(__name__) class TagSearchPanel(PanelWidget): tag_chosen = Signal(int) - def __init__(self, library: Library): + def __init__(self, library: Library, exclude: list[int] | None = None): super().__init__() self.lib = library + self.exclude = exclude self.first_tag_id = None self.tag_limit = 100 self.setMinimumSize(300, 400) @@ -84,6 +85,8 @@ class TagSearchPanel(PanelWidget): ) for tag in found_tags: + if self.exclude is not None and tag.id in self.exclude: + continue c = QWidget() layout = QHBoxLayout(c) layout.setContentsMargins(0, 0, 0, 0) diff --git a/tagstudio/src/qt/ts_qt.py b/tagstudio/src/qt/ts_qt.py index c780a6a6..c7dfa99c 100644 --- a/tagstudio/src/qt/ts_qt.py +++ b/tagstudio/src/qt/ts_qt.py @@ -632,7 +632,9 @@ class QtDriver(DriverMixin, QObject): panel: BuildTagPanel = self.modal.widget self.modal.saved.connect( lambda: ( - self.lib.add_tag(panel.build_tag(), panel.subtags), + self.lib.add_tag( + panel.build_tag(), panel.subtag_ids, panel.alias_names, panel.alias_ids + ), self.modal.hide(), ) ) diff --git a/tagstudio/src/qt/widgets/tag.py b/tagstudio/src/qt/widgets/tag.py index ae26b342..2d4cc7ce 100644 --- a/tagstudio/src/qt/widgets/tag.py +++ b/tagstudio/src/qt/widgets/tag.py @@ -9,12 +9,88 @@ from types import FunctionType from PIL import Image from PySide6.QtCore import QEvent, Qt, Signal -from PySide6.QtGui import QAction, QEnterEvent -from PySide6.QtWidgets import QHBoxLayout, QPushButton, QVBoxLayout, QWidget +from PySide6.QtGui import QAction, QEnterEvent, QFontMetrics +from PySide6.QtWidgets import ( + QHBoxLayout, + QLineEdit, + QPushButton, + QVBoxLayout, + QWidget, +) from src.core.library import Tag +from src.core.library.alchemy.enums import TagColor from src.core.palette import ColorType, get_tag_color +class TagAliasWidget(QWidget): + on_remove = Signal() + + def __init__( + self, + id: int | None = 0, + alias: str | None = None, + on_remove_callback=None, + ) -> None: + super().__init__() + + self.id = id + + # if on_click_callback: + self.setCursor(Qt.CursorShape.PointingHandCursor) + self.base_layout = QHBoxLayout(self) + self.base_layout.setObjectName("baseLayout") + self.base_layout.setContentsMargins(0, 0, 0, 0) + + self.on_remove.connect(on_remove_callback) + + self.text_field = QLineEdit(self) + self.text_field.textChanged.connect(self._adjust_width) + + if alias is not None: + self.text_field.setText(alias) + else: + self.text_field.setText("") + + self._adjust_width() + + self.remove_button = QPushButton(self) + self.remove_button.setFlat(True) + self.remove_button.setText("–") + self.remove_button.setHidden(False) + self.remove_button.setStyleSheet( + f"color: {get_tag_color(ColorType.PRIMARY, TagColor.DEFAULT)};" + f"background: {get_tag_color(ColorType.TEXT, TagColor.DEFAULT)};" + f"font-weight: 800;" + f"border-radius: 4px;" + f"border-width:0;" + f"padding-bottom: 4px;" + f"font-size: 14px" + ) + self.remove_button.setMinimumSize(19, 19) + self.remove_button.setMaximumSize(19, 19) + self.remove_button.clicked.connect(self.on_remove.emit) + + self.base_layout.addWidget(self.remove_button) + self.base_layout.addWidget(self.text_field) + + def _adjust_width(self): + text = self.text_field.text() or self.text_field.placeholderText() + font_metrics = QFontMetrics(self.text_field.font()) + text_width = font_metrics.horizontalAdvance(text) + 10 # Add padding + + # Set the minimum width of the QLineEdit + self.text_field.setMinimumWidth(text_width) + self.text_field.adjustSize() + + def enterEvent(self, event: QEnterEvent) -> None: # noqa: N802 + self.update() + return super().enterEvent(event) + + def leaveEvent(self, event: QEvent) -> None: # noqa: N802 + self.update() + return super().leaveEvent(event) + + class TagWidget(QWidget): edit_icon_128: Image.Image = Image.open( str(Path(__file__).parents[3] / "resources/qt/images/edit_icon_128.png") diff --git a/tagstudio/src/qt/widgets/tag_box.py b/tagstudio/src/qt/widgets/tag_box.py index c24a3519..653d3c0e 100755 --- a/tagstudio/src/qt/widgets/tag_box.py +++ b/tagstudio/src/qt/widgets/tag_box.py @@ -138,7 +138,7 @@ class TagBoxWidget(FieldWidget): self.edit_modal.saved.connect( lambda: self.driver.lib.update_tag( build_tag_panel.build_tag(), - subtag_ids=build_tag_panel.subtags, + subtag_ids=build_tag_panel.subtag_ids, ) ) self.edit_modal.show() diff --git a/tagstudio/tests/qt/test_build_tag_panel.py b/tagstudio/tests/qt/test_build_tag_panel.py new file mode 100644 index 00000000..9026b475 --- /dev/null +++ b/tagstudio/tests/qt/test_build_tag_panel.py @@ -0,0 +1,185 @@ +from typing import cast + +from PySide6.QtWidgets import QApplication, QMainWindow +from src.core.library.alchemy.models import Tag +from src.qt.modals.build_tag import BuildTagPanel +from src.qt.widgets.tag import TagAliasWidget + + +def test_build_tag_panel_add_sub_tag_callback(library, generate_tag): + parent = library.add_tag(generate_tag("xxx", id=123)) + child = library.add_tag(generate_tag("xx", id=124)) + assert child + assert parent + + panel: BuildTagPanel = BuildTagPanel(library, child) + + panel.add_subtag_callback(parent.id) + + assert len(panel.subtag_ids) == 1 + + +def test_build_tag_panel_remove_subtag_callback(library, generate_tag): + parent = library.add_tag(generate_tag("xxx", id=123)) + child = library.add_tag(generate_tag("xx", id=124)) + assert child + assert parent + + library.update_tag(child, {parent.id}, [], []) + + child = library.get_tag(child.id) + + assert child + + panel: BuildTagPanel = BuildTagPanel(library, child) + + panel.remove_subtag_callback(parent.id) + + assert len(panel.subtag_ids) == 0 + + +import os + +os.environ["QT_QPA_PLATFORM"] = "offscreen" + + +def test_build_tag_panel_remove_selected_alias(library, generate_tag): + app = QApplication.instance() or QApplication([]) + + window = QMainWindow() + parent_tag = library.add_tag(generate_tag("xxx", id=123)) + panel = BuildTagPanel(library, parent_tag) + panel.setParent(window) + + panel.add_alias_callback() + window.show() + + assert panel.aliases_flow_layout.count() == 2 + + alias_widget = panel.aliases_flow_layout.itemAt(0).widget() + alias_widget.text_field.setFocus() + + app.processEvents() + + panel.remove_selected_alias() + + assert panel.aliases_flow_layout.count() == 1 + + +def test_build_tag_panel_add_alias_callback(library, generate_tag): + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + panel: BuildTagPanel = BuildTagPanel(library, tag) + + panel.add_alias_callback() + + assert panel.aliases_flow_layout.count() == 2 + + +def test_build_tag_panel_remove_alias_callback(library, generate_tag): + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + library.update_tag(tag, [], {"alias", "alias_2"}, {123, 124}) + + tag = library.get_tag(tag.id) + + assert "alias" in tag.alias_strings + assert "alias_2" in tag.alias_strings + + panel: BuildTagPanel = BuildTagPanel(library, tag) + + alias = library.get_alias(tag.id, tag.alias_ids[0]) + + panel.remove_alias_callback(alias.name, alias.id) + + assert len(panel.alias_ids) == 1 + assert len(panel.alias_names) == 1 + assert alias.name not in panel.alias_names + + +def test_build_tag_panel_set_subtags(library, generate_tag): + parent = library.add_tag(generate_tag("parent", id=123)) + child = library.add_tag(generate_tag("child", id=124)) + assert parent + assert child + + library.add_subtag(child.id, parent.id) + + child = library.get_tag(child.id) + + panel: BuildTagPanel = BuildTagPanel(library, child) + + assert len(panel.subtag_ids) == 1 + assert panel.subtag_flow_layout.count() == 2 + + +def test_build_tag_panel_add_aliases(library, generate_tag): + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + library.update_tag(tag, [], {"alias", "alias_2"}, {123, 124}) + + tag = library.get_tag(tag.id) + + assert "alias" in tag.alias_strings + assert "alias_2" in tag.alias_strings + + panel: BuildTagPanel = BuildTagPanel(library, tag) + + widget = panel.aliases_flow_layout.itemAt(0).widget() + + alias_names: set[str] = set() + alias_names.add(cast(TagAliasWidget, widget).text_field.text()) + + widget = panel.aliases_flow_layout.itemAt(1).widget() + alias_names.add(cast(TagAliasWidget, widget).text_field.text()) + + assert "alias" in alias_names + assert "alias_2" in alias_names + + old_text = cast(TagAliasWidget, widget).text_field.text() + cast(TagAliasWidget, widget).text_field.setText("alias_update") + + panel.add_aliases() + + assert old_text not in panel.alias_names + assert "alias_update" in panel.alias_names + assert len(panel.alias_names) == 2 + + +def test_build_tag_panel_set_aliases(library, generate_tag): + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + library.update_tag(tag, [], {"alias"}, {123}) + + tag = library.get_tag(tag.id) + + assert len(tag.alias_ids) == 1 + + panel: BuildTagPanel = BuildTagPanel(library, tag) + + assert panel.aliases_flow_layout.count() == 2 + assert len(panel.alias_names) == 1 + assert len(panel.alias_ids) == 1 + + +def test_build_tag_panel_set_tag(library, generate_tag): + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + panel: BuildTagPanel = BuildTagPanel(library, tag) + + assert panel.tag + assert panel.tag.name == "xxx" + + +def test_build_tag_panel_build_tag(library): + panel: BuildTagPanel = BuildTagPanel(library) + + tag: Tag = panel.build_tag() + + assert tag + assert tag.name == "New Tag" diff --git a/tagstudio/tests/qt/test_tag_widget.py b/tagstudio/tests/qt/test_tag_widget.py index 9d10691a..86158bf4 100644 --- a/tagstudio/tests/qt/test_tag_widget.py +++ b/tagstudio/tests/qt/test_tag_widget.py @@ -84,7 +84,9 @@ def test_tag_widget_remove(qtbot, qt_driver, library, entry_full): def test_tag_widget_edit(qtbot, qt_driver, library, entry_full): # Given - tag = list(entry_full.tags)[0] + entry = next(library.get_entries(with_joins=True)) + library.add_tag(list(entry.tags)[0]) + tag = library.get_tag(list(entry.tags)[0].id) assert tag assert entry_full.tag_box_fields @@ -99,9 +101,7 @@ def test_tag_widget_edit(qtbot, qt_driver, library, entry_full): assert isinstance(tag_widget, TagWidget) # When - actions = tag_widget.bg_button.actions() - edit_action = [a for a in actions if a.text() == "Edit"][0] - edit_action.triggered.emit() + tag_box_widget.edit_tag(tag) # Then panel = tag_box_widget.edit_modal.widget diff --git a/tagstudio/tests/test_library.py b/tagstudio/tests/test_library.py index 24e4d9b9..8175c496 100644 --- a/tagstudio/tests/test_library.py +++ b/tagstudio/tests/test_library.py @@ -6,6 +6,64 @@ from src.core.enums import DefaultEnum, LibraryPrefs from src.core.library.alchemy import Entry, Library from src.core.library.alchemy.enums import FilterState from src.core.library.alchemy.fields import TextField, _FieldID +from src.core.library.alchemy.models import Tag + + +def test_library_add_alias(library, generate_tag): + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + subtag_ids: set[int] = set() + alias_ids: set[int] = set() + alias_names: set[str] = set() + alias_names.add("test_alias") + library.update_tag(tag, subtag_ids, alias_names, alias_ids) + + # Note: ask if it is expected behaviour that you need to re-request + # for the tag. Or if the tag in memory should be updated + alias_ids = library.get_tag(tag.id).alias_ids + + assert len(alias_ids) == 1 + + +def test_library_get_alias(library, generate_tag): + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + subtag_ids: set[int] = set() + alias_ids: set[int] = set() + alias_names: set[str] = set() + alias_names.add("test_alias") + library.update_tag(tag, subtag_ids, alias_names, alias_ids) + + alias_ids = library.get_tag(tag.id).alias_ids + + assert library.get_alias(tag.id, alias_ids[0]).name == "test_alias" + + +def test_library_update_alias(library, generate_tag): + tag: Tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + + subtag_ids: set[int] = set() + alias_ids: set[int] = set() + alias_names: set[str] = set() + alias_names.add("test_alias") + library.update_tag(tag, subtag_ids, alias_names, alias_ids) + + tag = library.get_tag(tag.id) + alias_ids = tag.alias_ids + + assert library.get_alias(tag.id, alias_ids[0]).name == "test_alias" + + alias_names.remove("test_alias") + alias_names.add("alias_update") + library.update_tag(tag, subtag_ids, alias_names, alias_ids) + + tag = library.get_tag(tag.id) + + assert len(tag.alias_ids) == 1 + assert library.get_alias(tag.id, tag.alias_ids[0]).name == "alias_update" @pytest.mark.parametrize("library", [TemporaryDirectory()], indirect=True) @@ -38,14 +96,26 @@ def test_create_tag(library, generate_tag): assert tag_inc.id > 1000 +def test_tag_subtag_itself(library, generate_tag): + # tag already exists + assert not library.add_tag(generate_tag("foo")) + + # new tag name + tag = library.add_tag(generate_tag("xxx", id=123)) + assert tag + assert tag.id == 123 + + library.update_tag(tag, {tag.id}, {}, {}) + tag = library.get_tag(tag.id) + assert len(tag.subtag_ids) == 0 + + def test_library_search(library, generate_tag, entry_full): assert library.entries_count == 2 tag = list(entry_full.tags)[0] results = library.search_library( - FilterState( - tag=tag.name, - ), + FilterState(tag=tag.name), ) assert results.total_count == 1