mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-05-10 22:33:59 +00:00
fix: implement additional feedback
This commit is contained in:
@@ -35,6 +35,13 @@ class BaseField(Base):
|
||||
def entry(self) -> Mapped[Entry]:
|
||||
return relationship(foreign_keys=[self.entry_id]) # type: ignore # pyright: ignore[reportArgumentType]
|
||||
|
||||
@property
|
||||
def class_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def clone_with_entry_id(self, entry_id: int) -> BaseField: # pyright: ignore
|
||||
raise NotImplementedError()
|
||||
|
||||
value: Any # pyright: ignore
|
||||
|
||||
|
||||
@@ -59,6 +66,12 @@ class TextField(BaseField):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.name, self.value, self.is_multiline))
|
||||
|
||||
@override
|
||||
def clone_with_entry_id(self, entry_id: int) -> TextField:
|
||||
return TextField(
|
||||
name=self.name, entry_id=entry_id, value=self.value, is_multiline=self.is_multiline
|
||||
)
|
||||
|
||||
|
||||
class DatetimeField(BaseField):
|
||||
__tablename__ = "datetime_fields"
|
||||
@@ -76,6 +89,10 @@ class DatetimeField(BaseField):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.name, self.value))
|
||||
|
||||
@override
|
||||
def clone_with_entry_id(self, entry_id: int) -> DatetimeField:
|
||||
return DatetimeField(name=self.name, entry_id=entry_id, value=self.value)
|
||||
|
||||
|
||||
class BaseFieldTemplate(Base):
|
||||
__abstract__ = True
|
||||
@@ -88,15 +105,30 @@ class BaseFieldTemplate(Base):
|
||||
def name(self) -> Mapped[str]:
|
||||
return mapped_column(nullable=False, default="")
|
||||
|
||||
@property
|
||||
def class_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def to_field(self, value: Any | None = None) -> BaseField: # pyright: ignore
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TextFieldTemplate(BaseFieldTemplate):
|
||||
__tablename__ = "text_field_templates"
|
||||
is_multiline: Mapped[bool] = mapped_column(nullable=False, default=False)
|
||||
|
||||
@override
|
||||
def to_field(self, value: str | None = None) -> TextField:
|
||||
return TextField(name=self.name, value=value, is_multiline=self.is_multiline)
|
||||
|
||||
|
||||
class DatetimeFieldTemplate(BaseFieldTemplate):
|
||||
__tablename__ = "datetime_field_templates"
|
||||
|
||||
@override
|
||||
def to_field(self, value: str | None = None) -> DatetimeField:
|
||||
return DatetimeField(name=self.name, value=value)
|
||||
|
||||
|
||||
# Used for migrating legacy libraries.
|
||||
# Legacy JSON libraries (<v9.4) use an integer ID.
|
||||
|
||||
@@ -335,13 +335,15 @@ class Library:
|
||||
value=value,
|
||||
is_multiline=bool(field_info["is_multiline"]),
|
||||
)
|
||||
self.add_field_to_entry(entry_id=(entry.id + 1), field=text_field)
|
||||
self.add_field_to_entries(
|
||||
entry_ids=(entry.id + 1), field=text_field
|
||||
)
|
||||
elif field_info["type"] == DatetimeField:
|
||||
datetime_field = DatetimeField(
|
||||
name=str(field_info["name"]), value=value
|
||||
)
|
||||
self.add_field_to_entry(
|
||||
entry_id=(entry.id + 1), field=datetime_field
|
||||
self.add_field_to_entries(
|
||||
entry_ids=(entry.id + 1), field=datetime_field
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -1312,11 +1314,7 @@ class Library:
|
||||
with Session(self.engine) as session:
|
||||
update_stmt = (
|
||||
update(field_type)
|
||||
.where(
|
||||
and_(
|
||||
field_type.id == field.id,
|
||||
)
|
||||
)
|
||||
.where(and_(field_type.id == field.id, field_type.entry_id.in_(entry_ids)))
|
||||
.values(value=value, is_multiline=is_multiline)
|
||||
)
|
||||
|
||||
@@ -1338,57 +1336,30 @@ class Library:
|
||||
with Session(self.engine) as session:
|
||||
update_stmt = (
|
||||
update(field_type)
|
||||
.where(
|
||||
and_(
|
||||
field_type.id == field.id,
|
||||
)
|
||||
)
|
||||
.where(and_(field_type.id == field.id, field_type.entry_id.in_(entry_ids)))
|
||||
.values(value=value)
|
||||
)
|
||||
|
||||
session.execute(update_stmt)
|
||||
session.commit()
|
||||
|
||||
def add_field_to_entry(self, entry_id: int, field: BaseField) -> bool:
|
||||
def add_field_to_entries(self, entry_ids: list[int] | int, field: BaseField) -> bool:
|
||||
"""Add a field object to an Entry."""
|
||||
if type(field) is TextField:
|
||||
logger.info(
|
||||
"[Library] Adding TextField to entry",
|
||||
entry_id=entry_id,
|
||||
name=field.name,
|
||||
value=field.value,
|
||||
is_multiline=field.is_multiline,
|
||||
)
|
||||
if isinstance(entry_ids, int):
|
||||
entry_ids = [entry_ids]
|
||||
|
||||
field = TextField(
|
||||
entry_id=entry_id,
|
||||
name=field.name,
|
||||
value=field.value,
|
||||
is_multiline=field.is_multiline,
|
||||
)
|
||||
logger.info(
|
||||
"[Library] Adding field to entry",
|
||||
type=field.class_name,
|
||||
entry_ids=entry_ids,
|
||||
name=field.name,
|
||||
value=field.value,
|
||||
)
|
||||
|
||||
with Session(self.engine) as session:
|
||||
with Session(self.engine) as session:
|
||||
for entry_id in entry_ids:
|
||||
try:
|
||||
session.add(field)
|
||||
session.commit()
|
||||
except IntegrityError as e:
|
||||
logger.error(e)
|
||||
session.rollback()
|
||||
return False
|
||||
|
||||
elif type(field) is DatetimeField:
|
||||
logger.info(
|
||||
"[Library] Adding DatetimeField to entry",
|
||||
entry_id=entry_id,
|
||||
name=field.name,
|
||||
value=field.value,
|
||||
)
|
||||
|
||||
field = DatetimeField(entry_id=entry_id, name=field.name, value=field.value)
|
||||
|
||||
with Session(self.engine) as session:
|
||||
try:
|
||||
session.add(field)
|
||||
session.add(field.clone_with_entry_id(entry_id))
|
||||
session.commit()
|
||||
except IntegrityError as e:
|
||||
logger.error(e)
|
||||
@@ -1951,7 +1922,7 @@ class Library:
|
||||
for entry in entries:
|
||||
for field in all_fields:
|
||||
if field not in entry.fields:
|
||||
self.add_field_to_entry(entry_id=entry.id, field=field)
|
||||
self.add_field_to_entries(entry_ids=entry.id, field=field)
|
||||
|
||||
def merge_entries(self, from_entry: Entry, into_entry: Entry) -> bool:
|
||||
"""Add fields and tags from the first entry to the second, and then delete the first."""
|
||||
|
||||
@@ -78,7 +78,7 @@ class AddFieldModal(QWidget):
|
||||
self.list_widget.clear()
|
||||
for field_template in self.lib.field_templates:
|
||||
field_name_key: str = FIELD_TYPE_KEYS.get(
|
||||
field_template.__class__.__name__, "field_type.unknown"
|
||||
field_template.class_name, "field_type.unknown"
|
||||
)
|
||||
item = QListWidgetItem(f"{field_template.name} ({Translations[field_name_key]})")
|
||||
item.setData(Qt.ItemDataRole.UserRole, field_template)
|
||||
|
||||
@@ -28,9 +28,7 @@ from tagstudio.core.library.alchemy.fields import (
|
||||
BaseField,
|
||||
BaseFieldTemplate,
|
||||
DatetimeField,
|
||||
DatetimeFieldTemplate,
|
||||
TextField,
|
||||
TextFieldTemplate,
|
||||
)
|
||||
from tagstudio.core.library.alchemy.library import Library
|
||||
from tagstudio.core.library.alchemy.models import Entry, Tag
|
||||
@@ -224,14 +222,9 @@ class FieldContainers(QWidget):
|
||||
logger.info(
|
||||
"[FieldContainers][add_field_to_selected] Adding field",
|
||||
name=template.name,
|
||||
type=template.__class__.__name__,
|
||||
type=template.class_name,
|
||||
)
|
||||
if type(template) is TextFieldTemplate:
|
||||
text_field = TextField(name=template.name, is_multiline=template.is_multiline)
|
||||
self.lib.add_field_to_entry(entry_id, text_field)
|
||||
elif type(template) is DatetimeFieldTemplate:
|
||||
datetime_field = DatetimeField(name=template.name)
|
||||
self.lib.add_field_to_entry(entry_id, datetime_field)
|
||||
self.lib.add_field_to_entries(entry_id, template.to_field())
|
||||
|
||||
def add_tags_to_selected(self, tags: int | list[int]):
|
||||
"""Add list of tags to one or more selected items.
|
||||
@@ -265,7 +258,7 @@ class FieldContainers(QWidget):
|
||||
"[FieldContainers][write_container]",
|
||||
index=index,
|
||||
name=field.name,
|
||||
type=field.__class__.__name__,
|
||||
type=field.class_name,
|
||||
)
|
||||
if len(self.containers) < (index + 1):
|
||||
container = FieldContainer()
|
||||
@@ -275,7 +268,7 @@ class FieldContainers(QWidget):
|
||||
container = self.containers[index]
|
||||
|
||||
# Set field title
|
||||
field_name_key: str = FIELD_TYPE_KEYS.get(field.__class__.__name__, "field_type.unknown")
|
||||
field_name_key: str = FIELD_TYPE_KEYS.get(field.class_name, "field_type.unknown")
|
||||
title = f"{field.name} ({Translations[field_name_key]})"
|
||||
|
||||
# Single-line Text
|
||||
|
||||
@@ -1228,7 +1228,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
if field.type_key == e.type_key and field.value == e.value:
|
||||
exists = True
|
||||
if not exists:
|
||||
self.lib.add_field_to_entry(id, field_id=field.type_key, value=field.value)
|
||||
self.lib.add_field_to_entries(id, field_id=field.type_key, value=field.value)
|
||||
self.lib.add_tags_to_entries(id, self.copy_buffer["tags"])
|
||||
if len(self.selected) > 1:
|
||||
if TAG_ARCHIVED in self.copy_buffer["tags"]:
|
||||
|
||||
@@ -213,7 +213,7 @@ def test_remove_text_field_entry_with_multiple_fields(library: Library, entry_fu
|
||||
|
||||
# When
|
||||
# add identical field
|
||||
assert library.add_field_to_entry(entry_full.id, field=title_field)
|
||||
assert library.add_field_to_entries(entry_full.id, field=title_field)
|
||||
|
||||
# remove entry field
|
||||
library.remove_entry_field(title_field, [entry_full.id])
|
||||
@@ -239,7 +239,7 @@ def test_update_entry_with_multiple_identical_text_fields(library: Library, entr
|
||||
# When
|
||||
# add identical field
|
||||
empty_title = TextField(name="Title", value="")
|
||||
library.add_field_to_entry(entry_full.id, field=empty_title)
|
||||
library.add_field_to_entries(entry_full.id, field=empty_title)
|
||||
|
||||
# update one of the fields
|
||||
library.update_text_field(entry_full.id, title_field, "new value", title_field.is_multiline)
|
||||
|
||||
Reference in New Issue
Block a user