From bf300d80bac6c3aa392f39f5cd8b13b1ab1b7be3 Mon Sep 17 00:00:00 2001 From: Travis Abendshien <46939827+CyanVoxel@users.noreply.github.com> Date: Sat, 9 May 2026 20:09:24 -0700 Subject: [PATCH] fix: implement additional feedback --- src/tagstudio/core/library/alchemy/fields.py | 32 +++++++++ src/tagstudio/core/library/alchemy/library.py | 71 ++++++------------- src/tagstudio/qt/mixed/add_field.py | 2 +- src/tagstudio/qt/mixed/field_containers.py | 15 ++-- src/tagstudio/qt/ts_qt.py | 2 +- tests/test_library.py | 4 +- 6 files changed, 61 insertions(+), 65 deletions(-) diff --git a/src/tagstudio/core/library/alchemy/fields.py b/src/tagstudio/core/library/alchemy/fields.py index 4331f578..5d4d6c96 100644 --- a/src/tagstudio/core/library/alchemy/fields.py +++ b/src/tagstudio/core/library/alchemy/fields.py @@ -35,6 +35,13 @@ class BaseField(Base): def entry(self) -> Mapped[Entry]: return relationship(foreign_keys=[self.entry_id]) # type: ignore # pyright: ignore[reportArgumentType] + @property + def class_name(self) -> str: + return self.__class__.__name__ + + def clone_with_entry_id(self, entry_id: int) -> BaseField: # pyright: ignore + raise NotImplementedError() + value: Any # pyright: ignore @@ -59,6 +66,12 @@ class TextField(BaseField): def __hash__(self) -> int: return hash((self.name, self.value, self.is_multiline)) + @override + def clone_with_entry_id(self, entry_id: int) -> TextField: + return TextField( + name=self.name, entry_id=entry_id, value=self.value, is_multiline=self.is_multiline + ) + class DatetimeField(BaseField): __tablename__ = "datetime_fields" @@ -76,6 +89,10 @@ class DatetimeField(BaseField): def __hash__(self) -> int: return hash((self.name, self.value)) + @override + def clone_with_entry_id(self, entry_id: int) -> DatetimeField: + return DatetimeField(name=self.name, entry_id=entry_id, value=self.value) + class BaseFieldTemplate(Base): __abstract__ = True @@ -88,15 +105,30 @@ class BaseFieldTemplate(Base): def name(self) -> Mapped[str]: return mapped_column(nullable=False, default="") + @property + def class_name(self) -> str: + return self.__class__.__name__ + + def to_field(self, value: Any | None = None) -> BaseField: # pyright: ignore + raise NotImplementedError() + class TextFieldTemplate(BaseFieldTemplate): __tablename__ = "text_field_templates" is_multiline: Mapped[bool] = mapped_column(nullable=False, default=False) + @override + def to_field(self, value: str | None = None) -> TextField: + return TextField(name=self.name, value=value, is_multiline=self.is_multiline) + class DatetimeFieldTemplate(BaseFieldTemplate): __tablename__ = "datetime_field_templates" + @override + def to_field(self, value: str | None = None) -> DatetimeField: + return DatetimeField(name=self.name, value=value) + # Used for migrating legacy libraries. # Legacy JSON libraries ( bool: + def add_field_to_entries(self, entry_ids: list[int] | int, field: BaseField) -> bool: """Add a field object to an Entry.""" - if type(field) is TextField: - logger.info( - "[Library] Adding TextField to entry", - entry_id=entry_id, - name=field.name, - value=field.value, - is_multiline=field.is_multiline, - ) + if isinstance(entry_ids, int): + entry_ids = [entry_ids] - field = TextField( - entry_id=entry_id, - name=field.name, - value=field.value, - is_multiline=field.is_multiline, - ) + logger.info( + "[Library] Adding field to entry", + type=field.class_name, + entry_ids=entry_ids, + name=field.name, + value=field.value, + ) - with Session(self.engine) as session: + with Session(self.engine) as session: + for entry_id in entry_ids: try: - session.add(field) - session.commit() - except IntegrityError as e: - logger.error(e) - session.rollback() - return False - - elif type(field) is DatetimeField: - logger.info( - "[Library] Adding DatetimeField to entry", - entry_id=entry_id, - name=field.name, - value=field.value, - ) - - field = DatetimeField(entry_id=entry_id, name=field.name, value=field.value) - - with Session(self.engine) as session: - try: - session.add(field) + session.add(field.clone_with_entry_id(entry_id)) session.commit() except IntegrityError as e: logger.error(e) @@ -1951,7 +1922,7 @@ class Library: for entry in entries: for field in all_fields: if field not in entry.fields: - self.add_field_to_entry(entry_id=entry.id, field=field) + self.add_field_to_entries(entry_ids=entry.id, field=field) def merge_entries(self, from_entry: Entry, into_entry: Entry) -> bool: """Add fields and tags from the first entry to the second, and then delete the first.""" diff --git a/src/tagstudio/qt/mixed/add_field.py b/src/tagstudio/qt/mixed/add_field.py index bfaac931..ba83ab58 100644 --- a/src/tagstudio/qt/mixed/add_field.py +++ b/src/tagstudio/qt/mixed/add_field.py @@ -78,7 +78,7 @@ class AddFieldModal(QWidget): self.list_widget.clear() for field_template in self.lib.field_templates: field_name_key: str = FIELD_TYPE_KEYS.get( - field_template.__class__.__name__, "field_type.unknown" + field_template.class_name, "field_type.unknown" ) item = QListWidgetItem(f"{field_template.name} ({Translations[field_name_key]})") item.setData(Qt.ItemDataRole.UserRole, field_template) diff --git a/src/tagstudio/qt/mixed/field_containers.py b/src/tagstudio/qt/mixed/field_containers.py index 2a7da707..12f82577 100644 --- a/src/tagstudio/qt/mixed/field_containers.py +++ b/src/tagstudio/qt/mixed/field_containers.py @@ -28,9 +28,7 @@ from tagstudio.core.library.alchemy.fields import ( BaseField, BaseFieldTemplate, DatetimeField, - DatetimeFieldTemplate, TextField, - TextFieldTemplate, ) from tagstudio.core.library.alchemy.library import Library from tagstudio.core.library.alchemy.models import Entry, Tag @@ -224,14 +222,9 @@ class FieldContainers(QWidget): logger.info( "[FieldContainers][add_field_to_selected] Adding field", name=template.name, - type=template.__class__.__name__, + type=template.class_name, ) - if type(template) is TextFieldTemplate: - text_field = TextField(name=template.name, is_multiline=template.is_multiline) - self.lib.add_field_to_entry(entry_id, text_field) - elif type(template) is DatetimeFieldTemplate: - datetime_field = DatetimeField(name=template.name) - self.lib.add_field_to_entry(entry_id, datetime_field) + self.lib.add_field_to_entries(entry_id, template.to_field()) def add_tags_to_selected(self, tags: int | list[int]): """Add list of tags to one or more selected items. @@ -265,7 +258,7 @@ class FieldContainers(QWidget): "[FieldContainers][write_container]", index=index, name=field.name, - type=field.__class__.__name__, + type=field.class_name, ) if len(self.containers) < (index + 1): container = FieldContainer() @@ -275,7 +268,7 @@ class FieldContainers(QWidget): container = self.containers[index] # Set field title - field_name_key: str = FIELD_TYPE_KEYS.get(field.__class__.__name__, "field_type.unknown") + field_name_key: str = FIELD_TYPE_KEYS.get(field.class_name, "field_type.unknown") title = f"{field.name} ({Translations[field_name_key]})" # Single-line Text diff --git a/src/tagstudio/qt/ts_qt.py b/src/tagstudio/qt/ts_qt.py index 4046e93f..8ba3a8ad 100644 --- a/src/tagstudio/qt/ts_qt.py +++ b/src/tagstudio/qt/ts_qt.py @@ -1228,7 +1228,7 @@ class QtDriver(DriverMixin, QObject): if field.type_key == e.type_key and field.value == e.value: exists = True if not exists: - self.lib.add_field_to_entry(id, field_id=field.type_key, value=field.value) + self.lib.add_field_to_entries(id, field_id=field.type_key, value=field.value) self.lib.add_tags_to_entries(id, self.copy_buffer["tags"]) if len(self.selected) > 1: if TAG_ARCHIVED in self.copy_buffer["tags"]: diff --git a/tests/test_library.py b/tests/test_library.py index c8e7524c..01cfb84b 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -213,7 +213,7 @@ def test_remove_text_field_entry_with_multiple_fields(library: Library, entry_fu # When # add identical field - assert library.add_field_to_entry(entry_full.id, field=title_field) + assert library.add_field_to_entries(entry_full.id, field=title_field) # remove entry field library.remove_entry_field(title_field, [entry_full.id]) @@ -239,7 +239,7 @@ def test_update_entry_with_multiple_identical_text_fields(library: Library, entr # When # add identical field empty_title = TextField(name="Title", value="") - library.add_field_to_entry(entry_full.id, field=empty_title) + library.add_field_to_entries(entry_full.id, field=empty_title) # update one of the fields library.update_text_field(entry_full.id, title_field, "new value", title_field.is_multiline)