refactor!: eradicate use of the term "subtag"

- Removes ambiguity between the use of the term "parent tag" and "subtag"
- Fixes inconstancies between the use of the term "subtag" to refer to either parent tags or child tags
- Fixes duplicate and ambiguous subtags mapped relationship for the Tag model
- Does NOT fix tests
This commit is contained in:
Travis Abendshien
2025-01-05 19:12:18 -08:00
parent af511d8986
commit 6461eebb48
14 changed files with 178 additions and 202 deletions

View File

@@ -8,8 +8,8 @@ from sqlalchemy.orm import Mapped, mapped_column
from .db import Base
class TagSubtag(Base):
__tablename__ = "tag_subtags"
class TagParent(Base):
__tablename__ = "tag_parents"
parent_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), primary_key=True)
child_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), primary_key=True)

View File

@@ -54,7 +54,7 @@ from .fields import (
TextField,
_FieldID,
)
from .joins import TagEntry, TagSubtag
from .joins import TagEntry, TagParent
from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType
from .visitors import SQLBoolExpressionBuilder
@@ -85,7 +85,7 @@ def get_default_tags() -> tuple[Tag, ...]:
id=TAG_ARCHIVED,
name="Archived",
aliases={TagAlias(name="Archive")},
subtags={meta_tag},
parent_tags={meta_tag},
color=TagColor.RED,
)
favorite_tag = Tag(
@@ -95,7 +95,7 @@ def get_default_tags() -> tuple[Tag, ...]:
TagAlias(name="Favorited"),
TagAlias(name="Favorites"),
},
subtags={meta_tag},
parent_tags={meta_tag},
color=TagColor.YELLOW,
)
@@ -192,10 +192,10 @@ class Library:
break
self.add_alias(name=alias, tag_id=tag.id)
# Tag Subtags
# Parent Tags (Previously known as "Subtags" in JSON)
for tag in json_lib.tags:
for subtag_id in tag.subtag_ids:
self.add_subtag(parent_id=tag.id, child_id=subtag_id)
for child_id in tag.subtag_ids:
self.add_parent_tag(parent_id=tag.id, child_id=child_id)
# Entries
self.add_entries(
@@ -216,8 +216,8 @@ class Library:
if k in {6, 7, 8}:
self.add_tags_to_entry(entry_id=entry.id + 1, tag_ids=v)
else:
self.add_entry_field_type(
entry_ids=(entry.id + 1), # JSON IDs start at 0 instead of 1
self.add_field_to_entry(
entry_id=(entry.id + 1), # JSON IDs start at 0 instead of 1
field_id=self.get_field_name_from_id(k),
value=v,
)
@@ -405,7 +405,7 @@ class Library:
.options(
selectinload(Entry.tags).options(
joinedload(Tag.aliases),
joinedload(Tag.subtags),
joinedload(Tag.parent_tags),
)
)
)
@@ -430,7 +430,7 @@ class Library:
selectinload(Entry.datetime_fields),
selectinload(Entry.tags).options(
selectinload(Tag.aliases),
selectinload(Tag.subtags),
selectinload(Tag.parent_tags),
),
)
statement = statement.distinct()
@@ -476,8 +476,8 @@ class Library:
@property
def tags(self) -> list[Tag]:
with Session(self.engine) as session:
# load all tags and join subtags
tags_query = select(Tag).options(selectinload(Tag.subtags))
# load all tags and join parent tags
tags_query = select(Tag).options(selectinload(Tag.parent_tags))
tags = session.scalars(tags_query).unique()
tags_list = list(tags)
@@ -577,7 +577,7 @@ class Library:
selectinload(Entry.text_fields),
selectinload(Entry.datetime_fields),
selectinload(Entry.tags).options(
selectinload(Tag.aliases), selectinload(Tag.subtags)
selectinload(Tag.aliases), selectinload(Tag.parent_tags)
),
)
@@ -613,7 +613,7 @@ class Library:
with Session(self.engine) as session:
query = select(Tag)
query = query.options(
selectinload(Tag.subtags),
selectinload(Tag.parent_tags),
selectinload(Tag.aliases),
).limit(tag_limit)
@@ -641,22 +641,22 @@ class Library:
return res
def get_all_child_tag_ids(self, tag_id: int) -> list[int]:
"""Recursively traverse a Tag's subtags and return a list of all children tags."""
all_subtags: set[int] = {tag_id}
"""Recursively traverse a Tag's parent tags and return a list of all children tags."""
all_parent_ids: set[int] = {tag_id}
with Session(self.engine) as session:
tag = session.scalar(select(Tag).where(Tag.id == tag_id))
tag: Tag | None = session.scalar(select(Tag).where(Tag.id == tag_id))
if tag is None:
raise ValueError(f"No tag found with id {tag_id}.")
subtag_ids = tag.subtag_ids
parent_ids = tag.parent_ids
all_subtags.update(subtag_ids)
all_parent_ids.update(parent_ids)
for sub_id in subtag_ids:
all_subtags.update(self.get_all_child_tag_ids(sub_id))
for child_id in parent_ids:
all_parent_ids.update(self.get_all_child_tag_ids(child_id))
return list(all_subtags)
return list(all_parent_ids)
def update_entry_path(self, entry_id: int | Entry, path: Path) -> None:
if isinstance(entry_id, Entry):
@@ -679,11 +679,11 @@ class Library:
def remove_tag(self, tag: Tag):
with Session(self.engine, expire_on_commit=False) as session:
try:
subtags = session.scalars(
select(TagSubtag).where(TagSubtag.parent_id == tag.id)
parent_tags = session.scalars(
select(TagParent).where(TagParent.parent_id == tag.id)
).all()
tags_query = select(Tag).options(
selectinload(Tag.subtags), selectinload(Tag.aliases)
selectinload(Tag.parent_tags), selectinload(Tag.aliases)
)
tag = session.scalar(tags_query.where(Tag.id == tag.id))
aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id))
@@ -691,9 +691,9 @@ class Library:
for alias in aliases or []:
session.delete(alias)
for subtag in subtags or []:
session.delete(subtag)
session.expunge(subtag)
for parent_tag in parent_tags or []:
session.delete(parent_tag)
session.expunge(parent_tag)
session.delete(tag)
session.commit()
@@ -810,17 +810,17 @@ class Library:
session.expunge(field)
return field
def add_entry_field_type(
def add_field_to_entry(
self,
entry_ids: list[int] | int,
entry_id: int,
*,
field: ValueType | None = None,
field_id: _FieldID | str | None = None,
value: str | datetime | list[int] | None = None,
value: str | datetime | None = None,
) -> bool:
logger.info(
"add_field_to_entry",
entry_ids=entry_ids,
entry_id=entry_id,
field_type=field,
field_id=field_id,
value=value,
@@ -828,9 +828,6 @@ class Library:
# supply only instance or ID, not both
assert bool(field) != (field_id is not None)
if isinstance(entry_ids, int):
entry_ids = [entry_ids]
if not field:
if isinstance(field_id, _FieldID):
field_id = field_id.name
@@ -853,11 +850,9 @@ class Library:
with Session(self.engine) as session:
try:
for entry_id in entry_ids:
field_model.entry_id = entry_id
session.add(field_model)
session.flush()
field_model.entry_id = entry_id
session.add(field_model)
session.flush()
session.commit()
except IntegrityError as e:
logger.exception(e)
@@ -869,7 +864,7 @@ class Library:
self.update_field_position(
field_class=type(field_model),
field_type=field.key,
entry_ids=entry_ids,
entry_ids=entry_id,
)
return True
@@ -898,7 +893,7 @@ class Library:
def add_tag(
self,
tag: Tag,
subtag_ids: list[int] | set[int] | None = None,
parent_ids: list[int] | set[int] | None = None,
alias_names: list[str] | set[str] | None = None,
alias_ids: list[int] | set[int] | None = None,
) -> Tag | None:
@@ -907,8 +902,8 @@ class Library:
session.add(tag)
session.flush()
if subtag_ids is not None:
self.update_subtags(tag, subtag_ids, session)
if parent_ids is not None:
self.update_parent_tags(tag, parent_ids, session)
if alias_ids is not None and alias_names is not None:
self.update_aliases(tag, alias_ids, alias_names, session)
@@ -978,12 +973,14 @@ class Library:
def get_tag(self, tag_id: int) -> Tag:
with Session(self.engine) as session:
tags_query = select(Tag).options(selectinload(Tag.subtags), selectinload(Tag.aliases))
tags_query = select(Tag).options(
selectinload(Tag.parent_tags), selectinload(Tag.aliases)
)
tag = session.scalar(tags_query.where(Tag.id == tag_id))
session.expunge(tag)
for subtag in tag.subtags:
session.expunge(subtag)
for parent in tag.parent_tags:
session.expunge(parent)
for alias in tag.aliases:
session.expunge(alias)
@@ -1006,19 +1003,19 @@ class Library:
return alias
def add_subtag(self, parent_id: int, child_id: int) -> bool:
def add_parent_tag(self, parent_id: int, child_id: int) -> bool:
if parent_id == child_id:
return False
# open session and save as parent tag
with Session(self.engine) as session:
subtag = TagSubtag(
parent_tag = TagParent(
parent_id=parent_id,
child_id=child_id,
)
try:
session.add(subtag)
session.add(parent_tag)
session.commit()
return True
except IntegrityError:
@@ -1045,11 +1042,11 @@ class Library:
logger.exception("IntegrityError")
return False
def remove_subtag(self, base_id: int, remove_tag_id: int) -> bool:
def remove_parent_tag(self, base_id: int, remove_tag_id: int) -> bool:
with Session(self.engine) as session:
p_id = base_id
r_id = remove_tag_id
remove = session.query(TagSubtag).filter_by(parent_id=p_id, child_id=r_id).one()
remove = session.query(TagParent).filter_by(parent_id=p_id, child_id=r_id).one()
session.delete(remove)
session.commit()
@@ -1058,12 +1055,12 @@ class Library:
def update_tag(
self,
tag: Tag,
subtag_ids: list[int] | set[int] | None = None,
parent_ids: list[int] | set[int] | None = None,
alias_names: list[str] | set[str] | None = None,
alias_ids: list[int] | set[int] | None = None,
) -> None:
"""Edit a Tag in the Library."""
self.add_tag(tag, subtag_ids, alias_names, alias_ids)
self.add_tag(tag, parent_ids, alias_names, alias_ids)
def update_aliases(self, tag, alias_ids, alias_names, session):
prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all()
@@ -1079,28 +1076,30 @@ class Library:
alias = TagAlias(alias_name, tag.id)
session.add(alias)
def update_subtags(self, tag, subtag_ids, session):
if tag.id in subtag_ids:
subtag_ids.remove(tag.id)
def update_parent_tags(self, tag, parent_ids, session):
if tag.id in parent_ids:
parent_ids.remove(tag.id)
# load all tag's subtag to know which to remove
prev_subtags = session.scalars(select(TagSubtag).where(TagSubtag.parent_id == tag.id)).all()
# load all tag's parent tags to know which to remove
prev_parent_tags = session.scalars(
select(TagParent).where(TagParent.parent_id == tag.id)
).all()
for subtag in prev_subtags:
if subtag.child_id not in subtag_ids:
session.delete(subtag)
for parent_tag in prev_parent_tags:
if parent_tag.child_id not in parent_ids:
session.delete(parent_tag)
else:
# no change, remove from list
subtag_ids.remove(subtag.child_id)
parent_ids.remove(parent_tag.child_id)
# create remaining items
for subtag_id in subtag_ids:
# add new subtag
subtag = TagSubtag(
for parent_id in parent_ids:
# add new parent tag
parent_tag = TagParent(
parent_id=tag.id,
child_id=subtag_id,
child_id=parent_id,
)
session.add(subtag)
session.add(parent_tag)
def prefs(self, key: LibraryPrefs):
# load given item from Preferences table
@@ -1130,8 +1129,8 @@ class Library:
for entry in entries:
for field_key, field in fields.items():
if field_key not in existing_fields:
self.add_entry_field_type(
entry_ids=entry.id,
self.add_field_to_entry(
entry_id=entry.id,
field_id=field.type_key,
value=field.value,
)

View File

@@ -17,16 +17,14 @@ from .fields import (
FieldTypeEnum,
TextField,
)
from .joins import TagSubtag
from .joins import TagParent
class TagAlias(Base):
__tablename__ = "tag_aliases"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False)
tag_id: Mapped[int] = mapped_column(ForeignKey("tags.id"))
tag: Mapped["Tag"] = relationship(back_populates="aliases")
@@ -54,22 +52,15 @@ class Tag(Base):
aliases: Mapped[set[TagAlias]] = relationship(back_populates="tag")
parent_tags: Mapped[set["Tag"]] = relationship(
secondary=TagSubtag.__tablename__,
primaryjoin="Tag.id == TagSubtag.child_id",
secondaryjoin="Tag.id == TagSubtag.parent_id",
back_populates="subtags",
)
subtags: Mapped[set["Tag"]] = relationship(
secondary=TagSubtag.__tablename__,
primaryjoin="Tag.id == TagSubtag.parent_id",
secondaryjoin="Tag.id == TagSubtag.child_id",
secondary=TagParent.__tablename__,
primaryjoin="Tag.id == TagParent.parent_id",
secondaryjoin="Tag.id == TagParent.child_id",
back_populates="parent_tags",
)
@property
def subtag_ids(self) -> list[int]:
return [tag.id for tag in self.subtags]
def parent_ids(self) -> list[int]:
return [tag.id for tag in self.parent_tags]
@property
def alias_strings(self) -> list[str]:
@@ -86,7 +77,6 @@ class Tag(Base):
shorthand: str | None = None,
aliases: set[TagAlias] | None = None,
parent_tags: set["Tag"] | None = None,
subtags: set["Tag"] | None = None,
icon: str | None = None,
color: TagColor = TagColor.DEFAULT,
is_category: bool = False,
@@ -94,7 +84,6 @@ class Tag(Base):
self.name = name
self.aliases = aliases or set()
self.parent_tags = parent_tags or set()
self.subtags = subtags or set()
self.color = color
self.icon = icon
self.shorthand = shorthand

View File

@@ -23,16 +23,17 @@ else:
logger = structlog.get_logger(__name__)
# TODO: Reevaluate after subtags -> parent tags name change
CHILDREN_QUERY = text("""
-- Note for this entire query that tag_subtags.child_id is the parent id and tag_subtags.parent_id is the child id due to bad naming
WITH RECURSIVE Subtags AS (
-- Note for this entire query that tag_parents.child_id is the parent id and tag_parents.parent_id is the child id due to bad naming
WITH RECURSIVE ChildTags AS (
SELECT :tag_id AS child_id
UNION ALL
SELECT ts.parent_id AS child_id
FROM tag_subtags ts
INNER JOIN Subtags s ON ts.child_id = s.child_id
SELECT tp.parent_id AS child_id
FROM tag_parents tp
INNER JOIN ChildTags c ON tp.child_id = c.child_id
)
SELECT * FROM Subtags;
SELECT * FROM ChildTags;
""") # noqa: E501

View File

@@ -126,7 +126,7 @@ class TagStudioCore:
is_new = field["id"] not in entry_field_types
field_key = field["id"]
if is_new:
lib.add_entry_field_type(entry.id, field_key, field["value"])
lib.add_field_to_entry(entry.id, field_key, field["value"])
else:
lib.update_entry_field(entry.id, field_key, field["value"])

View File

@@ -118,21 +118,21 @@ class BuildTagPanel(PanelWidget):
self.alias_add_button.clicked.connect(self.add_alias_callback)
# Parent Tags ----------------------------------------------------------
self.subtags_widget = QWidget()
self.subtags_layout = QVBoxLayout(self.subtags_widget)
self.subtags_layout.setStretch(1, 1)
self.subtags_layout.setContentsMargins(0, 0, 0, 0)
self.subtags_layout.setSpacing(0)
self.subtags_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
self.parent_tags_widget = QWidget()
self.parent_tags_layout = QVBoxLayout(self.parent_tags_widget)
self.parent_tags_layout.setStretch(1, 1)
self.parent_tags_layout.setContentsMargins(0, 0, 0, 0)
self.parent_tags_layout.setSpacing(0)
self.parent_tags_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
self.subtags_title = QLabel()
Translations.translate_qobject(self.subtags_title, "tag.parent_tags")
self.subtags_layout.addWidget(self.subtags_title)
self.parent_tags_title = QLabel()
Translations.translate_qobject(self.parent_tags_title, "tag.parent_tags")
self.parent_tags_layout.addWidget(self.parent_tags_title)
self.scroll_contents = QWidget()
self.subtags_scroll_layout = QVBoxLayout(self.scroll_contents)
self.subtags_scroll_layout.setContentsMargins(6, 0, 6, 0)
self.subtags_scroll_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
self.parent_tags_scroll_layout = QVBoxLayout(self.scroll_contents)
self.parent_tags_scroll_layout.setContentsMargins(6, 0, 6, 0)
self.parent_tags_scroll_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
self.scroll_area = QScrollArea()
# self.scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
@@ -142,23 +142,23 @@ class BuildTagPanel(PanelWidget):
self.scroll_area.setWidget(self.scroll_contents)
# self.scroll_area.setMinimumHeight(60)
self.subtags_layout.addWidget(self.scroll_area)
self.parent_tags_layout.addWidget(self.scroll_area)
self.subtags_add_button = QPushButton()
self.subtags_add_button.setCursor(Qt.CursorShape.PointingHandCursor)
self.subtags_add_button.setText("+")
self.subtags_layout.addWidget(self.subtags_add_button)
self.parent_tags_add_button = QPushButton()
self.parent_tags_add_button.setCursor(Qt.CursorShape.PointingHandCursor)
self.parent_tags_add_button.setText("+")
self.parent_tags_layout.addWidget(self.parent_tags_add_button)
exclude_ids: list[int] = list()
if tag is not None:
exclude_ids.append(tag.id)
tsp = TagSearchPanel(self.lib, exclude_ids)
tsp.tag_chosen.connect(lambda x: self.add_subtag_callback(x))
tsp.tag_chosen.connect(lambda x: self.add_parent_tag_callback(x))
self.add_tag_modal = PanelModal(tsp)
Translations.translate_with_setter(self.add_tag_modal.setTitle, "tag.parent_tags.add")
Translations.translate_with_setter(self.add_tag_modal.setWindowTitle, "tag.parent_tags.add")
self.subtags_add_button.clicked.connect(self.add_tag_modal.show)
self.parent_tags_add_button.clicked.connect(self.add_tag_modal.show)
# Color ----------------------------------------------------------------
self.color_widget = QWidget()
@@ -230,13 +230,13 @@ class BuildTagPanel(PanelWidget):
self.root_layout.addWidget(self.aliases_widget)
self.root_layout.addWidget(self.aliases_table)
self.root_layout.addWidget(self.alias_add_button)
self.root_layout.addWidget(self.subtags_widget)
self.root_layout.addWidget(self.parent_tags_widget)
self.root_layout.addWidget(self.color_widget)
self.root_layout.addWidget(QLabel("<h3>Properties</h3>"))
self.root_layout.addWidget(self.cat_widget)
# self.parent().done.connect(self.update_tag)
self.subtag_ids: set[int] = set()
self.parent_ids: set[int] = set()
self.alias_ids: list[int] = []
self.alias_names: list[str] = []
self.new_alias_names: dict = {}
@@ -276,15 +276,15 @@ class BuildTagPanel(PanelWidget):
if isinstance(focused_widget, CustomTableItem):
self.add_alias_callback()
def add_subtag_callback(self, tag_id: int):
logger.info("add_subtag_callback", tag_id=tag_id)
self.subtag_ids.add(tag_id)
self.set_subtags()
def add_parent_tag_callback(self, tag_id: int):
logger.info("add_parent_tag_callback", tag_id=tag_id)
self.parent_ids.add(tag_id)
self.set_parent_tags()
def remove_subtag_callback(self, tag_id: int):
logger.info("removing subtag", tag_id=tag_id)
self.subtag_ids.remove(tag_id)
self.set_subtags()
def remove_parent_tag_callback(self, tag_id: int):
logger.info("remove_parent_tag_callback", tag_id=tag_id)
self.parent_ids.remove(tag_id)
self.set_parent_tags()
def add_alias_callback(self):
logger.info("add_alias_callback")
@@ -305,20 +305,20 @@ class BuildTagPanel(PanelWidget):
self.alias_ids.remove(alias_id)
self._set_aliases()
def set_subtags(self):
while self.subtags_scroll_layout.itemAt(0):
self.subtags_scroll_layout.takeAt(0).widget().deleteLater()
def set_parent_tags(self):
while self.parent_tags_scroll_layout.itemAt(0):
self.parent_tags_scroll_layout.takeAt(0).widget().deleteLater()
c = QWidget()
layout = QVBoxLayout(c)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(3)
for tag_id in self.subtag_ids:
for tag_id in self.parent_ids:
tag = self.lib.get_tag(tag_id)
tw = TagWidget(tag, has_edit=False, has_remove=True)
tw.on_remove.connect(lambda t=tag_id: self.remove_subtag_callback(t))
tw.on_remove.connect(lambda t=tag_id: self.remove_parent_tag_callback(t))
layout.addWidget(tw)
self.subtags_scroll_layout.addWidget(c)
self.parent_tags_scroll_layout.addWidget(c)
def add_aliases(self):
names: set[str] = set()
@@ -392,9 +392,9 @@ class BuildTagPanel(PanelWidget):
self.alias_ids.append(alias_id)
self._set_aliases()
for subtag in tag.subtag_ids:
self.subtag_ids.add(subtag)
self.set_subtags()
for parent_id in tag.parent_ids:
self.parent_ids.add(parent_id)
self.set_parent_tags()
# select item in self.color_field where the userData value matched tag.color
for i in range(self.color_field.count()):

View File

@@ -41,7 +41,7 @@ def add_folders_to_tree(library: Library, tree: BranchData, items: tuple[str, ..
branch = tree
for folder in items:
if folder not in branch.dirs:
# TODO - subtags
# TODO: Reimplement parent tags
new_tag = Tag(name=folder)
library.add_tag(new_tag)
branch.dirs[folder] = BranchData(tag=new_tag)
@@ -81,11 +81,11 @@ def reverse_tag(library: Library, tag: Tag, items: list[Tag] | None) -> list[Tag
items = items or []
items.append(tag)
if not tag.subtag_ids:
if not tag.parent_ids:
items.reverse()
return items
for subtag_id in tag.subtag_ids:
for subtag_id in tag.parent_ids:
subtag = library.get_tag(subtag_id)
return reverse_tag(library, subtag, items)

View File

@@ -87,7 +87,7 @@ class TagDatabasePanel(PanelWidget):
lambda: (
self.lib.add_tag(
tag=panel.build_tag(),
subtag_ids=panel.subtag_ids,
parent_ids=panel.parent_ids,
alias_names=panel.alias_names,
alias_ids=panel.alias_ids,
),
@@ -169,7 +169,7 @@ class TagDatabasePanel(PanelWidget):
def edit_tag_callback(self, btp: BuildTagPanel):
self.lib.update_tag(
btp.build_tag(), set(btp.subtag_ids), set(btp.alias_names), set(btp.alias_ids)
btp.build_tag(), set(btp.parent_ids), set(btp.alias_names), set(btp.alias_ids)
)
self.update_tags(self.search_field.text())

View File

@@ -682,7 +682,7 @@ class QtDriver(DriverMixin, QObject):
lambda: (
self.lib.add_tag(
panel.build_tag(),
set(panel.subtag_ids),
set(panel.parent_ids),
set(panel.alias_names),
set(panel.alias_ids),
),
@@ -858,7 +858,7 @@ class QtDriver(DriverMixin, QObject):
for field_id, value in parsed_items.items():
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], str):
value = self.lib.tag_from_strings(value)
self.lib.add_entry_field_type(
self.lib.add_field_to_entry(
entry.id,
field_id=field_id,
value=value,
@@ -867,7 +867,7 @@ class QtDriver(DriverMixin, QObject):
elif name == MacroID.BUILD_URL:
url = TagStudioCore.build_url(entry, source)
if url is not None:
self.lib.add_entry_field_type(entry.id, field_id=_FieldID.SOURCE, value=url)
self.lib.add_field_to_entry(entry.id, field_id=_FieldID.SOURCE, value=url)
elif name == MacroID.MATCH:
TagStudioCore.match_conditions(self.lib, entry.id)
elif name == MacroID.CLEAN_URL:

View File

@@ -22,7 +22,7 @@ from sqlalchemy.orm import Session
from src.core.constants import TS_FOLDER_NAME
from src.core.enums import LibraryPrefs
from src.core.library.alchemy.enums import TagColor
from src.core.library.alchemy.joins import TagSubtag
from src.core.library.alchemy.joins import TagParent
from src.core.library.alchemy.library import DEFAULT_TAG_DIFF
from src.core.library.alchemy.library import Library as SqliteLibrary
from src.core.library.alchemy.models import Entry, TagAlias
@@ -115,7 +115,7 @@ class JsonMigrationModal(QObject):
entries_text: str = Translations["json_migration.heading.entires"]
tags_text: str = Translations["json_migration.heading.tags"]
shorthand_text: str = tab + Translations["json_migration.heading.shorthands"]
subtags_text: str = tab + Translations["json_migration.heading.parent_tags"]
parent_tags_text: str = tab + Translations["json_migration.heading.parent_tags"]
aliases_text: str = tab + Translations["json_migration.heading.aliases"]
colors_text: str = tab + Translations["json_migration.heading.colors"]
ext_text: str = Translations["json_migration.heading.file_extension_list"]
@@ -129,7 +129,7 @@ class JsonMigrationModal(QObject):
self.fields_row: int = 2
self.tags_row: int = 3
self.shorthands_row: int = 4
self.subtags_row: int = 5
self.parent_tags_row: int = 5
self.aliases_row: int = 6
self.colors_row: int = 7
self.ext_row: int = 8
@@ -151,7 +151,7 @@ class JsonMigrationModal(QObject):
self.old_content_layout.addWidget(QLabel(field_parity_text), self.fields_row, 0)
self.old_content_layout.addWidget(QLabel(tags_text), self.tags_row, 0)
self.old_content_layout.addWidget(QLabel(shorthand_text), self.shorthands_row, 0)
self.old_content_layout.addWidget(QLabel(subtags_text), self.subtags_row, 0)
self.old_content_layout.addWidget(QLabel(parent_tags_text), self.parent_tags_row, 0)
self.old_content_layout.addWidget(QLabel(aliases_text), self.aliases_row, 0)
self.old_content_layout.addWidget(QLabel(colors_text), self.colors_row, 0)
self.old_content_layout.addWidget(QLabel(ext_text), self.ext_row, 0)
@@ -183,7 +183,7 @@ class JsonMigrationModal(QObject):
self.old_content_layout.addWidget(old_field_value, self.fields_row, 1)
self.old_content_layout.addWidget(old_tag_count, self.tags_row, 1)
self.old_content_layout.addWidget(old_shorthand_count, self.shorthands_row, 1)
self.old_content_layout.addWidget(old_subtag_value, self.subtags_row, 1)
self.old_content_layout.addWidget(old_subtag_value, self.parent_tags_row, 1)
self.old_content_layout.addWidget(old_alias_value, self.aliases_row, 1)
self.old_content_layout.addWidget(old_color_value, self.colors_row, 1)
self.old_content_layout.addWidget(old_ext_count, self.ext_row, 1)
@@ -192,7 +192,7 @@ class JsonMigrationModal(QObject):
self.old_content_layout.addWidget(QLabel(), self.path_row, 2)
self.old_content_layout.addWidget(QLabel(), self.fields_row, 2)
self.old_content_layout.addWidget(QLabel(), self.shorthands_row, 2)
self.old_content_layout.addWidget(QLabel(), self.subtags_row, 2)
self.old_content_layout.addWidget(QLabel(), self.parent_tags_row, 2)
self.old_content_layout.addWidget(QLabel(), self.aliases_row, 2)
self.old_content_layout.addWidget(QLabel(), self.colors_row, 2)
@@ -214,7 +214,7 @@ class JsonMigrationModal(QObject):
self.new_content_layout.addWidget(QLabel(field_parity_text), self.fields_row, 0)
self.new_content_layout.addWidget(QLabel(tags_text), self.tags_row, 0)
self.new_content_layout.addWidget(QLabel(shorthand_text), self.shorthands_row, 0)
self.new_content_layout.addWidget(QLabel(subtags_text), self.subtags_row, 0)
self.new_content_layout.addWidget(QLabel(parent_tags_text), self.parent_tags_row, 0)
self.new_content_layout.addWidget(QLabel(aliases_text), self.aliases_row, 0)
self.new_content_layout.addWidget(QLabel(colors_text), self.colors_row, 0)
self.new_content_layout.addWidget(QLabel(ext_text), self.ext_row, 0)
@@ -246,7 +246,7 @@ class JsonMigrationModal(QObject):
self.new_content_layout.addWidget(field_parity_value, self.fields_row, 1)
self.new_content_layout.addWidget(new_tag_count, self.tags_row, 1)
self.new_content_layout.addWidget(new_shorthand_count, self.shorthands_row, 1)
self.new_content_layout.addWidget(subtag_parity_value, self.subtags_row, 1)
self.new_content_layout.addWidget(subtag_parity_value, self.parent_tags_row, 1)
self.new_content_layout.addWidget(alias_parity_value, self.aliases_row, 1)
self.new_content_layout.addWidget(new_color_value, self.colors_row, 1)
self.new_content_layout.addWidget(new_ext_count, self.ext_row, 1)
@@ -257,7 +257,7 @@ class JsonMigrationModal(QObject):
self.new_content_layout.addWidget(QLabel(), self.fields_row, 2)
self.new_content_layout.addWidget(QLabel(), self.shorthands_row, 2)
self.new_content_layout.addWidget(QLabel(), self.tags_row, 2)
self.new_content_layout.addWidget(QLabel(), self.subtags_row, 2)
self.new_content_layout.addWidget(QLabel(), self.parent_tags_row, 2)
self.new_content_layout.addWidget(QLabel(), self.aliases_row, 2)
self.new_content_layout.addWidget(QLabel(), self.colors_row, 2)
self.new_content_layout.addWidget(QLabel(), self.ext_row, 2)
@@ -396,7 +396,7 @@ class JsonMigrationModal(QObject):
self.update_parity_value(self.fields_row, self.field_parity)
self.update_parity_value(self.path_row, self.path_parity)
self.update_parity_value(self.shorthands_row, self.shorthand_parity)
self.update_parity_value(self.subtags_row, self.subtag_parity)
self.update_parity_value(self.parent_tags_row, self.subtag_parity)
self.update_parity_value(self.aliases_row, self.alias_parity)
self.update_parity_value(self.colors_row, self.color_parity)
self.sql_lib.close()
@@ -639,39 +639,39 @@ class JsonMigrationModal(QObject):
return self.path_parity
def check_subtag_parity(self) -> bool:
"""Check if all JSON subtags match the new SQL subtags."""
sql_subtags: set[int] = None
json_subtags: set[int] = None
"""Check if all JSON parent tags match the new SQL parent tags."""
sql_parent_tags: set[int] = None
json_parent_tags: set[int] = None
with Session(self.sql_lib.engine) as session:
for tag in self.sql_lib.tags:
if tag.id in range(0, 1000):
break
tag_id = tag.id # Tag IDs start at 0
sql_subtags = set(
session.scalars(select(TagSubtag.child_id).where(TagSubtag.parent_id == tag.id))
sql_parent_tags = set(
session.scalars(select(TagParent.child_id).where(TagParent.parent_id == tag.id))
)
# sql_subtags = sql_subtags.difference([x for x in range(0, 1000)])
# sql_parent_tags = sql_parent_tags.difference([x for x in range(0, 1000)])
# JSON tags allowed self-parenting; SQL tags no longer allow this.
json_subtags = set(self.json_lib.get_tag(tag_id).subtag_ids)
json_subtags.discard(tag_id)
json_parent_tags = set(self.json_lib.get_tag(tag_id).subtag_ids)
json_parent_tags.discard(tag_id)
logger.info(
"[Subtag Parity]",
tag_id=tag_id,
json_subtags=json_subtags,
sql_subtags=sql_subtags,
json_parent_tags=json_parent_tags,
sql_parent_tags=sql_parent_tags,
)
if not (
sql_subtags is not None
and json_subtags is not None
and (sql_subtags == json_subtags)
sql_parent_tags is not None
and json_parent_tags is not None
and (sql_parent_tags == json_parent_tags)
):
self.discrepancies.append(
f"[Subtag Parity][Tag ID: {tag_id}]:"
f"\nOLD (JSON):{json_subtags}\nNEW (SQL):{sql_subtags}"
f"\nOLD (JSON):{json_parent_tags}\nNEW (SQL):{sql_parent_tags}"
)
self.subtag_parity = False
return self.subtag_parity

View File

@@ -163,7 +163,7 @@ class FieldContainers(QWidget):
"""
tag_obj = self.lib.get_tag(tag_id) # Get full object
if p_ids is None:
p_ids = tag_obj.subtag_ids
p_ids = tag_obj.parent_ids
for p_id in p_ids:
if cluster_map.get(p_id) is None:
@@ -172,10 +172,10 @@ class FieldContainers(QWidget):
if tag.id not in cluster_map[p_id]:
cluster_map[p_id].add(tag.id)
p_tag = self.lib.get_tag(p_id) # Get full object
if p_tag.subtag_ids:
if p_tag.parent_ids:
add_to_cluster(
tag_id,
[sub_id for sub_id in p_tag.subtag_ids if sub_id != tag_id],
[sub_id for sub_id in p_tag.parent_ids if sub_id != tag_id],
)
exhausted.add(p_id)
exhausted.add(tag_id)
@@ -240,7 +240,7 @@ class FieldContainers(QWidget):
)
for entry_id in self.driver.selected:
for field_item in field_list:
self.lib.add_entry_field_type(
self.lib.add_field_to_entry(
entry_id,
field_id=field_item.data(Qt.ItemDataRole.UserRole),
)

View File

@@ -76,7 +76,7 @@ class TagBoxWidget(FieldWidget):
self.edit_modal = PanelModal(
build_tag_panel,
tag.name, # TODO - display name including subtags
tag.name, # TODO - display name including parent tags
"Edit Tag",
done_callback=self.driver.preview_panel.update_widgets,
has_save=True,
@@ -85,7 +85,7 @@ class TagBoxWidget(FieldWidget):
self.edit_modal.saved.connect(
lambda: self.driver.lib.update_tag(
build_tag_panel.build_tag(),
subtag_ids=set(build_tag_panel.subtag_ids),
parent_ids=set(build_tag_panel.parent_ids),
alias_names=set(build_tag_panel.alias_names),
alias_ids=set(build_tag_panel.alias_ids),
)

View File

@@ -11,9 +11,9 @@ def test_build_tag_panel_add_sub_tag_callback(library, generate_tag):
panel: BuildTagPanel = BuildTagPanel(library, child)
panel.add_subtag_callback(parent.id)
panel.add_parent_tag_callback(parent.id)
assert len(panel.subtag_ids) == 1
assert len(panel.parent_ids) == 1
def test_build_tag_panel_remove_subtag_callback(library, generate_tag):
@@ -30,9 +30,9 @@ def test_build_tag_panel_remove_subtag_callback(library, generate_tag):
panel: BuildTagPanel = BuildTagPanel(library, child)
panel.remove_subtag_callback(parent.id)
panel.remove_parent_tag_callback(parent.id)
assert len(panel.subtag_ids) == 0
assert len(panel.parent_ids) == 0
import os
@@ -79,14 +79,14 @@ def test_build_tag_panel_set_subtags(library, generate_tag):
assert parent
assert child
library.add_subtag(child.id, parent.id)
library.add_parent_tag(child.id, parent.id)
child = library.get_tag(child.id)
panel: BuildTagPanel = BuildTagPanel(library, child)
assert len(panel.subtag_ids) == 1
assert panel.subtags_scroll_layout.count() == 1
assert len(panel.parent_ids) == 1
assert panel.parent_tags_scroll_layout.count() == 1
def test_build_tag_panel_add_aliases(library, generate_tag):

View File

@@ -18,9 +18,6 @@ def test_library_add_alias(library, generate_tag):
alias_names: set[str] = set()
alias_names.add("test_alias")
library.update_tag(tag, subtag_ids, alias_names, alias_ids)
# Note: ask if it is expected behaviour that you need to re-request
# for the tag. Or if the tag in memory should be updated
alias_ids = library.get_tag(tag.id).alias_ids
assert len(alias_ids) == 1
@@ -35,7 +32,6 @@ def test_library_get_alias(library, generate_tag):
alias_names: set[str] = set()
alias_names.add("test_alias")
library.update_tag(tag, subtag_ids, alias_names, alias_ids)
alias_ids = library.get_tag(tag.id).alias_ids
assert library.get_alias(tag.id, alias_ids[0]).name == "test_alias"
@@ -50,9 +46,7 @@ def test_library_update_alias(library, generate_tag):
alias_names: set[str] = set()
alias_names.add("test_alias")
library.update_tag(tag, subtag_ids, alias_names, alias_ids)
tag = library.get_tag(tag.id)
alias_ids = tag.alias_ids
alias_ids = library.get_tag(tag.id).alias_ids
assert library.get_alias(tag.id, alias_ids[0]).name == "test_alias"
@@ -61,7 +55,6 @@ def test_library_update_alias(library, generate_tag):
library.update_tag(tag, subtag_ids, alias_names, alias_ids)
tag = library.get_tag(tag.id)
assert len(tag.alias_ids) == 1
assert library.get_alias(tag.id, tag.alias_ids[0]).name == "alias_update"
@@ -77,9 +70,7 @@ def test_library_add_file(library):
)
assert not library.has_path_entry(entry.path)
assert library.add_entries([entry])
assert library.has_path_entry(entry.path)
@@ -107,7 +98,7 @@ def test_tag_subtag_itself(library, generate_tag):
library.update_tag(tag, {tag.id}, {}, {})
tag = library.get_tag(tag.id)
assert len(tag.subtag_ids) == 0
assert len(tag.parent_ids) == 0
def test_library_search(library, generate_tag, entry_full):
@@ -133,11 +124,8 @@ def test_tag_search(library):
tag = library.tags[0]
assert library.search_tags(tag.name.lower())
assert library.search_tags(tag.name.upper())
assert library.search_tags(tag.name[2:-2])
assert not library.search_tags(tag.name * 2)
@@ -168,11 +156,10 @@ def test_add_field_to_entry(library):
)
# meta tags + content tags
assert len(entry.tag_box_fields) == 2
assert library.add_entries([entry])
# When
library.add_entry_field_type(entry.id, field_id=_FieldID.TAGS)
library.add_field_to_entry(entry.id, field_id=_FieldID.TAGS)
# Then
entry = [x for x in library.get_entries(with_joins=True) if x.path == entry.path][0]
@@ -195,22 +182,22 @@ def test_add_field_tag(library: Library, entry_full, generate_tag):
assert [x.name for x in tag_field.tags if x.name == tag_name]
def test_subtags_add(library, generate_tag):
def test_parents_add(library, generate_tag):
# Given
tag = library.tags[0]
tag: Tag = library.tags[0]
assert tag.id is not None
subtag = generate_tag("subtag1")
subtag = library.add_tag(subtag)
assert subtag.id is not None
parent_tag = generate_tag("subtag1")
parent_tag = library.add_tag(parent_tag)
assert parent_tag.id is not None
# When
assert library.add_subtag(tag.id, subtag.id)
assert library.add_parent_tag(tag.id, parent_tag.id)
# Then
assert tag.id is not None
tag = library.get_tag(tag.id)
assert tag.subtag_ids
assert tag.parent_ids
def test_remove_tag(library, generate_tag):