fix: implement additional feedback

This commit is contained in:
Travis Abendshien
2026-05-09 20:09:24 -07:00
parent be08cabc4f
commit bf300d80ba
6 changed files with 61 additions and 65 deletions

View File

@@ -35,6 +35,13 @@ class BaseField(Base):
def entry(self) -> Mapped[Entry]:
return relationship(foreign_keys=[self.entry_id]) # type: ignore # pyright: ignore[reportArgumentType]
@property
def class_name(self) -> str:
return self.__class__.__name__
def clone_with_entry_id(self, entry_id: int) -> BaseField: # pyright: ignore
raise NotImplementedError()
value: Any # pyright: ignore
@@ -59,6 +66,12 @@ class TextField(BaseField):
def __hash__(self) -> int:
return hash((self.name, self.value, self.is_multiline))
@override
def clone_with_entry_id(self, entry_id: int) -> TextField:
return TextField(
name=self.name, entry_id=entry_id, value=self.value, is_multiline=self.is_multiline
)
class DatetimeField(BaseField):
__tablename__ = "datetime_fields"
@@ -76,6 +89,10 @@ class DatetimeField(BaseField):
def __hash__(self) -> int:
return hash((self.name, self.value))
@override
def clone_with_entry_id(self, entry_id: int) -> DatetimeField:
return DatetimeField(name=self.name, entry_id=entry_id, value=self.value)
class BaseFieldTemplate(Base):
__abstract__ = True
@@ -88,15 +105,30 @@ class BaseFieldTemplate(Base):
def name(self) -> Mapped[str]:
return mapped_column(nullable=False, default="")
@property
def class_name(self) -> str:
return self.__class__.__name__
def to_field(self, value: Any | None = None) -> BaseField: # pyright: ignore
raise NotImplementedError()
class TextFieldTemplate(BaseFieldTemplate):
__tablename__ = "text_field_templates"
is_multiline: Mapped[bool] = mapped_column(nullable=False, default=False)
@override
def to_field(self, value: str | None = None) -> TextField:
return TextField(name=self.name, value=value, is_multiline=self.is_multiline)
class DatetimeFieldTemplate(BaseFieldTemplate):
__tablename__ = "datetime_field_templates"
@override
def to_field(self, value: str | None = None) -> DatetimeField:
return DatetimeField(name=self.name, value=value)
# Used for migrating legacy libraries.
# Legacy JSON libraries (<v9.4) use an integer ID.

View File

@@ -335,13 +335,15 @@ class Library:
value=value,
is_multiline=bool(field_info["is_multiline"]),
)
self.add_field_to_entry(entry_id=(entry.id + 1), field=text_field)
self.add_field_to_entries(
entry_ids=(entry.id + 1), field=text_field
)
elif field_info["type"] == DatetimeField:
datetime_field = DatetimeField(
name=str(field_info["name"]), value=value
)
self.add_field_to_entry(
entry_id=(entry.id + 1), field=datetime_field
self.add_field_to_entries(
entry_ids=(entry.id + 1), field=datetime_field
)
except Exception as e:
logger.error(
@@ -1312,11 +1314,7 @@ class Library:
with Session(self.engine) as session:
update_stmt = (
update(field_type)
.where(
and_(
field_type.id == field.id,
)
)
.where(and_(field_type.id == field.id, field_type.entry_id.in_(entry_ids)))
.values(value=value, is_multiline=is_multiline)
)
@@ -1338,57 +1336,30 @@ class Library:
with Session(self.engine) as session:
update_stmt = (
update(field_type)
.where(
and_(
field_type.id == field.id,
)
)
.where(and_(field_type.id == field.id, field_type.entry_id.in_(entry_ids)))
.values(value=value)
)
session.execute(update_stmt)
session.commit()
def add_field_to_entry(self, entry_id: int, field: BaseField) -> bool:
def add_field_to_entries(self, entry_ids: list[int] | int, field: BaseField) -> bool:
"""Add a field object to an Entry."""
if type(field) is TextField:
logger.info(
"[Library] Adding TextField to entry",
entry_id=entry_id,
name=field.name,
value=field.value,
is_multiline=field.is_multiline,
)
if isinstance(entry_ids, int):
entry_ids = [entry_ids]
field = TextField(
entry_id=entry_id,
name=field.name,
value=field.value,
is_multiline=field.is_multiline,
)
logger.info(
"[Library] Adding field to entry",
type=field.class_name,
entry_ids=entry_ids,
name=field.name,
value=field.value,
)
with Session(self.engine) as session:
with Session(self.engine) as session:
for entry_id in entry_ids:
try:
session.add(field)
session.commit()
except IntegrityError as e:
logger.error(e)
session.rollback()
return False
elif type(field) is DatetimeField:
logger.info(
"[Library] Adding DatetimeField to entry",
entry_id=entry_id,
name=field.name,
value=field.value,
)
field = DatetimeField(entry_id=entry_id, name=field.name, value=field.value)
with Session(self.engine) as session:
try:
session.add(field)
session.add(field.clone_with_entry_id(entry_id))
session.commit()
except IntegrityError as e:
logger.error(e)
@@ -1951,7 +1922,7 @@ class Library:
for entry in entries:
for field in all_fields:
if field not in entry.fields:
self.add_field_to_entry(entry_id=entry.id, field=field)
self.add_field_to_entries(entry_ids=entry.id, field=field)
def merge_entries(self, from_entry: Entry, into_entry: Entry) -> bool:
"""Add fields and tags from the first entry to the second, and then delete the first."""

View File

@@ -78,7 +78,7 @@ class AddFieldModal(QWidget):
self.list_widget.clear()
for field_template in self.lib.field_templates:
field_name_key: str = FIELD_TYPE_KEYS.get(
field_template.__class__.__name__, "field_type.unknown"
field_template.class_name, "field_type.unknown"
)
item = QListWidgetItem(f"{field_template.name} ({Translations[field_name_key]})")
item.setData(Qt.ItemDataRole.UserRole, field_template)

View File

@@ -28,9 +28,7 @@ from tagstudio.core.library.alchemy.fields import (
BaseField,
BaseFieldTemplate,
DatetimeField,
DatetimeFieldTemplate,
TextField,
TextFieldTemplate,
)
from tagstudio.core.library.alchemy.library import Library
from tagstudio.core.library.alchemy.models import Entry, Tag
@@ -224,14 +222,9 @@ class FieldContainers(QWidget):
logger.info(
"[FieldContainers][add_field_to_selected] Adding field",
name=template.name,
type=template.__class__.__name__,
type=template.class_name,
)
if type(template) is TextFieldTemplate:
text_field = TextField(name=template.name, is_multiline=template.is_multiline)
self.lib.add_field_to_entry(entry_id, text_field)
elif type(template) is DatetimeFieldTemplate:
datetime_field = DatetimeField(name=template.name)
self.lib.add_field_to_entry(entry_id, datetime_field)
self.lib.add_field_to_entries(entry_id, template.to_field())
def add_tags_to_selected(self, tags: int | list[int]):
"""Add list of tags to one or more selected items.
@@ -265,7 +258,7 @@ class FieldContainers(QWidget):
"[FieldContainers][write_container]",
index=index,
name=field.name,
type=field.__class__.__name__,
type=field.class_name,
)
if len(self.containers) < (index + 1):
container = FieldContainer()
@@ -275,7 +268,7 @@ class FieldContainers(QWidget):
container = self.containers[index]
# Set field title
field_name_key: str = FIELD_TYPE_KEYS.get(field.__class__.__name__, "field_type.unknown")
field_name_key: str = FIELD_TYPE_KEYS.get(field.class_name, "field_type.unknown")
title = f"{field.name} ({Translations[field_name_key]})"
# Single-line Text

View File

@@ -1228,7 +1228,7 @@ class QtDriver(DriverMixin, QObject):
if field.type_key == e.type_key and field.value == e.value:
exists = True
if not exists:
self.lib.add_field_to_entry(id, field_id=field.type_key, value=field.value)
self.lib.add_field_to_entries(id, field_id=field.type_key, value=field.value)
self.lib.add_tags_to_entries(id, self.copy_buffer["tags"])
if len(self.selected) > 1:
if TAG_ARCHIVED in self.copy_buffer["tags"]:

View File

@@ -213,7 +213,7 @@ def test_remove_text_field_entry_with_multiple_fields(library: Library, entry_fu
# When
# add identical field
assert library.add_field_to_entry(entry_full.id, field=title_field)
assert library.add_field_to_entries(entry_full.id, field=title_field)
# remove entry field
library.remove_entry_field(title_field, [entry_full.id])
@@ -239,7 +239,7 @@ def test_update_entry_with_multiple_identical_text_fields(library: Library, entr
# When
# add identical field
empty_title = TextField(name="Title", value="")
library.add_field_to_entry(entry_full.id, field=empty_title)
library.add_field_to_entries(entry_full.id, field=empty_title)
# update one of the fields
library.update_text_field(entry_full.id, title_field, "new value", title_field.is_multiline)