From ab59fc4a5041a6e8800bfabb4bc05e706701dc79 Mon Sep 17 00:00:00 2001 From: Travis Abendshien Date: Sun, 8 Dec 2024 14:28:47 -0800 Subject: [PATCH] refactor: remove tag field types --- tagstudio/src/core/constants.py | 3 +- tagstudio/src/core/library/alchemy/db.py | 4 +- tagstudio/src/core/library/alchemy/fields.py | 4 - tagstudio/src/core/library/alchemy/library.py | 145 ++++---- tagstudio/src/core/library/alchemy/models.py | 13 +- tagstudio/src/qt/widgets/item_thumb.py | 22 +- tagstudio/src/qt/widgets/migration_modal.py | 129 +++---- tagstudio/src/qt/widgets/preview_panel.py | 10 +- tagstudio/src/qt/widgets/tag_box.py | 321 +++++++++--------- tagstudio/tests/conftest.py | 29 +- tagstudio/tests/qt/test_tag_widget.py | 152 ++++----- tagstudio/tests/test_json_migration.py | 3 +- 12 files changed, 404 insertions(+), 431 deletions(-) mode change 100755 => 100644 tagstudio/src/qt/widgets/tag_box.py diff --git a/tagstudio/src/core/constants.py b/tagstudio/src/core/constants.py index 335df553..5b0f1730 100644 --- a/tagstudio/src/core/constants.py +++ b/tagstudio/src/core/constants.py @@ -11,7 +11,8 @@ FONT_SAMPLE_TEXT: str = ( ) FONT_SAMPLE_SIZES: list[int] = [10, 15, 20] -TAG_FAVORITE = 1 TAG_ARCHIVED = 0 +TAG_FAVORITE = 1 +TAG_META = 2 RESERVED_TAG_START = 0 RESERVED_TAG_END = 999 diff --git a/tagstudio/src/core/library/alchemy/db.py b/tagstudio/src/core/library/alchemy/db.py index f48ed2b5..938ae79f 100644 --- a/tagstudio/src/core/library/alchemy/db.py +++ b/tagstudio/src/core/library/alchemy/db.py @@ -44,7 +44,9 @@ def make_tables(engine: Engine) -> None: autoincrement_val = result.scalar() if not autoincrement_val or autoincrement_val <= RESERVED_TAG_END: conn.execute( - text(f"INSERT INTO tags (id, name, color) VALUES ({RESERVED_TAG_END}, 'temp', 1)") + text( + f"INSERT INTO tags (id, name, color, is_category) VALUES ({RESERVED_TAG_END}, 'temp', 1, false)" + ) ) conn.execute(text(f"DELETE FROM tags WHERE id = {RESERVED_TAG_END}")) conn.commit() diff --git a/tagstudio/src/core/library/alchemy/fields.py b/tagstudio/src/core/library/alchemy/fields.py index c4618b3e..9f180cfc 100644 --- a/tagstudio/src/core/library/alchemy/fields.py +++ b/tagstudio/src/core/library/alchemy/fields.py @@ -134,10 +134,6 @@ class _FieldID(Enum): URL = DefaultField(id=3, name="URL", type=FieldTypeEnum.TEXT_LINE) DESCRIPTION = DefaultField(id=4, name="Description", type=FieldTypeEnum.TEXT_LINE) NOTES = DefaultField(id=5, name="Notes", type=FieldTypeEnum.TEXT_BOX) - # TODO: Remove (i think) - # TAGS = DefaultField(id=6, name="Tags", type=FieldTypeEnum.TAGS) - # TAGS_CONTENT = DefaultField(id=7, name="Content Tags", type=FieldTypeEnum.TAGS, is_default=True) - # TAGS_META = DefaultField(id=8, name="Meta Tags", type=FieldTypeEnum.TAGS, is_default=True) COLLATION = DefaultField(id=9, name="Collation", type=FieldTypeEnum.TEXT_LINE) DATE = DefaultField(id=10, name="Date", type=FieldTypeEnum.DATETIME) DATE_CREATED = DefaultField(id=11, name="Date Created", type=FieldTypeEnum.DATETIME) diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index f1fb30af..2e21e261 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from datetime import UTC, datetime from os import makedirs from pathlib import Path -from typing import Any from uuid import uuid4 import structlog @@ -38,6 +37,7 @@ from ...constants import ( BACKUP_FOLDER_NAME, TAG_ARCHIVED, TAG_FAVORITE, + TAG_META, TS_FOLDER_NAME, ) from ...enums import LibraryPrefs @@ -70,13 +70,19 @@ def slugify(input_string: str) -> str: def get_default_tags() -> tuple[Tag, ...]: + meta_tag = Tag( + id=TAG_META, + name="Meta Tags", + aliases={TagAlias(name="Meta"), TagAlias(name="Meta Tag")}, + is_category=True, + ) archive_tag = Tag( id=TAG_ARCHIVED, name="Archived", aliases={TagAlias(name="Archive")}, + subtags={meta_tag}, color=TagColor.RED, ) - favorite_tag = Tag( id=TAG_FAVORITE, name="Favorite", @@ -84,10 +90,15 @@ def get_default_tags() -> tuple[Tag, ...]: TagAlias(name="Favorited"), TagAlias(name="Favorites"), }, + subtags={meta_tag}, color=TagColor.YELLOW, ) - return archive_tag, favorite_tag + return archive_tag, favorite_tag, meta_tag + + +# The difference in the number of default JSON tags vs default tags in the current version. +DEFAULT_TAG_DIFF: int = len(get_default_tags()) - 2 @dataclass(frozen=True) @@ -156,14 +167,18 @@ class Library: # Tags for tag in json_lib.tags: - self.add_tag( - Tag( - id=tag.id, - name=tag.name, - shorthand=tag.shorthand, - color=TagColor.get_color_from_str(tag.color), + if tag.id == TAG_ARCHIVED or tag.id == TAG_FAVORITE: + # Update built-in + pass + else: + self.add_tag( + Tag( + id=tag.id, + name=tag.name, + shorthand=tag.shorthand, + color=TagColor.get_color_from_str(tag.color), + ) ) - ) # Tag Aliases for tag in json_lib.tags: @@ -192,11 +207,15 @@ class Library: for entry in json_lib.entries: for field in entry.fields: for k, v in field.items(): - self.add_entry_field_type( - entry_ids=(entry.id + 1), # JSON IDs start at 0 instead of 1 - field_id=self.get_field_name_from_id(k), - value=v, - ) + # Old tag fields get added as tags + if k in {6, 7, 8}: + self.add_tags_to_entry(entry_id=entry.id + 1, tag_ids=v) + else: + self.add_entry_field_type( + entry_ids=(entry.id + 1), # JSON IDs start at 0 instead of 1 + field_id=self.get_field_name_from_id(k), + value=v, + ) # Preferences self.set_prefs(LibraryPrefs.EXTENSION_LIST, [x.strip(".") for x in json_lib.ext_list]) @@ -232,9 +251,7 @@ class Library: return self.open_sqlite_library(library_dir, is_new) - def open_sqlite_library( - self, library_dir: Path, is_new: bool, add_default_data: bool = True - ) -> LibraryStatus: + def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus: connection_string = URL.create( drivername="sqlite", database=str(self.storage_path), @@ -255,14 +272,13 @@ class Library: with Session(self.engine) as session: make_tables(self.engine) - if add_default_data: - tags = get_default_tags() - try: - session.add_all(tags) - session.commit() - except IntegrityError: - # default tags may exist already - session.rollback() + tags = get_default_tags() + try: + session.add_all(tags) + session.commit() + except IntegrityError: + # default tags may exist already + session.rollback() # dont check db version when creating new library if not is_new: @@ -565,9 +581,9 @@ class Library: statement = statement.options( selectinload(Entry.text_fields), selectinload(Entry.datetime_fields), - selectinload(Entry.tag_box_fields) - # .joinedload(TagBoxField.tags) - .options(selectinload(Tag.aliases), selectinload(Tag.subtags)), + selectinload(Entry.tags).options( + selectinload(Tag.aliases), selectinload(Tag.subtags) + ), ) statement = statement.distinct(Entry.id) @@ -923,61 +939,18 @@ class Library: session.rollback() return None - # TODO: Delete - # def add_field_tag( - # self, - # entry: Entry, - # tag: Tag, - # field_key: str = _FieldID.TAGS.name, - # create_field: bool = False, - # ) -> bool: - # assert isinstance(field_key, str), f"field_key is {type(field_key)}" - # - # with Session(self.engine) as session: - # # find field matching entry and field_type - # field = session.scalars( - # select(TagBoxField).where( - # and_( - # TagBoxField.entry_id == entry.id, - # TagBoxField.type_key == field_key, - # ) - # ) - # ).first() - # - # if not field and not create_field: - # logger.error("no field found", entry=entry, field_key=field_key) - # return False - # - # try: - # if not field: - # field = TagBoxField( - # type_key=field_key, - # entry_id=entry.id, - # position=0, - # ) - # session.add(field) - # session.flush() - # - # # create record for `TagField` table - # if not tag.id: - # session.add(tag) - # session.flush() - # - # tag_field = TagField( - # tag_id=tag.id, - # field_id=field.id, - # ) - # - # session.add(tag_field) - # session.commit() - # logger.info("tag added to field", tag=tag, field=field, entry_id=entry.id) - # - # return True - # except IntegrityError as e: - # logger.exception(e) - # session.rollback() - # - # return False + def add_tags_to_entry(self, entry_id: int, tag_ids: int | list[int] | set[int]): + tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids + with Session(self.engine, expire_on_commit=False) as session: + try: + for tag_id in tag_ids_: + session.add(TagEntry(tag_id=tag_id, entry_id=entry_id)) + session.flush() + session.commit() + except IntegrityError as e: + logger.exception(e) + session.rollback() + return None def save_library_backup_to_disk(self) -> Path: assert isinstance(self.library_dir, Path) @@ -1120,12 +1093,12 @@ class Library: ) session.add(subtag) - def prefs(self, key: LibraryPrefs) -> Any: + def prefs(self, key: LibraryPrefs): # load given item from Preferences table with Session(self.engine) as session: return session.scalar(select(Preferences).where(Preferences.key == key.name)).value - def set_prefs(self, key: LibraryPrefs, value: Any) -> None: + def set_prefs(self, key: LibraryPrefs, value) -> None: # set given item in Preferences table with Session(self.engine) as session: # load existing preference and update value diff --git a/tagstudio/src/core/library/alchemy/models.py b/tagstudio/src/core/library/alchemy/models.py index fe60fb36..a56e8b4c 100644 --- a/tagstudio/src/core/library/alchemy/models.py +++ b/tagstudio/src/core/library/alchemy/models.py @@ -44,6 +44,7 @@ class Tag(Base): name: Mapped[str] shorthand: Mapped[str | None] color: Mapped[TagColor] + is_category: Mapped[bool] icon: Mapped[str | None] aliases: Mapped[set[TagAlias]] = relationship(back_populates="tag") @@ -84,6 +85,7 @@ class Tag(Base): subtags: set["Tag"] | None = None, icon: str | None = None, color: TagColor = TagColor.DEFAULT, + is_category: bool = False, ): self.name = name self.aliases = aliases or set() @@ -92,6 +94,7 @@ class Tag(Base): self.color = color self.icon = icon self.shorthand = shorthand + self.is_category = is_category assert not self.id self.id = id super().__init__() @@ -144,17 +147,11 @@ class Entry(Base): @property def is_favorited(self) -> bool: - for tag in self.tags: - if tag.id == TAG_FAVORITE: - return True - return False + return any(tag.id == TAG_FAVORITE for tag in self.tags) @property def is_archived(self) -> bool: - for tag in self.tags: - if tag.id == TAG_ARCHIVED: - return True - return False + return any(tag.id == TAG_ARCHIVED for tag in self.tags) def __init__( self, diff --git a/tagstudio/src/qt/widgets/item_thumb.py b/tagstudio/src/qt/widgets/item_thumb.py index f79708ad..1a74304b 100644 --- a/tagstudio/src/qt/widgets/item_thumb.py +++ b/tagstudio/src/qt/widgets/item_thumb.py @@ -25,7 +25,6 @@ from src.core.constants import ( TAG_FAVORITE, ) from src.core.library import Entry, ItemType, Library -from src.core.library.alchemy.fields import _FieldID from src.core.media_types import MediaCategories, MediaType from src.qt.flowlayout import FlowWidget from src.qt.helpers.file_opener import FileOpenerHelper @@ -508,9 +507,7 @@ class ItemThumb(FlowWidget): for idx in update_items: entry = self.driver.frame_content[idx] - self.toggle_item_tag( - entry, toggle_value, tag_id, _FieldID.TAGS_META.name, create_field=True - ) + self.toggle_item_tag(entry.id, toggle_value, tag_id) # update the entry self.driver.frame_content[idx] = self.lib.get_entry_full(entry.id) @@ -518,25 +515,16 @@ class ItemThumb(FlowWidget): def toggle_item_tag( self, - entry: Entry, + entry_id: int, toggle_value: bool, tag_id: int, - field_key: str, - create_field: bool = False, ): - logger.info( - "toggle_item_tag", - entry_id=entry.id, - toggle_value=toggle_value, - tag_id=tag_id, - field_key=field_key, - ) + logger.info("toggle_item_tag", entry_id=entry_id, toggle_value=toggle_value, tag_id=tag_id) - tag = self.lib.get_tag(tag_id) if toggle_value: - self.lib.add_field_tag(entry, tag, field_key, create_field) + self.lib.add_tags_to_entry(entry_id, tag_id) else: - self.lib.remove_field_tag(entry, tag.id, field_key) + self.lib.remove_tag_from_entry(entry_id, tag_id) if self.driver.preview_panel.is_open: self.driver.preview_panel.update_widgets() diff --git a/tagstudio/src/qt/widgets/migration_modal.py b/tagstudio/src/qt/widgets/migration_modal.py index e191e597..d141e5d3 100644 --- a/tagstudio/src/qt/widgets/migration_modal.py +++ b/tagstudio/src/qt/widgets/migration_modal.py @@ -24,10 +24,8 @@ from src.core.enums import LibraryPrefs from src.core.library.alchemy.enums import TagColor # from src.core.library.alchemy.fields import TagBoxField, _FieldID -from src.core.library.alchemy.fields import _FieldID - -# from src.core.library.alchemy.joins import TagField, TagSubtag from src.core.library.alchemy.joins import TagSubtag +from src.core.library.alchemy.library import DEFAULT_TAG_DIFF from src.core.library.alchemy.library import Library as SqliteLibrary from src.core.library.alchemy.models import Entry, TagAlias from src.core.library.json.library import Library as JsonLibrary # type: ignore @@ -395,7 +393,7 @@ class JsonMigrationModal(QObject): except Exception as e: yield f"Error: {type(e).__name__}" - self.done = True + self.done = True def update_parity_ui(self): """Update all parity values UI.""" @@ -416,7 +414,7 @@ class JsonMigrationModal(QObject): ) self.update_sql_value( self.tags_row, - len(self.sql_lib.tags), + (len(self.sql_lib.tags) - DEFAULT_TAG_DIFF), self.old_tag_count, ) self.update_sql_value( @@ -551,13 +549,16 @@ class JsonMigrationModal(QObject): return self.field_parity for sf in sql_entry.fields: - sql_fields.append( - ( - sql_entry.id, - sf.type.key, - sanitize_field(session, sql_entry, sf.value, sf.type.type, sf.type_key), + if sf.type.type.value not in {6, 7, 8}: + sql_fields.append( + ( + sql_entry.id, + sf.type.key, + sanitize_field( + session, sql_entry, sf.value, sf.type.type, sf.type_key + ), + ) ) - ) sql_fields.sort() # NOTE: The JSON database allowed for separate tag fields of the same type with @@ -565,56 +566,51 @@ class JsonMigrationModal(QObject): # across all instances of that field on an entry. # TODO: ROADMAP: "Tag Categories" will merge all field tags onto the entry. # All visual separation from there will be data-driven from the tag itself. - meta_tags_count: int = 0 - content_tags_count: int = 0 + # meta_tags_count: int = 0 + # content_tags_count: int = 0 tags_count: int = 0 - merged_meta_tags: set[int] = set() - merged_content_tags: set[int] = set() + # merged_meta_tags: set[int] = set() + # merged_content_tags: set[int] = set() merged_tags: set[int] = set() for jf in json_entry.fields: - key: str = self.sql_lib.get_field_name_from_id(list(jf.keys())[0]).name + int_key: int = list(jf.keys())[0] value = sanitize_json_field(list(jf.values())[0]) - - if key == _FieldID.TAGS_META.name: - meta_tags_count += 1 - merged_meta_tags = merged_meta_tags.union(value or []) - elif key == _FieldID.TAGS_CONTENT.name: - content_tags_count += 1 - merged_content_tags = merged_content_tags.union(value or []) - elif key == _FieldID.TAGS.name: + if int_key in {6, 7, 8}: tags_count += 1 merged_tags = merged_tags.union(value or []) + pass else: - # JSON IDs start at 0 instead of 1 + key: str = self.sql_lib.get_field_name_from_id(int_key).name json_fields.append((json_entry.id + 1, key, value)) - if meta_tags_count: - for _ in range(0, meta_tags_count): - json_fields.append( - ( - json_entry.id + 1, - _FieldID.TAGS_META.name, - merged_meta_tags if merged_meta_tags else None, - ) - ) - if content_tags_count: - for _ in range(0, content_tags_count): - json_fields.append( - ( - json_entry.id + 1, - _FieldID.TAGS_CONTENT.name, - merged_content_tags if merged_content_tags else None, - ) - ) - if tags_count: - for _ in range(0, tags_count): - json_fields.append( - ( - json_entry.id + 1, - _FieldID.TAGS.name, - merged_tags if merged_tags else None, - ) - ) + # TODO: DO NOT IGNORE TAGS + # if meta_tags_count: + # for _ in range(0, meta_tags_count): + # json_fields.append( + # ( + # json_entry.id + 1, + # _FieldID.TAGS_META.name, + # merged_meta_tags if merged_meta_tags else None, + # ) + # ) + # if content_tags_count: + # for _ in range(0, content_tags_count): + # json_fields.append( + # ( + # json_entry.id + 1, + # _FieldID.TAGS_CONTENT.name, + # merged_content_tags if merged_content_tags else None, + # ) + # ) + # if tags_count: + # for _ in range(0, tags_count): + # json_fields.append( + # ( + # json_entry.id + 1, + # "TAGS", + # merged_tags if merged_tags else None, + # ) + # ) json_fields.sort() if not ( @@ -653,14 +649,17 @@ class JsonMigrationModal(QObject): with Session(self.sql_lib.engine) as session: for tag in self.sql_lib.tags: + if tag.id in range(0, 1000): + break tag_id = tag.id # Tag IDs start at 0 sql_subtags = set( session.scalars(select(TagSubtag.child_id).where(TagSubtag.parent_id == tag.id)) ) + # sql_subtags = sql_subtags.difference([x for x in range(0, 1000)]) + # JSON tags allowed self-parenting; SQL tags no longer allow this. - json_subtags = set(self.json_lib.get_tag(tag_id).subtag_ids).difference( - set([self.json_lib.get_tag(tag_id).id]) - ) + json_subtags = set(self.json_lib.get_tag(tag_id).subtag_ids) + json_subtags.discard(tag_id) logger.info( "[Subtag Parity]", @@ -675,7 +674,8 @@ class JsonMigrationModal(QObject): and (sql_subtags == json_subtags) ): self.discrepancies.append( - f"[Subtag Parity]:\nOLD (JSON):{json_subtags}\nNEW (SQL):{sql_subtags}" + f"[Subtag Parity][Tag ID: {tag_id}]:" + f"\nOLD (JSON):{json_subtags}\nNEW (SQL):{sql_subtags}" ) self.subtag_parity = False return self.subtag_parity @@ -693,6 +693,8 @@ class JsonMigrationModal(QObject): with Session(self.sql_lib.engine) as session: for tag in self.sql_lib.tags: + if tag.id in range(0, 1000): + break tag_id = tag.id # Tag IDs start at 0 sql_aliases = set( session.scalars(select(TagAlias.name).where(TagAlias.tag_id == tag.id)) @@ -711,7 +713,8 @@ class JsonMigrationModal(QObject): and (sql_aliases == json_aliases) ): self.discrepancies.append( - f"[Alias Parity]:\nOLD (JSON):{json_aliases}\nNEW (SQL):{sql_aliases}" + f"[Alias Parity][Tag ID: {tag_id}]:" + f"\nOLD (JSON):{json_aliases}\nNEW (SQL):{sql_aliases}" ) self.alias_parity = False return self.alias_parity @@ -725,6 +728,8 @@ class JsonMigrationModal(QObject): json_shorthand: str = None for tag in self.sql_lib.tags: + if tag.id in range(0, 1000): + break tag_id = tag.id # Tag IDs start at 0 sql_shorthand = tag.shorthand json_shorthand = self.json_lib.get_tag(tag_id).shorthand @@ -742,7 +747,8 @@ class JsonMigrationModal(QObject): and (sql_shorthand == json_shorthand) ): self.discrepancies.append( - f"[Shorthand Parity]:\nOLD (JSON):{json_shorthand}\nNEW (SQL):{sql_shorthand}" + f"[Shorthand Parity][Tag ID: {tag_id}]:" + f"\nOLD (JSON):{json_shorthand}\nNEW (SQL):{sql_shorthand}" ) self.shorthand_parity = False return self.shorthand_parity @@ -756,11 +762,13 @@ class JsonMigrationModal(QObject): json_color: str = None for tag in self.sql_lib.tags: + if tag.id in range(0, 1000): + break tag_id = tag.id # Tag IDs start at 0 sql_color = tag.color.name json_color = ( TagColor.get_color_from_str(self.json_lib.get_tag(tag_id).color).name - if self.json_lib.get_tag(tag_id).color != "" + if (self.json_lib.get_tag(tag_id).color) != "" else TagColor.DEFAULT.name ) @@ -773,7 +781,8 @@ class JsonMigrationModal(QObject): if not (sql_color is not None and json_color is not None and (sql_color == json_color)): self.discrepancies.append( - f"[Color Parity]:\nOLD (JSON):{json_color}\nNEW (SQL):{sql_color}" + f"[Color Parity][Tag ID: {tag_id}]:" + f"\nOLD (JSON):{json_color}\nNEW (SQL):{sql_color}" ) self.color_parity = False return self.color_parity diff --git a/tagstudio/src/qt/widgets/preview_panel.py b/tagstudio/src/qt/widgets/preview_panel.py index 88798e9e..5bf232d6 100644 --- a/tagstudio/src/qt/widgets/preview_panel.py +++ b/tagstudio/src/qt/widgets/preview_panel.py @@ -42,7 +42,6 @@ from src.core.library.alchemy.fields import ( FieldTypeEnum, # TagBoxField, TextField, - _FieldID, ) from src.core.library.alchemy.library import Library from src.core.media_types import MediaCategories @@ -907,7 +906,8 @@ class PreviewPanel(QWidget): # self.update_widgets(), # ) # ) - # # NOTE: Tag Boxes have no Edit Button (But will when you can convert field types) + # # NOTE: Tag Boxes have no Edit Button + # (But will when you can convert field types) # container.set_remove_callback( # lambda: self.remove_message_box( # prompt=self.remove_field_prompt(field.type.name), @@ -1070,9 +1070,9 @@ class PreviewPanel(QWidget): self.lib.remove_entry_field(field, entry_ids) - # if the field is meta tags, update the badges - if field.type_key == _FieldID.TAGS_META.value: - self.driver.update_badges(self.selected) + # # if the field is meta tags, update the badges + # if field.type_key == _FieldID.TAGS_META.value: + # self.driver.update_badges(self.selected) def update_field(self, field: BaseField, content: str) -> None: """Update a field in all selected Entries, given a field object.""" diff --git a/tagstudio/src/qt/widgets/tag_box.py b/tagstudio/src/qt/widgets/tag_box.py old mode 100755 new mode 100644 index e3805232..d31cefbb --- a/tagstudio/src/qt/widgets/tag_box.py +++ b/tagstudio/src/qt/widgets/tag_box.py @@ -1,185 +1,192 @@ -# Copyright (C) 2024 Travis Abendshien (CyanVoxel). -# Licensed under the GPL-3.0 License. -# Created for TagStudio: https://github.com/CyanVoxel/TagStudio +# # Copyright (C) 2024 Travis Abendshien (CyanVoxel). +# # Licensed under the GPL-3.0 License. +# # Created for TagStudio: https://github.com/CyanVoxel/TagStudio -import math -import typing +# import math +# import typing -import structlog -from PySide6.QtCore import Qt, Signal -from PySide6.QtWidgets import QPushButton -from src.core.constants import TAG_ARCHIVED, TAG_FAVORITE -from src.core.library import Entry, Tag -from src.core.library.alchemy.enums import FilterState - -# from src.core.library.alchemy.fields import TagBoxField -from src.qt.flowlayout import FlowLayout -from src.qt.modals.build_tag import BuildTagPanel -from src.qt.modals.tag_search import TagSearchPanel -from src.qt.translations import Translations -from src.qt.widgets.fields import FieldWidget -from src.qt.widgets.panel import PanelModal -from src.qt.widgets.tag import TagWidget - -if typing.TYPE_CHECKING: - from src.qt.ts_qt import QtDriver - -logger = structlog.get_logger(__name__) +# import structlog +# from PySide6.QtCore import Qt, Signal +# from PySide6.QtWidgets import QPushButton +# from src.core.constants import TAG_ARCHIVED, TAG_FAVORITE +# from src.core.library import Entry, Tag +# from src.core.library.alchemy.enums import FilterState -class TagBoxWidget(FieldWidget): - updated = Signal() - error_occurred = Signal(Exception) +# # from src.core.library.alchemy.fields import TagBoxField +# from src.qt.flowlayout import FlowLayout +# from src.qt.modals.build_tag import BuildTagPanel +# from src.qt.modals.tag_search import TagSearchPanel +# from src.qt.translations import Translations +# from src.qt.widgets.fields import FieldWidget +# from src.qt.widgets.panel import PanelModal +# from src.qt.widgets.tag import TagWidget - def __init__( - self, - field: TagBoxField, - title: str, - driver: "QtDriver", - ) -> None: - super().__init__(title) - assert isinstance(field, TagBoxField), f"field is {type(field)}" +# if typing.TYPE_CHECKING: +# from src.qt.ts_qt import QtDriver - self.field = field - self.driver = ( - driver # Used for creating tag click callbacks that search entries for that tag. - ) - self.setObjectName("tagBox") - self.base_layout = FlowLayout() - self.base_layout.enable_grid_optimizations(value=False) - self.base_layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(self.base_layout) +# logger = structlog.get_logger(__name__) - self.add_button = QPushButton() - self.add_button.setCursor(Qt.CursorShape.PointingHandCursor) - self.add_button.setMinimumSize(23, 23) - self.add_button.setMaximumSize(23, 23) - self.add_button.setText("+") - self.add_button.setStyleSheet( - f"QPushButton{{" - f"background: #1e1e1e;" - f"color: #FFFFFF;" - f"font-weight: bold;" - f"border-color: #333333;" - f"border-radius: 6px;" - f"border-style:solid;" - f"border-width:{math.ceil(self.devicePixelRatio())}px;" - f"padding-bottom: 5px;" - f"font-size: 20px;" - f"}}" - f"QPushButton::hover" - f"{{" - f"border-color: #CCCCCC;" - f"background: #555555;" - f"}}" - ) - tsp = TagSearchPanel(self.driver.lib) - tsp.tag_chosen.connect(lambda x: self.add_tag_callback(x)) - self.add_modal = PanelModal(tsp, title) - Translations.translate_with_setter(self.add_modal.setWindowTitle, "tag.add.plural") - self.add_button.clicked.connect( - lambda: ( - tsp.update_tags(), - self.add_modal.show(), - ) - ) - self.set_tags(field.tags) +# class TagBoxWidget(FieldWidget): +# updated = Signal() +# error_occurred = Signal(Exception) - def set_field(self, field: TagBoxField): - self.field = field +# def __init__( +# self, +# field: TagBoxField, +# title: str, +# driver: "QtDriver", +# ) -> None: +# super().__init__(title) - def set_tags(self, tags: typing.Iterable[Tag]): - tags_ = sorted(list(tags), key=lambda tag: tag.name) - is_recycled = False - while self.base_layout.itemAt(0) and self.base_layout.itemAt(1): - self.base_layout.takeAt(0).widget().deleteLater() - is_recycled = True +# assert isinstance(field, TagBoxField), f"field is {type(field)}" - for tag in tags_: - tag_widget = TagWidget(tag, has_edit=True, has_remove=True) - tag_widget.on_click.connect( - lambda tag_id=tag.id: ( - self.driver.main_window.searchField.setText(f"tag_id:{tag_id}"), - self.driver.filter_items(FilterState.from_tag_id(tag_id)), - ) - ) +# self.field = field +# self.driver = ( +# driver # Used for creating tag click callbacks that search entries for that tag. +# ) +# self.setObjectName("tagBox") +# self.base_layout = FlowLayout() +# self.base_layout.enable_grid_optimizations(value=False) +# self.base_layout.setContentsMargins(0, 0, 0, 0) +# self.setLayout(self.base_layout) - tag_widget.on_remove.connect( - lambda tag_id=tag.id: ( - self.remove_tag(tag_id), - self.driver.preview_panel.update_widgets(), - ) - ) - tag_widget.on_edit.connect(lambda t=tag: self.edit_tag(t)) - self.base_layout.addWidget(tag_widget) - # Move or add the '+' button. - if is_recycled: - self.base_layout.addWidget(self.base_layout.takeAt(0).widget()) - else: - self.base_layout.addWidget(self.add_button) +# self.add_button = QPushButton() +# self.add_button.setCursor(Qt.CursorShape.PointingHandCursor) +# self.add_button.setMinimumSize(23, 23) +# self.add_button.setMaximumSize(23, 23) +# self.add_button.setText("+") +# self.add_button.setStyleSheet( +# f"QPushButton{{" +# f"background: #1e1e1e;" +# f"color: #FFFFFF;" +# f"font-weight: bold;" +# f"border-color: #333333;" +# f"border-radius: 6px;" +# f"border-style:solid;" +# f"border-width:{math.ceil(self.devicePixelRatio())}px;" +# f"padding-bottom: 5px;" +# f"font-size: 20px;" +# f"}}" +# f"QPushButton::hover" +# f"{{" +# f"border-color: #CCCCCC;" +# f"background: #555555;" +# f"}}" +# ) +# tsp = TagSearchPanel(self.driver.lib) +# tsp.tag_chosen.connect(lambda x: self.add_tag_callback(x)) +# self.add_modal = PanelModal(tsp, title) +# Translations.translate_with_setter(self.add_modal.setWindowTitle, "tag.add.plural") +# self.add_button.clicked.connect( +# lambda: ( +# tsp.update_tags(), +# self.add_modal.show(), +# ) +# ) - # Handles an edge case where there are no more tags and the '+' button - # doesn't move all the way to the left. - if self.base_layout.itemAt(0) and not self.base_layout.itemAt(1): - self.base_layout.update() - def edit_tag(self, tag: Tag): - assert isinstance(tag, Tag), f"tag is {type(tag)}" - build_tag_panel = BuildTagPanel(self.driver.lib, tag=tag) +# self.set_tags(field.tags) - self.edit_modal = PanelModal( - build_tag_panel, - title=tag.name, # TODO - display name including subtags - done_callback=self.driver.preview_panel.update_widgets, - has_save=True, - ) - Translations.translate_with_setter(self.edit_modal.setWindowTitle, "tag.edit") - # TODO - this was update_tag() - self.edit_modal.saved.connect( - lambda: self.driver.lib.update_tag( - build_tag_panel.build_tag(), - subtag_ids=set(build_tag_panel.subtag_ids), - alias_names=set(build_tag_panel.alias_names), - alias_ids=set(build_tag_panel.alias_ids), - ) - ) - self.edit_modal.show() +# def set_field(self, field: TagBoxField): +# self.field = field - def add_tag_callback(self, tag_id: int): - logger.info("add_tag_callback", tag_id=tag_id, selected=self.driver.selected) +# def set_tags(self, tags: typing.Iterable[Tag]): +# tags_ = sorted(list(tags), key=lambda tag: tag.name) +# is_recycled = False +# while self.base_layout.itemAt(0) and self.base_layout.itemAt(1): +# self.base_layout.takeAt(0).widget().deleteLater() +# is_recycled = True - tag = self.driver.lib.get_tag(tag_id=tag_id) - for idx in self.driver.selected: - entry: Entry = self.driver.frame_content[idx] +# for tag in tags_: +# tag_widget = TagWidget(tag, has_edit=True, has_remove=True) +# tag_widget.on_click.connect( +# lambda tag_id=tag.id: ( +# self.driver.main_window.searchField.setText(f"tag_id:{tag_id}"), +# self.driver.filter_items(FilterState.from_tag_id(tag_id)), +# ) +# ) - if not self.driver.lib.add_field_tag(entry, tag, self.field.type_key): - # TODO - add some visible error - self.error_occurred.emit(Exception("Failed to add tag")) - self.updated.emit() +# tag_widget.on_remove.connect( +# lambda tag_id=tag.id: ( +# self.remove_tag(tag_id), +# self.driver.preview_panel.update_widgets(), +# ) +# ) +# tag_widget.on_edit.connect(lambda t=tag: self.edit_tag(t)) +# self.base_layout.addWidget(tag_widget) - if tag_id in (TAG_FAVORITE, TAG_ARCHIVED): - self.driver.update_badges() +# # Move or add the '+' button. +# if is_recycled: +# self.base_layout.addWidget(self.base_layout.takeAt(0).widget()) +# else: +# self.base_layout.addWidget(self.add_button) - def edit_tag_callback(self, tag: Tag): - self.driver.lib.update_tag(tag) +# # Handles an edge case where there are no more tags and the '+' button +# # doesn't move all the way to the left. +# if self.base_layout.itemAt(0) and not self.base_layout.itemAt(1): +# self.base_layout.update() - def remove_tag(self, tag_id: int): - logger.info( - "remove_tag", - selected=self.driver.selected, - field_type=self.field.type, - ) - for grid_idx in self.driver.selected: - entry = self.driver.frame_content[grid_idx] - self.driver.lib.remove_field_tag(entry, tag_id, self.field.type_key) +# self.edit_modal = PanelModal( +# build_tag_panel, +# title=tag.name, # TODO - display name including subtags +# done_callback=self.driver.preview_panel.update_widgets, +# has_save=True, +# ) +# Translations.translate_with_setter(self.edit_modal.setWindowTitle, "tag.edit") +# # TODO - this was update_tag() +# self.edit_modal.saved.connect( +# lambda: self.driver.lib.update_tag( +# build_tag_panel.build_tag(), +# subtag_ids=set(build_tag_panel.subtag_ids), +# alias_names=set(build_tag_panel.alias_names), +# alias_ids=set(build_tag_panel.alias_ids), +# ) +# ) +# self.edit_modal.show() - self.updated.emit() +# def edit_tag(self, tag: Tag): +# assert isinstance(tag, Tag), f"tag is {type(tag)}" +# build_tag_panel = BuildTagPanel(self.driver.lib, tag=tag) - if tag_id in (TAG_FAVORITE, TAG_ARCHIVED): - self.driver.update_badges() + +# def add_tag_callback(self, tag_id: int): +# logger.info("add_tag_callback", tag_id=tag_id, selected=self.driver.selected) + +# tag = self.driver.lib.get_tag(tag_id=tag_id) +# for idx in self.driver.selected: +# entry: Entry = self.driver.frame_content[idx] + +# if not self.driver.lib.add_field_tag(entry, tag, self.field.type_key): +# # TODO - add some visible error +# self.error_occurred.emit(Exception("Failed to add tag")) + +# self.updated.emit() + +# if tag_id in (TAG_FAVORITE, TAG_ARCHIVED): +# self.driver.update_badges() + +# def edit_tag_callback(self, tag: Tag): +# self.driver.lib.update_tag(tag) + +# def remove_tag(self, tag_id: int): +# logger.info( +# "remove_tag", +# selected=self.driver.selected, +# field_type=self.field.type, +# ) + +# for grid_idx in self.driver.selected: +# entry = self.driver.frame_content[grid_idx] +# self.driver.lib.remove_field_tag(entry, tag_id, self.field.type_key) + +# self.updated.emit() + +# if tag_id in (TAG_FAVORITE, TAG_ARCHIVED): +# self.driver.update_badges() diff --git a/tagstudio/tests/conftest.py b/tagstudio/tests/conftest.py index 7a59d05e..72924e20 100644 --- a/tagstudio/tests/conftest.py +++ b/tagstudio/tests/conftest.py @@ -12,7 +12,6 @@ sys.path.insert(0, str(CWD.parent)) from src.core.library import Entry, Library, Tag from src.core.library import alchemy as backend from src.core.library.alchemy.enums import TagColor -from src.core.library.alchemy.fields import TagBoxField, _FieldID from src.qt.ts_qt import QtDriver @@ -90,26 +89,26 @@ def library(request): fields=lib.default_fields, ) - entry.tag_box_fields = [ - TagBoxField(type_key=_FieldID.TAGS.name, tags={tag}, position=0), - TagBoxField( - type_key=_FieldID.TAGS_META.name, - position=0, - ), - ] + # entry.tag_box_fields = [ + # TagBoxField(type_key=_FieldID.TAGS.name, tags={tag}, position=0), + # TagBoxField( + # type_key=_FieldID.TAGS_META.name, + # position=0, + # ), + # ] entry2 = Entry( folder=lib.folder, path=pathlib.Path("one/two/bar.md"), fields=lib.default_fields, ) - entry2.tag_box_fields = [ - TagBoxField( - tags={tag2}, - type_key=_FieldID.TAGS_META.name, - position=0, - ), - ] + # entry2.tag_box_fields = [ + # TagBoxField( + # tags={tag2}, + # type_key=_FieldID.TAGS_META.name, + # position=0, + # ), + # ] assert lib.add_entries([entry, entry2]) assert len(lib.tags) == 5 diff --git a/tagstudio/tests/qt/test_tag_widget.py b/tagstudio/tests/qt/test_tag_widget.py index 86158bf4..2e402f50 100644 --- a/tagstudio/tests/qt/test_tag_widget.py +++ b/tagstudio/tests/qt/test_tag_widget.py @@ -1,110 +1,110 @@ -from unittest.mock import patch +# from unittest.mock import patch -from src.core.library.alchemy.fields import _FieldID -from src.qt.modals.build_tag import BuildTagPanel -from src.qt.widgets.tag import TagWidget -from src.qt.widgets.tag_box import TagBoxWidget +# from src.core.library.alchemy.fields import _FieldID +# from src.qt.modals.build_tag import BuildTagPanel +# from src.qt.widgets.tag import TagWidget +# from src.qt.widgets.tag_box import TagBoxWidget -def test_tag_widget(qtbot, library, qt_driver): - # given - entry = next(library.get_entries(with_joins=True)) - field = entry.tag_box_fields[0] +# def test_tag_widget(qtbot, library, qt_driver): +# # given +# entry = next(library.get_entries(with_joins=True)) +# field = entry.tag_box_fields[0] - tag_widget = TagBoxWidget(field, "title", qt_driver) +# tag_widget = TagBoxWidget(field, "title", qt_driver) - qtbot.add_widget(tag_widget) +# qtbot.add_widget(tag_widget) - assert not tag_widget.add_modal.isVisible() +# assert not tag_widget.add_modal.isVisible() - # when/then check no exception is raised - tag_widget.add_button.clicked.emit() - # check `tag_widget.add_modal` is visible - assert tag_widget.add_modal.isVisible() +# # when/then check no exception is raised +# tag_widget.add_button.clicked.emit() +# # check `tag_widget.add_modal` is visible +# assert tag_widget.add_modal.isVisible() -def test_tag_widget_add_existing_raises(library, qt_driver, entry_full): - # Given - tag_field = [f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name][0] - assert len(entry_full.tags) == 1 - tag = next(iter(entry_full.tags)) +# def test_tag_widget_add_existing_raises(library, qt_driver, entry_full): +# # Given +# tag_field = [f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name][0] +# assert len(entry_full.tags) == 1 +# tag = next(iter(entry_full.tags)) - # When - tag_widget = TagBoxWidget(tag_field, "title", qt_driver) - tag_widget.driver.frame_content = [entry_full] - tag_widget.driver.selected = [0] +# # When +# tag_widget = TagBoxWidget(tag_field, "title", qt_driver) +# tag_widget.driver.frame_content = [entry_full] +# tag_widget.driver.selected = [0] - # Then - with patch.object(tag_widget, "error_occurred") as mocked: - tag_widget.add_modal.widget.tag_chosen.emit(tag.id) - assert mocked.emit.called +# # Then +# with patch.object(tag_widget, "error_occurred") as mocked: +# tag_widget.add_modal.widget.tag_chosen.emit(tag.id) +# assert mocked.emit.called -def test_tag_widget_add_new_pass(qtbot, library, qt_driver, generate_tag): - # Given - entry = next(library.get_entries(with_joins=True)) - field = entry.tag_box_fields[0] +# def test_tag_widget_add_new_pass(qtbot, library, qt_driver, generate_tag): +# # Given +# entry = next(library.get_entries(with_joins=True)) +# field = entry.tag_box_fields[0] - tag = generate_tag(name="new_tag") - library.add_tag(tag) +# tag = generate_tag(name="new_tag") +# library.add_tag(tag) - tag_widget = TagBoxWidget(field, "title", qt_driver) +# tag_widget = TagBoxWidget(field, "title", qt_driver) - qtbot.add_widget(tag_widget) +# qtbot.add_widget(tag_widget) - tag_widget.driver.selected = [0] - with patch.object(tag_widget, "error_occurred") as mocked: - # When - tag_widget.add_modal.widget.tag_chosen.emit(tag.id) +# tag_widget.driver.selected = [0] +# with patch.object(tag_widget, "error_occurred") as mocked: +# # When +# tag_widget.add_modal.widget.tag_chosen.emit(tag.id) - # Then - assert not mocked.emit.called +# # Then +# assert not mocked.emit.called -def test_tag_widget_remove(qtbot, qt_driver, library, entry_full): - tag = list(entry_full.tags)[0] - assert tag +# def test_tag_widget_remove(qtbot, qt_driver, library, entry_full): +# tag = list(entry_full.tags)[0] +# assert tag - assert entry_full.tag_box_fields - tag_field = [f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name][0] +# assert entry_full.tag_box_fields +# tag_field = [f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name][0] - tag_widget = TagBoxWidget(tag_field, "title", qt_driver) - tag_widget.driver.selected = [0] +# tag_widget = TagBoxWidget(tag_field, "title", qt_driver) +# tag_widget.driver.selected = [0] - qtbot.add_widget(tag_widget) +# qtbot.add_widget(tag_widget) - tag_widget = tag_widget.base_layout.itemAt(0).widget() - assert isinstance(tag_widget, TagWidget) +# tag_widget = tag_widget.base_layout.itemAt(0).widget() +# assert isinstance(tag_widget, TagWidget) - tag_widget.remove_button.clicked.emit() +# tag_widget.remove_button.clicked.emit() - entry = next(qt_driver.lib.get_entries(with_joins=True)) - assert not entry.tag_box_fields[0].tags +# entry = next(qt_driver.lib.get_entries(with_joins=True)) +# assert not entry.tag_box_fields[0].tags -def test_tag_widget_edit(qtbot, qt_driver, library, entry_full): - # Given - entry = next(library.get_entries(with_joins=True)) - library.add_tag(list(entry.tags)[0]) - tag = library.get_tag(list(entry.tags)[0].id) - assert tag +# def test_tag_widget_edit(qtbot, qt_driver, library, entry_full): +# # Given +# entry = next(library.get_entries(with_joins=True)) +# library.add_tag(list(entry.tags)[0]) +# tag = library.get_tag(list(entry.tags)[0].id) +# assert tag - assert entry_full.tag_box_fields - tag_field = [f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name][0] +# assert entry_full.tag_box_fields +# tag_field = [f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name][0] - tag_box_widget = TagBoxWidget(tag_field, "title", qt_driver) - tag_box_widget.driver.selected = [0] +# tag_box_widget = TagBoxWidget(tag_field, "title", qt_driver) +# tag_box_widget.driver.selected = [0] - qtbot.add_widget(tag_box_widget) +# qtbot.add_widget(tag_box_widget) - tag_widget = tag_box_widget.base_layout.itemAt(0).widget() - assert isinstance(tag_widget, TagWidget) +# tag_widget = tag_box_widget.base_layout.itemAt(0).widget() +# assert isinstance(tag_widget, TagWidget) - # When - tag_box_widget.edit_tag(tag) +# # When +# tag_box_widget.edit_tag(tag) - # Then - panel = tag_box_widget.edit_modal.widget - assert isinstance(panel, BuildTagPanel) - assert panel.tag.name == tag.name - assert panel.name_field.text() == tag.name +# # Then +# panel = tag_box_widget.edit_modal.widget +# assert isinstance(panel, BuildTagPanel) +# assert panel.tag.name == tag.name +# assert panel.name_field.text() == tag.name diff --git a/tagstudio/tests/test_json_migration.py b/tagstudio/tests/test_json_migration.py index c8ad58e6..62078a07 100644 --- a/tagstudio/tests/test_json_migration.py +++ b/tagstudio/tests/test_json_migration.py @@ -6,6 +6,7 @@ import pathlib from time import time from src.core.enums import LibraryPrefs +from src.core.library.alchemy.library import DEFAULT_TAG_DIFF from src.qt.widgets.migration_modal import JsonMigrationModal CWD = pathlib.Path(__file__) @@ -29,7 +30,7 @@ def test_json_migration(): # Tags ===================================================================== # Count - assert len(modal.json_lib.tags) == len(modal.sql_lib.tags) + assert len(modal.json_lib.tags) == (len(modal.sql_lib.tags) - DEFAULT_TAG_DIFF) # Shorthand Parity assert modal.check_shorthand_parity() # Subtag/Parent Tag Parity