From be08cabc4f569ae5ee0ae512fb06516e3925e8c5 Mon Sep 17 00:00:00 2001 From: Travis Abendshien <46939827+CyanVoxel@users.noreply.github.com> Date: Sat, 9 May 2026 10:51:06 -0700 Subject: [PATCH] fix: implement review feedback and misc fixes --- src/tagstudio/core/library/alchemy/fields.py | 28 +- src/tagstudio/core/library/alchemy/library.py | 260 +++++++----------- src/tagstudio/qt/mixed/field_containers.py | 23 +- tests/macros/test_dupe_files.py | 8 +- tests/test_library.py | 22 +- 5 files changed, 153 insertions(+), 188 deletions(-) diff --git a/src/tagstudio/core/library/alchemy/fields.py b/src/tagstudio/core/library/alchemy/fields.py index e674b955..4331f578 100644 --- a/src/tagstudio/core/library/alchemy/fields.py +++ b/src/tagstudio/core/library/alchemy/fields.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship @@ -44,12 +44,38 @@ class TextField(BaseField): value: Mapped[str | None] is_multiline: Mapped[bool] = mapped_column(nullable=False, default=False) + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, TextField): + return False + + return (self.name, self.value, self.is_multiline) == ( + other.name, + other.value, + other.is_multiline, + ) + + @override + def __hash__(self) -> int: + return hash((self.name, self.value, self.is_multiline)) + class DatetimeField(BaseField): __tablename__ = "datetime_fields" value: Mapped[str | None] + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, DatetimeField): + return False + + return (self.name, self.value) == (other.name, other.value) + + @override + def __hash__(self) -> int: + return hash((self.name, self.value)) + class BaseFieldTemplate(Base): __abstract__ = True diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index a87a1d78..779a9c40 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from os import makedirs from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from uuid import uuid4 from warnings import catch_warnings @@ -327,24 +327,21 @@ class Library: self.add_tags_to_entries(entry_ids=entry.id + 1, tag_ids=value) else: try: - if LEGACY_FIELD_MAP[legacy_field_id]["type"] == TextField: - self.add_text_field_to_entry( - entry_id=( - entry.id + 1 - ), # NOTE: JSON IDs start at 0 instead of 1 - name=str(LEGACY_FIELD_MAP[legacy_field_id]["name"]), + # NOTE: JSON IDs start at 0 instead of 1 + field_info = LEGACY_FIELD_MAP[legacy_field_id] + if field_info["type"] == TextField: + text_field = TextField( + name=str(field_info["name"]), value=value, - is_multiline=bool( - LEGACY_FIELD_MAP[legacy_field_id]["is_multiline"] - ), + is_multiline=bool(field_info["is_multiline"]), ) - elif LEGACY_FIELD_MAP[legacy_field_id]["type"] == DatetimeField: - self.add_datetime_field_to_entry( - entry_id=( - entry.id + 1 - ), # NOTE: JSON IDs start at 0 instead of 1 - name=str(LEGACY_FIELD_MAP[legacy_field_id]["name"]), - value=value, + self.add_field_to_entry(entry_id=(entry.id + 1), field=text_field) + elif field_info["type"] == DatetimeField: + datetime_field = DatetimeField( + name=str(field_info["name"]), value=value + ) + self.add_field_to_entry( + entry_id=(entry.id + 1), field=datetime_field ) except Exception as e: logger.error( @@ -490,18 +487,14 @@ class Library: # Add default field templates if is_new: - for ft in get_default_field_templates(): + for template in get_default_field_templates(): try: - if type(ft) is TextFieldTemplate: - session.add( - TextFieldTemplate(name=ft.name, is_multiline=ft.is_multiline) - ) - elif type(ft) is DatetimeFieldTemplate: - session.add(DatetimeFieldTemplate(name=ft.name)) - + session.add(template) session.commit() except IntegrityError: - logger.info("[Library] FieldTemplate already exists", field_template=ft) + logger.info( + "[Library] FieldTemplate already exists", field_template=template + ) session.rollback() # Ensure version rows are present @@ -575,7 +568,6 @@ class Library: self.__apply_db104_migrations(session, library_dir) if loaded_db_version < 200: self.__apply_db200_migrations(session) - self.__apply_db200_data_repairs(session) # Update DB_VERSION if loaded_db_version < DB_VERSION: @@ -767,55 +759,54 @@ class Library: """Migrate DB to DB_VERSION 200.""" with session: # Drop unused 'boolean_fields' and 'value_type' tables + logger.info( + "[Library][Migration][200] Dropping boolean_fields and value_type tables..." + ) session.execute(text("DROP TABLE boolean_fields")) session.execute(text("DROP TABLE value_type")) - session.commit() - logger.info("[Library][Migration][200] Dropped boolean_fields and value_type tables") # Add 'name' column to text_fields and datetime_fields tables - stmt = text('ALTER TABLE text_fields ADD COLUMN name VARCHAR NOT NULL DEFAULT ""') + logger.info("[Library][Migration][200] Adding name columns to field tables...") + stmt = text('ALTER TABLE text_fields ADD COLUMN name VARCHAR DEFAULT ""') session.execute(stmt) - stmt = text('ALTER TABLE datetime_fields ADD COLUMN name VARCHAR NOT NULL DEFAULT ""') + stmt = text('ALTER TABLE datetime_fields ADD COLUMN name VARCHAR DEFAULT ""') session.execute(stmt) - session.commit() - logger.info("[Library][Migration][200] Added name columns to field tables") # Drop unnecessary 'position' columns + logger.info("[Library][Migration][200] Dropping position columns to field tables...") session.execute(text("ALTER TABLE datetime_fields DROP COLUMN position")) session.execute(text("ALTER TABLE text_fields DROP COLUMN position")) - session.commit() - logger.info("[Library][Migration][200] Dropped position columns to field tables") # Add 'is_multiline' column to text_fields table + logger.info("[Library][Migration][200] Adding is_multiline column to text_fields...") stmt = text( "ALTER TABLE text_fields ADD COLUMN is_multiline BOOLEAN NOT NULL DEFAULT 0" ) session.execute(stmt) - session.commit() - logger.info("[Library][Migration][200] Added is_multiline column to text_fields table") + session.flush() # Move values from old `type_key` columns into new `name` columns + logger.info("[Library][Migration][200] Moving values from type_key columns to name...") session.execute(text("UPDATE text_fields SET name = type_key")) session.execute(text("UPDATE datetime_fields SET name = type_key")) - session.commit() - logger.info("[Library][Migration][200] Moved values from type_key columns to name") + session.flush() # TODO: Remove `type_key` columns from text_fields and datetime_fields tables. # See issue with dropping columns foreign keys in SQLite: # https://www.sqlite.org/lang_altertable.html#making_other_kinds_of_table_schema_changes # Change `name` values to title case + logger.info("[Library][Migration][200] Normalizing TextField names...") for text_field in session.execute(select(TextField)).scalars(): # NOTE: The only exception to the "Title Case" conversion is the "URL" field. text_field.name = text_field.name.title().replace("Url", "URL").replace("_", " ") - logger.info("[Library][Migration][200] Normalized TextField names") - session.commit() + logger.info("[Library][Migration][200] Normalizing DatetimeField names...") for datetime_field in session.execute(select(DatetimeField)).scalars(): datetime_field.name = datetime_field.name.title().replace("_", " ") - logger.info("[Library][Migration][200] Normalized DatetimeField names") - session.commit() + session.flush() # Add correct `is_multiline` values to text_fields table + logger.info("[Library][Migration][200] Updating is_multiline for legacy TEXT_BOXes...") text_boxes = [ x.get("name") for x in LEGACY_FIELD_MAP.values() if x.get("is_multiline") is True ] @@ -823,17 +814,10 @@ class Library: update(TextField).where(TextField.name.in_(text_boxes)).values(is_multiline=True) ) session.execute(update_stmt) - logger.info( - "[Library][Migration][200] Updated is_multiline columns for legacy TEXT_BOX fields" - ) - session.commit() + session.flush() - pass - - def __apply_db200_data_repairs(self, session: Session): - logger.info("[Library][Migration] Repairing data for library below version 200...") - with session: # Repair legacy "Description" fields to use is_multiline = True + logger.info("[Library][Migration][200] Repairing legacy Description fields...") desc_stmt = ( update(TextField) .where(TextField.name == "Description" and TextField.is_multiline == False) # noqa: E712 @@ -842,27 +826,26 @@ class Library: session.execute(desc_stmt) # Repair legacy "Comments" fields to use is_multiline = True + logger.info("[Library][Migration][200] Repairing legacy Comment fields...") comm_stmt = ( update(TextField) .where(TextField.name == "Comments" and TextField.is_multiline == False) # noqa: E712 .values(is_multiline=True) ) session.execute(comm_stmt) - session.commit() # Add default field templates - for ft in get_default_field_templates(): + logger.info("[Library][Migration][200] Adding default field templates...") + for template in get_default_field_templates(): try: - if type(ft) is TextFieldTemplate: - session.add(TextFieldTemplate(name=ft.name, is_multiline=ft.is_multiline)) - elif type(ft) is DatetimeFieldTemplate: - session.add(DatetimeFieldTemplate(name=ft.name)) - - session.commit() + session.add(template) + session.flush() except IntegrityError: - logger.info("[Library] FieldTemplate already exists", field_template=ft) + logger.error("[Library] FieldTemplate already exists", field_template=template) session.rollback() + session.commit() + @property def field_templates(self) -> Sequence[BaseFieldTemplate]: with Session(self.engine) as session: @@ -1297,20 +1280,20 @@ class Library: field: BaseField, entry_ids: list[int], ) -> None: - field_ = type(field) + field_type = type(field) logger.info( "remove_entry_field", field=field, - type=field_, + type=field_type, entry_ids=entry_ids, ) with Session(self.engine) as session: # remove all fields matching entry and field_type - delete_stmt = delete(field_).where( + delete_stmt = delete(field_type).where( and_( - field_.id == field.id, + field_type.id == field.id, ) ) @@ -1324,14 +1307,14 @@ class Library: if isinstance(entry_ids, int): entry_ids = [entry_ids] - field_ = type(field) + field_type = type(field) with Session(self.engine) as session: update_stmt = ( - update(field_) + update(field_type) .where( and_( - field_.id == field.id, + field_type.id == field.id, ) ) .values(value=value, is_multiline=is_multiline) @@ -1350,14 +1333,14 @@ class Library: if isinstance(entry_ids, int): entry_ids = [entry_ids] - field_ = type(field) + field_type = type(field) with Session(self.engine) as session: update_stmt = ( - update(field_) + update(field_type) .where( and_( - field_.id == field.id, + field_type.id == field.id, ) ) .values(value=value) @@ -1366,61 +1349,51 @@ class Library: session.execute(update_stmt) session.commit() - def add_text_field_to_entry( - self, entry_id: int, name: str, value: str | None = None, is_multiline: bool = False - ) -> bool: - """Add a TextField field to an Entry.""" - logger.info( - "[Library] Adding text field to entry", - entry_id=entry_id, - name=name, - value=value, - is_multiline=is_multiline, - ) + def add_field_to_entry(self, entry_id: int, field: BaseField) -> bool: + """Add a field object to an Entry.""" + if type(field) is TextField: + logger.info( + "[Library] Adding TextField to entry", + entry_id=entry_id, + name=field.name, + value=field.value, + is_multiline=field.is_multiline, + ) - field = TextField(entry_id=entry_id, name=name, value=value, is_multiline=is_multiline) + field = TextField( + entry_id=entry_id, + name=field.name, + value=field.value, + is_multiline=field.is_multiline, + ) - with Session(self.engine) as session: - try: - session.add(field) - session.flush() - session.commit() - except IntegrityError as e: - logger.error(e) - session.rollback() - return False + with Session(self.engine) as session: + try: + session.add(field) + session.commit() + except IntegrityError as e: + logger.error(e) + session.rollback() + return False - return True + elif type(field) is DatetimeField: + logger.info( + "[Library] Adding DatetimeField to entry", + entry_id=entry_id, + name=field.name, + value=field.value, + ) - def add_datetime_field_to_entry( - self, - entry_id: int, - name: str, - value: str | None = None, - ) -> bool: - """Add a DatetimeField field to an Entry.""" - logger.info( - "[Library] Adding datetime field to entry", - entry_id=entry_id, - name=name, - value=value, - ) + field = DatetimeField(entry_id=entry_id, name=field.name, value=field.value) - field = DatetimeField( - entry_id=entry_id, - name=name, - value=value, - ) - - with Session(self.engine) as session: - try: - session.add(field) - session.flush() - session.commit() - except IntegrityError as e: - logger.error(e) - session.rollback() - return False + with Session(self.engine) as session: + try: + session.add(field) + session.commit() + except IntegrityError as e: + logger.error(e) + session.rollback() + return False return True @@ -1963,55 +1936,22 @@ class Library: def mirror_entry_fields(self, entries: list[Entry]) -> None: """Mirror fields among multiple Entry items.""" - all_tuples_to_fields_map = {} + all_fields: set[BaseField] = set() + logger.info("[Library][mirror_fields]", all_fields=all_fields) # Track all fields across all entries for entry in entries: for field in entry.fields: - field_tuple: tuple | None = None - if type(field) is TextField: - field_tuple = (type(field), field.name, field.value, field.is_multiline) - elif type(field) is DatetimeField: - field_tuple = (type(field), field.name, field.value) - all_tuples_to_fields_map[field_tuple] = field + all_fields.add(field) logger.info( "[Library][mirror_fields]", entry_id=entry.id, field_count_before=len(entry.fields) ) # Apply all (remaining) fields to all entries, avoiding duplicates for entry in entries: - for field_tuple, field in all_tuples_to_fields_map.items(): # pyright: ignore[reportUnknownVariableType] - entry_field_tuples: set[tuple[Any, ...]] = set() # pyright: ignore[reportExplicitAny] - # Locally process the entry's fields into parsable tuples - for entry_field in entry.fields: - entry_field_tuple: tuple | None = None - if type(entry_field) is TextField: - entry_field_tuple = ( - type(entry_field), - entry_field.name, - entry_field.value, - entry_field.is_multiline, - ) - entry_field_tuples.add(entry_field_tuple) - elif type(entry_field) is DatetimeField: - entry_field_tuple = (type(entry_field), entry_field.name, entry_field.value) - entry_field_tuples.add(entry_field_tuple) - - if field_tuple not in entry_field_tuples: - if type(field) is TextField: - self.add_text_field_to_entry( - entry_id=entry.id, - name=field.name, - value=field.value, - is_multiline=field.is_multiline, - ) - elif type(field) is DatetimeField: - self.add_datetime_field_to_entry( - entry_id=entry.id, name=field.name, value=field.value - ) - logger.info( - "[Library][mirror_fields]", entry_id=entry.id, field_count_after=len(entry.fields) - ) + for field in all_fields: + if field not in entry.fields: + self.add_field_to_entry(entry_id=entry.id, field=field) def merge_entries(self, from_entry: Entry, into_entry: Entry) -> bool: """Add fields and tags from the first entry to the second, and then delete the first.""" diff --git a/src/tagstudio/qt/mixed/field_containers.py b/src/tagstudio/qt/mixed/field_containers.py index 594738f7..2a7da707 100644 --- a/src/tagstudio/qt/mixed/field_containers.py +++ b/src/tagstudio/qt/mixed/field_containers.py @@ -220,23 +220,18 @@ class FieldContainers(QWidget): ) for entry_id in self.driver.selected: for field in field_list: - field_: BaseFieldTemplate = field.data(Qt.ItemDataRole.UserRole) + template: BaseFieldTemplate = field.data(Qt.ItemDataRole.UserRole) logger.info( "[FieldContainers][add_field_to_selected] Adding field", - name=field_.name, - type=field_.__class__.__name__, + name=template.name, + type=template.__class__.__name__, ) - if type(field_) is TextFieldTemplate: - self.lib.add_text_field_to_entry( - entry_id=entry_id, - name=field_.name, - is_multiline=field_.is_multiline, - ) - elif type(field_) is DatetimeFieldTemplate: - self.lib.add_datetime_field_to_entry( - entry_id=entry_id, - name=field_.name, - ) + if type(template) is TextFieldTemplate: + text_field = TextField(name=template.name, is_multiline=template.is_multiline) + self.lib.add_field_to_entry(entry_id, text_field) + elif type(template) is DatetimeFieldTemplate: + datetime_field = DatetimeField(name=template.name) + self.lib.add_field_to_entry(entry_id, datetime_field) def add_tags_to_selected(self, tags: int | list[int]): """Add list of tags to one or more selected items. diff --git a/tests/macros/test_dupe_files.py b/tests/macros/test_dupe_files.py index dac71127..05053555 100644 --- a/tests/macros/test_dupe_files.py +++ b/tests/macros/test_dupe_files.py @@ -4,7 +4,7 @@ from pathlib import Path -from tagstudio.core.library.alchemy.fields import TextField +from tagstudio.core.library.alchemy.fields import BaseField, TextField from tagstudio.core.library.alchemy.library import Library from tagstudio.core.library.alchemy.models import Entry from tagstudio.core.library.alchemy.registries.dupe_files_registry import DupeFilesRegistry @@ -17,16 +17,18 @@ def test_refresh_dupe_files(library: Library): library.library_dir = Path("/tmp/") folder = unwrap(library.folder) + fields: list[BaseField] = [TextField(name="Title", value="I'm a Test Title")] + entry = Entry( folder=folder, path=Path("bar/foo.txt"), - fields=[TextField(name="Title", value="I'm a Test Title")], + fields=fields, ) entry2 = Entry( folder=folder, path=Path("foo/foo.txt"), - fields=[TextField(name="Title", value="I'm a Test Title")], + fields=fields, ) library.add_entries([entry, entry2]) diff --git a/tests/test_library.py b/tests/test_library.py index f2447391..c8e7524c 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -207,18 +207,13 @@ def test_remove_entry_field(library: Library, entry_full: Entry): assert not entry.text_fields -def test_remove_text_field_entry_with_multiple_field(library: Library, entry_full: Entry): +def test_remove_text_field_entry_with_multiple_fields(library: Library, entry_full: Entry): # Given title_field = entry_full.text_fields[0] # When # add identical field - assert library.add_text_field_to_entry( - entry_full.id, - name=title_field.name, - value=title_field.value, - is_multiline=title_field.is_multiline, - ) + assert library.add_field_to_entry(entry_full.id, field=title_field) # remove entry field library.remove_entry_field(title_field, [entry_full.id]) @@ -243,7 +238,8 @@ def test_update_entry_with_multiple_identical_text_fields(library: Library, entr # When # add identical field - library.add_text_field_to_entry(entry_full.id, name="Title", value="") + empty_title = TextField(name="Title", value="") + library.add_field_to_entry(entry_full.id, field=empty_title) # update one of the fields library.update_text_field(entry_full.id, title_field, "new value", title_field.is_multiline) @@ -268,7 +264,8 @@ def test_mirror_entry_fields(library: Library): folder=unwrap(library.folder), path=Path("notes.txt"), fields=[ - TextField(name="Notes", value="These are my notes.\nNo peeking!", is_multiline=True) + TextField(name="Notes", value="These are my notes.\nNo peeking!", is_multiline=True), + TextField(name="Title", value="I'm a Test Title"), ], ) entry_c = Entry( @@ -291,7 +288,7 @@ def test_mirror_entry_fields(library: Library): assert entry_b_.fields[0].name == "Notes" assert entry_c_.fields[0].name == "Date Published" assert len(entry_a_.fields) == 2 - assert len(entry_b_.fields) == 1 + assert len(entry_b_.fields) == 2 assert len(entry_c_.fields) == 1 # Mirror fields between entries @@ -302,6 +299,11 @@ def test_mirror_entry_fields(library: Library): entry_b_mirrored = unwrap(library.get_entry_full(entry_b_id)) entry_c_mirrored = unwrap(library.get_entry_full(entry_c_id)) + for entry in [entry_a_mirrored, entry_b_mirrored, entry_c_mirrored]: + logger.info( + "[Library][mirror_fields]", entry_id=entry.id, field_count_after=len(entry.fields) + ) + # Assert presence of all fields on all entries assert len(entry_a_mirrored.fields) == 4 assert len(entry_b_mirrored.fields) == 4