feat(parity): migrate json libraries to sqlite (#604)

* feat(ui): add PagedPanel widget

* feat(ui): add MigrationModal widget

* feat: add basic json to sql conversion

* fix: chose `poolclass` based on file or memory db

* feat: migrate tag colors from json to sql

* feat: migrate entry fields from json to sql

- fix: tag name column no longer has unique constraint
- fix: tags are referenced by id in db queries
- fix: tag_search_panel no longer queries db on initialization; does not regress to empty search window when shown
- fix: tag name search no longer uses library grid FilterState object
- fix: tag name search now respects tag limit

* set default `is_new` case

* fix: limit correct tag query

* feat: migrate tag aliases and subtags from json to sql

* add migration timer

* fix(tests): fix broken tests

* rename methods, add docstrings

* revert tag id search, split tag name search

* fix: use correct type in sidecar macro

* tests: add json migration tests

* fix: drop leading dot from json extensions

* add special characters to json db test

* tests: add file path and entry field parity checks

* fix(ui): tag manager no longer starts empty

* fix: read old windows paths as posix

Addresses #298

* tests: add posix + windows paths to json library

* tests: add subtag, alias, and shorthand parity tests

* tests: ensure no none values in parity checks

* tests: add tag color test, use tag id in tag tests

* tests: fix and optimize tests

* tests: add discrepancy tracker

* refactor: reduce duplicate UI code

* fix: load non-sequential entry ids

* fix(ui): sort tags in the preview panel

* tests(fix): prioritize `None` check over equality

* fix(tests): fix multi "same tag field type" tests

* ui: increase height of migration modal

* feat: add progress bar to migration ui

* fix(ui): sql values update earlier

* refactor: use `get_color_from_str` in test

* refactor: migrate tags before aliases and subtags

* remove unused assertion

* refactor: use `json_migration_req` flag
This commit is contained in:
Travis Abendshien
2024-11-30 13:00:08 -08:00
committed by GitHub
parent b7e652ad8d
commit ef68603322
17 changed files with 1244 additions and 80 deletions

6
.gitignore vendored
View File

@@ -55,7 +55,6 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
tagstudio/tests/fixtures/library/*
# Translations
*.mo
@@ -255,11 +254,14 @@ compile_commands.json
# Ignore all local history of files
.history
.ionide
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,qt
# TagStudio
.TagStudio
!*/tests/**/.TagStudio
tagstudio/tests/fixtures/library/*
tagstudio/tests/fixtures/json_library/.TagStudio/*.sqlite
TagStudio.ini
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,qt
.envrc
.direnv

View File

@@ -42,6 +42,13 @@ class TagColor(enum.IntEnum):
COOL_GRAY = 36
OLIVE = 37
@staticmethod
def get_color_from_str(color_name: str) -> "TagColor":
for color in TagColor:
if color.name == color_name.upper().replace(" ", "_"):
return color
return TagColor.DEFAULT
class SearchMode(enum.IntEnum):
"""Operational modes for item searching."""

View File

@@ -1,5 +1,6 @@
import re
import shutil
import time
import unicodedata
from dataclasses import dataclass
from datetime import UTC, datetime
@@ -9,9 +10,11 @@ from typing import Any, Iterator, Type
from uuid import uuid4
import structlog
from humanfriendly import format_timespan
from sqlalchemy import (
URL,
Engine,
NullPool,
and_,
create_engine,
delete,
@@ -29,6 +32,7 @@ from sqlalchemy.orm import (
make_transient,
selectinload,
)
from src.core.library.json.library import Library as JsonLibrary # type: ignore
from ...constants import (
BACKUP_FOLDER_NAME,
@@ -122,6 +126,7 @@ class LibraryStatus:
success: bool
library_path: Path | None = None
message: str | None = None
json_migration_req: bool = False
class Library:
@@ -133,7 +138,8 @@ class Library:
folder: Folder | None
included_files: set[Path] = set()
FILENAME: str = "ts_library.sqlite"
SQL_FILENAME: str = "ts_library.sqlite"
JSON_FILENAME: str = "ts_library.json"
def close(self):
if self.engine:
@@ -143,32 +149,119 @@ class Library:
self.folder = None
self.included_files = set()
def migrate_json_to_sqlite(self, json_lib: JsonLibrary):
"""Migrate JSON library data to the SQLite database."""
logger.info("Starting Library Conversion...")
start_time = time.time()
folder: Folder = Folder(path=self.library_dir, uuid=str(uuid4()))
# Tags
for tag in json_lib.tags:
self.add_tag(
Tag(
id=tag.id,
name=tag.name,
shorthand=tag.shorthand,
color=TagColor.get_color_from_str(tag.color),
)
)
# Tag Aliases
for tag in json_lib.tags:
for alias in tag.aliases:
self.add_alias(name=alias, tag_id=tag.id)
# Tag Subtags
for tag in json_lib.tags:
for subtag_id in tag.subtag_ids:
self.add_subtag(parent_id=tag.id, child_id=subtag_id)
# Entries
self.add_entries(
[
Entry(
path=entry.path / entry.filename,
folder=folder,
fields=[],
id=entry.id + 1, # JSON IDs start at 0 instead of 1
)
for entry in json_lib.entries
]
)
for entry in json_lib.entries:
for field in entry.fields:
for k, v in field.items():
self.add_entry_field_type(
entry_ids=(entry.id + 1), # JSON IDs start at 0 instead of 1
field_id=self.get_field_name_from_id(k),
value=v,
)
# Preferences
self.set_prefs(LibraryPrefs.EXTENSION_LIST, [x.strip(".") for x in json_lib.ext_list])
self.set_prefs(LibraryPrefs.IS_EXCLUDE_LIST, json_lib.is_exclude_list)
end_time = time.time()
logger.info(f"Library Converted! ({format_timespan(end_time-start_time)})")
def get_field_name_from_id(self, field_id: int) -> _FieldID:
for f in _FieldID:
if field_id == f.value.id:
return f
return None
def open_library(self, library_dir: Path, storage_path: str | None = None) -> LibraryStatus:
is_new: bool = True
if storage_path == ":memory:":
self.storage_path = storage_path
is_new = True
return self.open_sqlite_library(library_dir, is_new)
else:
self.verify_ts_folders(library_dir)
self.storage_path = library_dir / TS_FOLDER_NAME / self.FILENAME
is_new = not self.storage_path.exists()
self.storage_path = library_dir / TS_FOLDER_NAME / self.SQL_FILENAME
if self.verify_ts_folder(library_dir) and (is_new := not self.storage_path.exists()):
json_path = library_dir / TS_FOLDER_NAME / self.JSON_FILENAME
if json_path.exists():
return LibraryStatus(
success=False,
library_path=library_dir,
message="[JSON] Legacy v9.4 library requires conversion to v9.5+",
json_migration_req=True,
)
return self.open_sqlite_library(library_dir, is_new)
def open_sqlite_library(
self, library_dir: Path, is_new: bool, add_default_data: bool = True
) -> LibraryStatus:
connection_string = URL.create(
drivername="sqlite",
database=str(self.storage_path),
)
# NOTE: File-based databases should use NullPool to create new DB connection in order to
# keep connections on separate threads, which prevents the DB files from being locked
# even after a connection has been closed.
# SingletonThreadPool (the default for :memory:) should still be used for in-memory DBs.
# More info can be found on the SQLAlchemy docs:
# https://docs.sqlalchemy.org/en/20/changelog/migration_07.html
# Under -> sqlite-the-sqlite-dialect-now-uses-nullpool-for-file-based-databases
poolclass = None if self.storage_path == ":memory:" else NullPool
logger.info("opening library", library_dir=library_dir, connection_string=connection_string)
self.engine = create_engine(connection_string)
logger.info(
"Opening SQLite Library", library_dir=library_dir, connection_string=connection_string
)
self.engine = create_engine(connection_string, poolclass=poolclass)
with Session(self.engine) as session:
make_tables(self.engine)
tags = get_default_tags()
try:
session.add_all(tags)
session.commit()
except IntegrityError:
# default tags may exist already
session.rollback()
if add_default_data:
tags = get_default_tags()
try:
session.add_all(tags)
session.commit()
except IntegrityError:
# default tags may exist already
session.rollback()
# dont check db version when creating new library
if not is_new:
@@ -219,7 +312,6 @@ class Library:
db_version=db_version.value,
expected=LibraryPrefs.DB_VERSION.default,
)
# TODO - handle migration
return LibraryStatus(
success=False,
message=(
@@ -354,8 +446,12 @@ class Library:
return list(tags_list)
def verify_ts_folders(self, library_dir: Path) -> None:
"""Verify/create folders required by TagStudio."""
def verify_ts_folder(self, library_dir: Path) -> bool:
"""Verify/create folders required by TagStudio.
Returns:
bool: True if path exists, False if it needed to be created.
"""
if library_dir is None:
raise ValueError("No path set.")
@@ -366,6 +462,8 @@ class Library:
if not full_ts_path.exists():
logger.info("creating library directory", dir=full_ts_path)
full_ts_path.mkdir(parents=True, exist_ok=True)
return False
return True
def add_entries(self, items: list[Entry]) -> list[int]:
"""Add multiple Entry records to the Library."""
@@ -507,21 +605,23 @@ class Library:
def search_tags(
self,
search: FilterState,
name: str,
) -> list[Tag]:
"""Return a list of Tag records matching the query."""
tag_limit = 100
with Session(self.engine) as session:
query = select(Tag)
query = query.options(
selectinload(Tag.subtags),
selectinload(Tag.aliases),
)
).limit(tag_limit)
if search.tag:
if name:
query = query.where(
or_(
Tag.name.icontains(search.tag),
Tag.shorthand.icontains(search.tag),
Tag.name.icontains(name),
Tag.shorthand.icontains(name),
)
)
@@ -531,7 +631,7 @@ class Library:
logger.info(
"searching tags",
search=search,
search=name,
statement=str(query),
results=len(res),
)
@@ -694,7 +794,7 @@ class Library:
*,
field: ValueType | None = None,
field_id: _FieldID | str | None = None,
value: str | datetime | list[str] | None = None,
value: str | datetime | list[int] | None = None,
) -> bool:
logger.info(
"add_field_to_entry",
@@ -727,8 +827,11 @@ class Library:
if value:
assert isinstance(value, list)
for tag in value:
field_model.tags.add(Tag(name=tag))
with Session(self.engine) as session:
for tag_id in list(set(value)):
tag = session.scalar(select(Tag).where(Tag.id == tag_id))
field_model.tags.add(tag)
session.flush()
elif field.type == FieldTypeEnum.DATETIME:
field_model = DatetimeField(
@@ -760,6 +863,28 @@ class Library:
)
return True
def tag_from_strings(self, strings: list[str] | str) -> list[int]:
"""Create a Tag from a given string."""
# TODO: Port over tag searching with aliases fallbacks
# and context clue ranking for string searches.
tags: list[int] = []
if isinstance(strings, str):
strings = [strings]
with Session(self.engine) as session:
for string in strings:
tag = session.scalar(select(Tag).where(Tag.name == string))
if tag:
tags.append(tag.id)
else:
new = session.add(Tag(name=string))
if new:
tags.append(new.id)
session.flush()
session.commit()
return tags
def add_tag(
self,
tag: Tag,
@@ -852,7 +977,7 @@ class Library:
target_path = self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME / filename
shutil.copy2(
self.library_dir / TS_FOLDER_NAME / self.FILENAME,
self.library_dir / TS_FOLDER_NAME / self.SQL_FILENAME,
target_path,
)
@@ -879,15 +1004,15 @@ class Library:
return alias
def add_subtag(self, base_id: int, new_tag_id: int) -> bool:
if base_id == new_tag_id:
def add_subtag(self, parent_id: int, child_id: int) -> bool:
if parent_id == child_id:
return False
# open session and save as parent tag
with Session(self.engine) as session:
subtag = TagSubtag(
parent_id=base_id,
child_id=new_tag_id,
parent_id=parent_id,
child_id=child_id,
)
try:
@@ -899,6 +1024,22 @@ class Library:
logger.exception("IntegrityError")
return False
def add_alias(self, name: str, tag_id: int) -> bool:
with Session(self.engine) as session:
alias = TagAlias(
name=name,
tag_id=tag_id,
)
try:
session.add(alias)
session.commit()
return True
except IntegrityError:
session.rollback()
logger.exception("IntegrityError")
return False
def remove_subtag(self, base_id: int, remove_tag_id: int) -> bool:
with Session(self.engine) as session:
p_id = base_id

View File

@@ -43,7 +43,7 @@ class Tag(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(unique=True)
name: Mapped[str]
shorthand: Mapped[str | None]
color: Mapped[TagColor]
icon: Mapped[str | None]
@@ -78,14 +78,14 @@ class Tag(Base):
def __init__(
self,
name: str,
id: int | None = None,
name: str | None = None,
shorthand: str | None = None,
aliases: set[TagAlias] | None = None,
parent_tags: set["Tag"] | None = None,
subtags: set["Tag"] | None = None,
icon: str | None = None,
color: TagColor = TagColor.DEFAULT,
id: int | None = None,
):
self.name = name
self.aliases = aliases or set()
@@ -177,10 +177,11 @@ class Entry(Base):
path: Path,
folder: Folder,
fields: list[BaseField],
id: int | None = None,
) -> None:
self.path = path
self.folder = folder
self.id = id
self.suffix = path.suffix.lstrip(".").lower()
for field in fields:

View File

@@ -414,17 +414,17 @@ class Library:
"""Verifies/creates folders required by TagStudio."""
full_ts_path = self.library_dir / TS_FOLDER_NAME
full_backup_path = self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME
full_collage_path = self.library_dir / TS_FOLDER_NAME / COLLAGE_FOLDER_NAME
# full_backup_path = self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME
# full_collage_path = self.library_dir / TS_FOLDER_NAME / COLLAGE_FOLDER_NAME
if not os.path.isdir(full_ts_path):
os.mkdir(full_ts_path)
if not os.path.isdir(full_backup_path):
os.mkdir(full_backup_path)
# if not os.path.isdir(full_backup_path):
# os.mkdir(full_backup_path)
if not os.path.isdir(full_collage_path):
os.mkdir(full_collage_path)
# if not os.path.isdir(full_collage_path):
# os.mkdir(full_collage_path)
def verify_default_tags(self, tag_list: list) -> list:
"""
@@ -449,7 +449,7 @@ class Library:
return_code = OpenStatus.CORRUPTED
_path: Path = self._fix_lib_path(path)
logger.info("opening library", path=_path)
logger.info("Opening JSON Library", path=_path)
if (_path / TS_FOLDER_NAME / "ts_library.json").exists():
try:
with open(
@@ -554,7 +554,7 @@ class Library:
self._next_entry_id += 1
filename = entry.get("filename", "")
e_path = entry.get("path", "")
e_path = entry.get("path", "").replace("\\", "/")
fields: list = []
if "fields" in entry:
# Cast JSON str keys to ints

View File

@@ -2,7 +2,9 @@
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import structlog
from PySide6.QtCore import QSize, Qt, Signal
from PySide6.QtGui import QShowEvent
from PySide6.QtWidgets import (
QFrame,
QHBoxLayout,
@@ -12,11 +14,16 @@ from PySide6.QtWidgets import (
QWidget,
)
from src.core.library import Library, Tag
from src.core.library.alchemy.enums import FilterState
from src.qt.modals.build_tag import BuildTagPanel
from src.qt.widgets.panel import PanelModal, PanelWidget
from src.qt.widgets.tag import TagWidget
logger = structlog.get_logger(__name__)
# TODO: This class shares the majority of its code with tag_search.py.
# It should either be made DRY, or be replaced with the intended and more robust
# Tag Management tab/pane outlined on the Feature Roadmap.
class TagDatabasePanel(PanelWidget):
tag_chosen = Signal(int)
@@ -24,8 +31,8 @@ class TagDatabasePanel(PanelWidget):
def __init__(self, library: Library):
super().__init__()
self.lib: Library = library
self.is_initialized: bool = False
self.first_tag_id = -1
self.tag_limit = 30
self.setMinimumSize(300, 400)
self.root_layout = QVBoxLayout(self)
@@ -54,7 +61,6 @@ class TagDatabasePanel(PanelWidget):
self.root_layout.addWidget(self.search_field)
self.root_layout.addWidget(self.scroll_area)
self.update_tags()
def on_return(self, text: str):
if text and self.first_tag_id >= 0:
@@ -67,12 +73,13 @@ class TagDatabasePanel(PanelWidget):
def update_tags(self, query: str | None = None):
# TODO: Look at recycling rather than deleting and re-initializing
logger.info("[Tag Manager Modal] Updating Tags")
while self.scroll_layout.itemAt(0):
self.scroll_layout.takeAt(0).widget().deleteLater()
tags = self.lib.search_tags(FilterState(path=query, page_size=self.tag_limit))
tags_results = self.lib.search_tags(name=query)
for tag in tags:
for tag in tags_results:
container = QWidget()
row = QHBoxLayout(container)
row.setContentsMargins(0, 0, 0, 0)
@@ -101,3 +108,9 @@ class TagDatabasePanel(PanelWidget):
def edit_tag_callback(self, btp: BuildTagPanel):
self.lib.update_tag(btp.build_tag(), btp.subtag_ids, btp.alias_names, btp.alias_ids)
self.update_tags(self.search_field.text())
def showEvent(self, event: QShowEvent) -> None: # noqa N802
if not self.is_initialized:
self.update_tags()
self.is_initialized = True
return super().showEvent(event)

View File

@@ -7,6 +7,7 @@ import math
import structlog
from PySide6.QtCore import QSize, Qt, Signal
from PySide6.QtGui import QShowEvent
from PySide6.QtWidgets import (
QFrame,
QHBoxLayout,
@@ -17,7 +18,6 @@ from PySide6.QtWidgets import (
QWidget,
)
from src.core.library import Library
from src.core.library.alchemy.enums import FilterState
from src.core.palette import ColorType, get_tag_color
from src.qt.widgets.panel import PanelWidget
from src.qt.widgets.tag import TagWidget
@@ -32,8 +32,8 @@ class TagSearchPanel(PanelWidget):
super().__init__()
self.lib = library
self.exclude = exclude
self.is_initialized: bool = False
self.first_tag_id = None
self.tag_limit = 100
self.setMinimumSize(300, 400)
self.root_layout = QVBoxLayout(self)
self.root_layout.setContentsMargins(6, 0, 6, 0)
@@ -61,11 +61,9 @@ class TagSearchPanel(PanelWidget):
self.root_layout.addWidget(self.search_field)
self.root_layout.addWidget(self.scroll_area)
self.update_tags()
def on_return(self, text: str):
if text and self.first_tag_id is not None:
# callback(self.first_tag_id)
self.tag_chosen.emit(self.first_tag_id)
self.search_field.setText("")
self.update_tags()
@@ -73,20 +71,17 @@ class TagSearchPanel(PanelWidget):
self.search_field.setFocus()
self.parentWidget().hide()
def update_tags(self, name: str | None = None):
def update_tags(self, query: str | None = None):
logger.info("[Tag Search Modal] Updating Tags")
while self.scroll_layout.count():
self.scroll_layout.takeAt(0).widget().deleteLater()
found_tags = self.lib.search_tags(
FilterState(
path=name,
page_size=self.tag_limit,
)
)
tag_results = self.lib.search_tags(name=query)
for tag in found_tags:
for tag in tag_results:
if self.exclude is not None and tag.id in self.exclude:
continue
c = QWidget()
layout = QHBoxLayout(c)
layout.setContentsMargins(0, 0, 0, 0)
@@ -123,3 +118,9 @@ class TagSearchPanel(PanelWidget):
self.scroll_layout.addWidget(c)
self.search_field.setFocus()
def showEvent(self, event: QShowEvent) -> None: # noqa N802
if not self.is_initialized:
self.update_tags()
self.is_initialized = True
return super().showEvent(event)

View File

@@ -89,6 +89,7 @@ from src.qt.modals.folders_to_tags import FoldersToTagsModal
from src.qt.modals.tag_database import TagDatabasePanel
from src.qt.resource_manager import ResourceManager
from src.qt.widgets.item_thumb import BadgeType, ItemThumb
from src.qt.widgets.migration_modal import JsonMigrationModal
from src.qt.widgets.panel import PanelModal
from src.qt.widgets.preview_panel import PreviewPanel
from src.qt.widgets.progress import ProgressWidget
@@ -468,6 +469,7 @@ class QtDriver(DriverMixin, QObject):
self.thumb_renderers: list[ThumbRenderer] = []
self.filter = FilterState()
self.init_library_window()
self.migration_modal: JsonMigrationModal = None
path_result = self.evaluate_path(self.args.open)
# check status of library path evaluating
@@ -807,6 +809,8 @@ class QtDriver(DriverMixin, QObject):
elif name == MacroID.SIDECAR:
parsed_items = TagStudioCore.get_gdl_sidecar(ful_path, source)
for field_id, value in parsed_items.items():
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], str):
value = self.lib.tag_from_strings(value)
self.lib.add_entry_field_type(
entry.id,
field_id=field_id,
@@ -1187,14 +1191,27 @@ class QtDriver(DriverMixin, QObject):
self.settings.endGroup()
self.settings.sync()
def open_library(self, path: Path) -> LibraryStatus:
def open_library(self, path: Path) -> None:
"""Open a TagStudio library."""
open_message: str = f'Opening Library "{str(path)}"...'
self.main_window.landing_widget.set_status_label(open_message)
self.main_window.statusbar.showMessage(open_message, 3)
self.main_window.repaint()
open_status = self.lib.open_library(path)
open_status: LibraryStatus = self.lib.open_library(path)
# Migration is required
if open_status.json_migration_req:
self.migration_modal = JsonMigrationModal(path)
self.migration_modal.migration_finished.connect(
lambda: self.init_library(path, self.lib.open_library(path))
)
self.main_window.landing_widget.set_status_label("")
self.migration_modal.paged_panel.show()
else:
self.init_library(path, open_status)
def init_library(self, path: Path, open_status: LibraryStatus):
if not open_status.success:
self.show_error_message(open_status.message or "Error opening library.")
return open_status

View File

@@ -0,0 +1,784 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from pathlib import Path
import structlog
from PySide6.QtCore import QObject, Qt, QThreadPool, Signal
from PySide6.QtWidgets import (
QApplication,
QGridLayout,
QHBoxLayout,
QLabel,
QMessageBox,
QProgressDialog,
QSizePolicy,
QVBoxLayout,
QWidget,
)
from sqlalchemy import and_, select
from sqlalchemy.orm import Session
from src.core.constants import TS_FOLDER_NAME
from src.core.enums import LibraryPrefs
from src.core.library.alchemy.enums import FieldTypeEnum, TagColor
from src.core.library.alchemy.fields import TagBoxField, _FieldID
from src.core.library.alchemy.joins import TagField, TagSubtag
from src.core.library.alchemy.library import Library as SqliteLibrary
from src.core.library.alchemy.models import Entry, Tag, TagAlias
from src.core.library.json.library import Library as JsonLibrary # type: ignore
from src.qt.helpers.custom_runnable import CustomRunnable
from src.qt.helpers.function_iterator import FunctionIterator
from src.qt.helpers.qbutton_wrapper import QPushButtonWrapper
from src.qt.widgets.paged_panel.paged_body_wrapper import PagedBodyWrapper
from src.qt.widgets.paged_panel.paged_panel import PagedPanel
from src.qt.widgets.paged_panel.paged_panel_state import PagedPanelState
logger = structlog.get_logger(__name__)
class JsonMigrationModal(QObject):
"""A modal for data migration from v9.4 JSON to v9.5+ SQLite."""
migration_cancelled = Signal()
migration_finished = Signal()
def __init__(self, path: Path):
super().__init__()
self.done: bool = False
self.path: Path = path
self.stack: list[PagedPanelState] = []
self.json_lib: JsonLibrary = None
self.sql_lib: SqliteLibrary = None
self.is_migration_initialized: bool = False
self.discrepancies: list[str] = []
self.title: str = f'Save Format Migration: "{self.path}"'
self.warning: str = "<b><a style='color: #e22c3c'>(!)</a></b>"
self.old_entry_count: int = 0
self.old_tag_count: int = 0
self.old_ext_count: int = 0
self.old_ext_type: bool = None
self.field_parity: bool = False
self.path_parity: bool = False
self.shorthand_parity: bool = False
self.subtag_parity: bool = False
self.alias_parity: bool = False
self.color_parity: bool = False
self.init_page_info()
self.init_page_convert()
self.paged_panel: PagedPanel = PagedPanel((700, 640), self.stack)
def init_page_info(self) -> None:
"""Initialize the migration info page."""
body_wrapper: PagedBodyWrapper = PagedBodyWrapper()
body_label: QLabel = QLabel(
"Library save files created with TagStudio versions <b>9.4 and below</b> will "
"need to be migrated to the new <b>v9.5+</b> format."
"<br>"
"<h2>What you need to know:</h2>"
"<ul>"
"<li>Your existing library save file will <b><i>NOT</i></b> be deleted</li>"
"<li>Your personal files will <b><i>NOT</i></b> be deleted, moved, or modified</li>"
"<li>The new v9.5+ save format can not be opened in earlier versions of TagStudio</li>"
"</ul>"
)
body_label.setWordWrap(True)
body_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
body_wrapper.layout().addWidget(body_label)
body_wrapper.layout().setContentsMargins(0, 36, 0, 0)
cancel_button: QPushButtonWrapper = QPushButtonWrapper("Cancel")
next_button: QPushButtonWrapper = QPushButtonWrapper("Continue")
cancel_button.clicked.connect(self.migration_cancelled.emit)
self.stack.append(
PagedPanelState(
title=self.title,
body_wrapper=body_wrapper,
buttons=[cancel_button, 1, next_button],
connect_to_back=[cancel_button],
connect_to_next=[next_button],
)
)
def init_page_convert(self) -> None:
"""Initialize the migration conversion page."""
self.body_wrapper_01: PagedBodyWrapper = PagedBodyWrapper()
body_container: QWidget = QWidget()
body_container_layout: QHBoxLayout = QHBoxLayout(body_container)
body_container_layout.setContentsMargins(0, 0, 0, 0)
tab: str = " "
self.match_text: str = "Matched"
self.differ_text: str = "Discrepancy"
entries_text: str = "Entries:"
tags_text: str = "Tags:"
shorthand_text: str = tab + "Shorthands:"
subtags_text: str = tab + "Parent Tags:"
aliases_text: str = tab + "Aliases:"
colors_text: str = tab + "Colors:"
ext_text: str = "File Extension List:"
ext_type_text: str = "Extension List Type:"
desc_text: str = (
"<br>Start and preview the results of the library migration process. "
'The converted library will <i>not</i> be used unless you click "Finish Migration". '
"<br><br>"
'Library data should either have matching values or a feature a "Matched" label. '
'Values that do not match will be displayed in red and feature a "<b>(!)</b>" '
"symbol next to them."
"<br><center><i>"
"This process may take up to several minutes for larger libraries."
"</i></center>"
)
path_parity_text: str = tab + "Paths:"
field_parity_text: str = tab + "Fields:"
self.entries_row: int = 0
self.path_row: int = 1
self.fields_row: int = 2
self.tags_row: int = 3
self.shorthands_row: int = 4
self.subtags_row: int = 5
self.aliases_row: int = 6
self.colors_row: int = 7
self.ext_row: int = 8
self.ext_type_row: int = 9
old_lib_container: QWidget = QWidget()
old_lib_layout: QVBoxLayout = QVBoxLayout(old_lib_container)
old_lib_title: QLabel = QLabel("<h2>v9.4 Library</h2>")
old_lib_title.setAlignment(Qt.AlignmentFlag.AlignCenter)
old_lib_layout.addWidget(old_lib_title)
old_content_container: QWidget = QWidget()
self.old_content_layout: QGridLayout = QGridLayout(old_content_container)
self.old_content_layout.setContentsMargins(0, 0, 0, 0)
self.old_content_layout.setSpacing(3)
self.old_content_layout.addWidget(QLabel(entries_text), self.entries_row, 0)
self.old_content_layout.addWidget(QLabel(path_parity_text), self.path_row, 0)
self.old_content_layout.addWidget(QLabel(field_parity_text), self.fields_row, 0)
self.old_content_layout.addWidget(QLabel(tags_text), self.tags_row, 0)
self.old_content_layout.addWidget(QLabel(shorthand_text), self.shorthands_row, 0)
self.old_content_layout.addWidget(QLabel(subtags_text), self.subtags_row, 0)
self.old_content_layout.addWidget(QLabel(aliases_text), self.aliases_row, 0)
self.old_content_layout.addWidget(QLabel(colors_text), self.colors_row, 0)
self.old_content_layout.addWidget(QLabel(ext_text), self.ext_row, 0)
self.old_content_layout.addWidget(QLabel(ext_type_text), self.ext_type_row, 0)
old_entry_count: QLabel = QLabel()
old_entry_count.setAlignment(Qt.AlignmentFlag.AlignRight)
old_path_value: QLabel = QLabel()
old_path_value.setAlignment(Qt.AlignmentFlag.AlignRight)
old_field_value: QLabel = QLabel()
old_field_value.setAlignment(Qt.AlignmentFlag.AlignRight)
old_tag_count: QLabel = QLabel()
old_tag_count.setAlignment(Qt.AlignmentFlag.AlignRight)
old_shorthand_count: QLabel = QLabel()
old_shorthand_count.setAlignment(Qt.AlignmentFlag.AlignRight)
old_subtag_value: QLabel = QLabel()
old_subtag_value.setAlignment(Qt.AlignmentFlag.AlignRight)
old_alias_value: QLabel = QLabel()
old_alias_value.setAlignment(Qt.AlignmentFlag.AlignRight)
old_color_value: QLabel = QLabel()
old_color_value.setAlignment(Qt.AlignmentFlag.AlignRight)
old_ext_count: QLabel = QLabel()
old_ext_count.setAlignment(Qt.AlignmentFlag.AlignRight)
old_ext_type: QLabel = QLabel()
old_ext_type.setAlignment(Qt.AlignmentFlag.AlignRight)
self.old_content_layout.addWidget(old_entry_count, self.entries_row, 1)
self.old_content_layout.addWidget(old_path_value, self.path_row, 1)
self.old_content_layout.addWidget(old_field_value, self.fields_row, 1)
self.old_content_layout.addWidget(old_tag_count, self.tags_row, 1)
self.old_content_layout.addWidget(old_shorthand_count, self.shorthands_row, 1)
self.old_content_layout.addWidget(old_subtag_value, self.subtags_row, 1)
self.old_content_layout.addWidget(old_alias_value, self.aliases_row, 1)
self.old_content_layout.addWidget(old_color_value, self.colors_row, 1)
self.old_content_layout.addWidget(old_ext_count, self.ext_row, 1)
self.old_content_layout.addWidget(old_ext_type, self.ext_type_row, 1)
self.old_content_layout.addWidget(QLabel(), self.path_row, 2)
self.old_content_layout.addWidget(QLabel(), self.fields_row, 2)
self.old_content_layout.addWidget(QLabel(), self.shorthands_row, 2)
self.old_content_layout.addWidget(QLabel(), self.subtags_row, 2)
self.old_content_layout.addWidget(QLabel(), self.aliases_row, 2)
self.old_content_layout.addWidget(QLabel(), self.colors_row, 2)
old_lib_layout.addWidget(old_content_container)
new_lib_container: QWidget = QWidget()
new_lib_layout: QVBoxLayout = QVBoxLayout(new_lib_container)
new_lib_title: QLabel = QLabel("<h2>v9.5+ Library</h2>")
new_lib_title.setAlignment(Qt.AlignmentFlag.AlignCenter)
new_lib_layout.addWidget(new_lib_title)
new_content_container: QWidget = QWidget()
self.new_content_layout: QGridLayout = QGridLayout(new_content_container)
self.new_content_layout.setContentsMargins(0, 0, 0, 0)
self.new_content_layout.setSpacing(3)
self.new_content_layout.addWidget(QLabel(entries_text), self.entries_row, 0)
self.new_content_layout.addWidget(QLabel(path_parity_text), self.path_row, 0)
self.new_content_layout.addWidget(QLabel(field_parity_text), self.fields_row, 0)
self.new_content_layout.addWidget(QLabel(tags_text), self.tags_row, 0)
self.new_content_layout.addWidget(QLabel(shorthand_text), self.shorthands_row, 0)
self.new_content_layout.addWidget(QLabel(subtags_text), self.subtags_row, 0)
self.new_content_layout.addWidget(QLabel(aliases_text), self.aliases_row, 0)
self.new_content_layout.addWidget(QLabel(colors_text), self.colors_row, 0)
self.new_content_layout.addWidget(QLabel(ext_text), self.ext_row, 0)
self.new_content_layout.addWidget(QLabel(ext_type_text), self.ext_type_row, 0)
new_entry_count: QLabel = QLabel()
new_entry_count.setAlignment(Qt.AlignmentFlag.AlignRight)
path_parity_value: QLabel = QLabel()
path_parity_value.setAlignment(Qt.AlignmentFlag.AlignRight)
field_parity_value: QLabel = QLabel()
field_parity_value.setAlignment(Qt.AlignmentFlag.AlignRight)
new_tag_count: QLabel = QLabel()
new_tag_count.setAlignment(Qt.AlignmentFlag.AlignRight)
new_shorthand_count: QLabel = QLabel()
new_shorthand_count.setAlignment(Qt.AlignmentFlag.AlignRight)
subtag_parity_value: QLabel = QLabel()
subtag_parity_value.setAlignment(Qt.AlignmentFlag.AlignRight)
alias_parity_value: QLabel = QLabel()
alias_parity_value.setAlignment(Qt.AlignmentFlag.AlignRight)
new_color_value: QLabel = QLabel()
new_color_value.setAlignment(Qt.AlignmentFlag.AlignRight)
new_ext_count: QLabel = QLabel()
new_ext_count.setAlignment(Qt.AlignmentFlag.AlignRight)
new_ext_type: QLabel = QLabel()
new_ext_type.setAlignment(Qt.AlignmentFlag.AlignRight)
self.new_content_layout.addWidget(new_entry_count, self.entries_row, 1)
self.new_content_layout.addWidget(path_parity_value, self.path_row, 1)
self.new_content_layout.addWidget(field_parity_value, self.fields_row, 1)
self.new_content_layout.addWidget(new_tag_count, self.tags_row, 1)
self.new_content_layout.addWidget(new_shorthand_count, self.shorthands_row, 1)
self.new_content_layout.addWidget(subtag_parity_value, self.subtags_row, 1)
self.new_content_layout.addWidget(alias_parity_value, self.aliases_row, 1)
self.new_content_layout.addWidget(new_color_value, self.colors_row, 1)
self.new_content_layout.addWidget(new_ext_count, self.ext_row, 1)
self.new_content_layout.addWidget(new_ext_type, self.ext_type_row, 1)
self.new_content_layout.addWidget(QLabel(), self.entries_row, 2)
self.new_content_layout.addWidget(QLabel(), self.path_row, 2)
self.new_content_layout.addWidget(QLabel(), self.fields_row, 2)
self.new_content_layout.addWidget(QLabel(), self.shorthands_row, 2)
self.new_content_layout.addWidget(QLabel(), self.tags_row, 2)
self.new_content_layout.addWidget(QLabel(), self.subtags_row, 2)
self.new_content_layout.addWidget(QLabel(), self.aliases_row, 2)
self.new_content_layout.addWidget(QLabel(), self.colors_row, 2)
self.new_content_layout.addWidget(QLabel(), self.ext_row, 2)
self.new_content_layout.addWidget(QLabel(), self.ext_type_row, 2)
new_lib_layout.addWidget(new_content_container)
desc_label = QLabel(desc_text)
desc_label.setWordWrap(True)
body_container_layout.addStretch(2)
body_container_layout.addWidget(old_lib_container)
body_container_layout.addStretch(1)
body_container_layout.addWidget(new_lib_container)
body_container_layout.addStretch(2)
self.body_wrapper_01.layout().addWidget(body_container)
self.body_wrapper_01.layout().addWidget(desc_label)
self.body_wrapper_01.layout().setSpacing(12)
back_button: QPushButtonWrapper = QPushButtonWrapper("Back")
start_button: QPushButtonWrapper = QPushButtonWrapper("Start and Preview")
start_button.setMinimumWidth(120)
start_button.clicked.connect(self.migrate)
start_button.clicked.connect(lambda: finish_button.setDisabled(False))
start_button.clicked.connect(lambda: start_button.setDisabled(True))
finish_button: QPushButtonWrapper = QPushButtonWrapper("Finish Migration")
finish_button.setMinimumWidth(120)
finish_button.setDisabled(True)
finish_button.clicked.connect(self.finish_migration)
finish_button.clicked.connect(self.migration_finished.emit)
self.stack.append(
PagedPanelState(
title=self.title,
body_wrapper=self.body_wrapper_01,
buttons=[back_button, 1, start_button, 1, finish_button],
connect_to_back=[back_button],
connect_to_next=[finish_button],
)
)
def migrate(self, skip_ui: bool = False):
"""Open and migrate the JSON library to SQLite."""
if not self.is_migration_initialized:
self.paged_panel.update_frame()
self.paged_panel.update()
# Open the JSON Library
self.json_lib = JsonLibrary()
self.json_lib.open_library(self.path)
# Update JSON UI
self.update_json_entry_count(len(self.json_lib.entries))
self.update_json_tag_count(len(self.json_lib.tags))
self.update_json_ext_count(len(self.json_lib.ext_list))
self.update_json_ext_type(self.json_lib.is_exclude_list)
self.migration_progress(skip_ui=skip_ui)
self.is_migration_initialized = True
def migration_progress(self, skip_ui: bool = False):
"""Initialize the progress bar and iterator for the library migration."""
pb = QProgressDialog(
labelText="",
cancelButtonText="",
minimum=0,
maximum=0,
)
pb.setCancelButton(None)
self.body_wrapper_01.layout().addWidget(pb)
iterator = FunctionIterator(self.migration_iterator)
iterator.value.connect(
lambda x: (
pb.setLabelText(f"<h4>{x}</h4>"),
self.update_sql_value_ui(show_msg_box=False)
if x == "Checking for Parity..."
else (),
self.update_parity_ui() if x == "Checking for Parity..." else (),
)
)
r = CustomRunnable(iterator.run)
r.done.connect(
lambda: (
self.update_sql_value_ui(show_msg_box=not skip_ui),
pb.setMinimum(1),
pb.setValue(1),
)
)
QThreadPool.globalInstance().start(r)
def migration_iterator(self):
"""Iterate over the library migration process."""
try:
# Convert JSON Library to SQLite
yield "Creating SQL Database Tables..."
self.sql_lib = SqliteLibrary()
self.temp_path: Path = (
self.json_lib.library_dir / TS_FOLDER_NAME / "migration_ts_library.sqlite"
)
self.sql_lib.storage_path = self.temp_path
if self.temp_path.exists():
logger.info('Temporary migration file "temp_path" already exists. Removing...')
self.temp_path.unlink()
self.sql_lib.open_sqlite_library(
self.json_lib.library_dir, is_new=True, add_default_data=False
)
yield f"Migrating {len(self.json_lib.entries):,d} File Entries..."
self.sql_lib.migrate_json_to_sqlite(self.json_lib)
yield "Checking for Parity..."
check_set = set()
check_set.add(self.check_field_parity())
check_set.add(self.check_path_parity())
check_set.add(self.check_shorthand_parity())
check_set.add(self.check_subtag_parity())
check_set.add(self.check_alias_parity())
check_set.add(self.check_color_parity())
self.update_parity_ui()
if False not in check_set:
yield "Migration Complete!"
else:
yield "Migration Complete, Discrepancies Found"
self.done = True
except Exception as e:
yield f"Error: {type(e).__name__}"
self.done = True
def update_parity_ui(self):
"""Update all parity values UI."""
self.update_parity_value(self.fields_row, self.field_parity)
self.update_parity_value(self.path_row, self.path_parity)
self.update_parity_value(self.shorthands_row, self.shorthand_parity)
self.update_parity_value(self.subtags_row, self.subtag_parity)
self.update_parity_value(self.aliases_row, self.alias_parity)
self.update_parity_value(self.colors_row, self.color_parity)
self.sql_lib.close()
def update_sql_value_ui(self, show_msg_box: bool = True):
"""Update the SQL value count UI."""
self.update_sql_value(
self.entries_row,
self.sql_lib.entries_count,
self.old_entry_count,
)
self.update_sql_value(
self.tags_row,
len(self.sql_lib.tags),
self.old_tag_count,
)
self.update_sql_value(
self.ext_row,
len(self.sql_lib.prefs(LibraryPrefs.EXTENSION_LIST)),
self.old_ext_count,
)
self.update_sql_value(
self.ext_type_row,
self.sql_lib.prefs(LibraryPrefs.IS_EXCLUDE_LIST),
self.old_ext_type,
)
logger.info("Parity check complete!")
if self.discrepancies:
logger.warning("Discrepancies found:")
logger.warning("\n".join(self.discrepancies))
QApplication.beep()
if not show_msg_box:
return
msg_box = QMessageBox()
msg_box.setWindowTitle("Library Discrepancies Found")
msg_box.setText(
"Discrepancies were found between the original and converted library formats. "
"Please review and choose to whether continue with the migration or to cancel."
)
msg_box.setDetailedText("\n".join(self.discrepancies))
msg_box.setIcon(QMessageBox.Icon.Warning)
msg_box.exec()
def finish_migration(self):
"""Finish the migration upon user approval."""
final_name = self.json_lib.library_dir / TS_FOLDER_NAME / SqliteLibrary.SQL_FILENAME
if self.temp_path.exists():
self.temp_path.rename(final_name)
def update_json_entry_count(self, value: int):
self.old_entry_count = value
label: QLabel = self.old_content_layout.itemAtPosition(self.entries_row, 1).widget() # type:ignore
label.setText(self.color_value_default(value))
def update_json_tag_count(self, value: int):
self.old_tag_count = value
label: QLabel = self.old_content_layout.itemAtPosition(self.tags_row, 1).widget() # type:ignore
label.setText(self.color_value_default(value))
def update_json_ext_count(self, value: int):
self.old_ext_count = value
label: QLabel = self.old_content_layout.itemAtPosition(self.ext_row, 1).widget() # type:ignore
label.setText(self.color_value_default(value))
def update_json_ext_type(self, value: bool):
self.old_ext_type = value
label: QLabel = self.old_content_layout.itemAtPosition(self.ext_type_row, 1).widget() # type:ignore
label.setText(self.color_value_default(value))
def update_sql_value(self, row: int, value: int | bool, old_value: int | bool):
label: QLabel = self.new_content_layout.itemAtPosition(row, 1).widget() # type:ignore
warning_icon: QLabel = self.new_content_layout.itemAtPosition(row, 2).widget() # type:ignore
label.setText(self.color_value_conditional(old_value, value))
warning_icon.setText("" if old_value == value else self.warning)
def update_parity_value(self, row: int, value: bool):
result: str = self.match_text if value else self.differ_text
old_label: QLabel = self.old_content_layout.itemAtPosition(row, 1).widget() # type:ignore
new_label: QLabel = self.new_content_layout.itemAtPosition(row, 1).widget() # type:ignore
old_warning_icon: QLabel = self.old_content_layout.itemAtPosition(row, 2).widget() # type:ignore
new_warning_icon: QLabel = self.new_content_layout.itemAtPosition(row, 2).widget() # type:ignore
old_label.setText(self.color_value_conditional(self.match_text, result))
new_label.setText(self.color_value_conditional(self.match_text, result))
old_warning_icon.setText("" if value else self.warning)
new_warning_icon.setText("" if value else self.warning)
def color_value_default(self, value: int) -> str:
"""Apply the default color to a value."""
return str(f"<b><a style='color: #3b87f0'>{value}</a></b>")
def color_value_conditional(self, old_value: int | str, new_value: int | str) -> str:
"""Apply a conditional color to a value."""
red: str = "#e22c3c"
green: str = "#28bb48"
color = green if old_value == new_value else red
return str(f"<b><a style='color: {color}'>{new_value}</a></b>")
def check_field_parity(self) -> bool:
"""Check if all JSON field data matches the new SQL field data."""
def sanitize_field(session, entry: Entry, value, type, type_key):
if type is FieldTypeEnum.TAGS:
tags = list(
session.scalars(
select(Tag.id)
.join(TagField)
.join(TagBoxField)
.where(
and_(
TagBoxField.entry_id == entry.id,
TagBoxField.id == TagField.field_id,
TagBoxField.type_key == type_key,
)
)
)
)
return set(tags) if tags else None
else:
return value if value else None
def sanitize_json_field(value):
if isinstance(value, list):
return set(value) if value else None
else:
return value if value else None
with Session(self.sql_lib.engine) as session:
for json_entry in self.json_lib.entries:
sql_fields: list[tuple] = []
json_fields: list[tuple] = []
sql_entry: Entry = session.scalar(
select(Entry).where(Entry.id == json_entry.id + 1)
)
if not sql_entry:
logger.info(
"[Field Comparison]",
message=f"NEW (SQL): SQL Entry ID mismatch: {json_entry.id+1}",
)
self.discrepancies.append(
f"[Field Comparison]:\nNEW (SQL): SQL Entry ID not found: {json_entry.id+1}"
)
self.field_parity = False
return self.field_parity
for sf in sql_entry.fields:
sql_fields.append(
(
sql_entry.id,
sf.type.key,
sanitize_field(session, sql_entry, sf.value, sf.type.type, sf.type_key),
)
)
sql_fields.sort()
# NOTE: The JSON database allowed for separate tag fields of the same type with
# different values. The SQL database does not, and instead merges these values
# across all instances of that field on an entry.
# TODO: ROADMAP: "Tag Categories" will merge all field tags onto the entry.
# All visual separation from there will be data-driven from the tag itself.
meta_tags_count: int = 0
content_tags_count: int = 0
tags_count: int = 0
merged_meta_tags: set[int] = set()
merged_content_tags: set[int] = set()
merged_tags: set[int] = set()
for jf in json_entry.fields:
key: str = self.sql_lib.get_field_name_from_id(list(jf.keys())[0]).name
value = sanitize_json_field(list(jf.values())[0])
if key == _FieldID.TAGS_META.name:
meta_tags_count += 1
merged_meta_tags = merged_meta_tags.union(value or [])
elif key == _FieldID.TAGS_CONTENT.name:
content_tags_count += 1
merged_content_tags = merged_content_tags.union(value or [])
elif key == _FieldID.TAGS.name:
tags_count += 1
merged_tags = merged_tags.union(value or [])
else:
# JSON IDs start at 0 instead of 1
json_fields.append((json_entry.id + 1, key, value))
if meta_tags_count:
for _ in range(0, meta_tags_count):
json_fields.append(
(
json_entry.id + 1,
_FieldID.TAGS_META.name,
merged_meta_tags if merged_meta_tags else None,
)
)
if content_tags_count:
for _ in range(0, content_tags_count):
json_fields.append(
(
json_entry.id + 1,
_FieldID.TAGS_CONTENT.name,
merged_content_tags if merged_content_tags else None,
)
)
if tags_count:
for _ in range(0, tags_count):
json_fields.append(
(
json_entry.id + 1,
_FieldID.TAGS.name,
merged_tags if merged_tags else None,
)
)
json_fields.sort()
if not (
json_fields is not None
and sql_fields is not None
and (json_fields == sql_fields)
):
self.discrepancies.append(
f"[Field Comparison]:\nOLD (JSON):{json_fields}\nNEW (SQL):{sql_fields}"
)
self.field_parity = False
return self.field_parity
logger.info(
"[Field Comparison]",
fields="\n".join([str(x) for x in zip(json_fields, sql_fields)]),
)
self.field_parity = True
return self.field_parity
def check_path_parity(self) -> bool:
"""Check if all JSON file paths match the new SQL paths."""
with Session(self.sql_lib.engine) as session:
json_paths: list = sorted([x.path / x.filename for x in self.json_lib.entries])
sql_paths: list = sorted(list(session.scalars(select(Entry.path))))
self.path_parity = (
json_paths is not None and sql_paths is not None and (json_paths == sql_paths)
)
return self.path_parity
def check_subtag_parity(self) -> bool:
"""Check if all JSON subtags match the new SQL subtags."""
sql_subtags: set[int] = None
json_subtags: set[int] = None
with Session(self.sql_lib.engine) as session:
for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_subtags = set(
session.scalars(select(TagSubtag.child_id).where(TagSubtag.parent_id == tag.id))
)
# JSON tags allowed self-parenting; SQL tags no longer allow this.
json_subtags = set(self.json_lib.get_tag(tag_id).subtag_ids).difference(
set([self.json_lib.get_tag(tag_id).id])
)
logger.info(
"[Subtag Parity]",
tag_id=tag_id,
json_subtags=json_subtags,
sql_subtags=sql_subtags,
)
if not (
sql_subtags is not None
and json_subtags is not None
and (sql_subtags == json_subtags)
):
self.discrepancies.append(
f"[Subtag Parity]:\nOLD (JSON):{json_subtags}\nNEW (SQL):{sql_subtags}"
)
self.subtag_parity = False
return self.subtag_parity
self.subtag_parity = True
return self.subtag_parity
def check_ext_type(self) -> bool:
return self.json_lib.is_exclude_list == self.sql_lib.prefs(LibraryPrefs.IS_EXCLUDE_LIST)
def check_alias_parity(self) -> bool:
"""Check if all JSON aliases match the new SQL aliases."""
sql_aliases: set[str] = None
json_aliases: set[str] = None
with Session(self.sql_lib.engine) as session:
for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_aliases = set(
session.scalars(select(TagAlias.name).where(TagAlias.tag_id == tag.id))
)
json_aliases = set(self.json_lib.get_tag(tag_id).aliases)
logger.info(
"[Alias Parity]",
tag_id=tag_id,
json_aliases=json_aliases,
sql_aliases=sql_aliases,
)
if not (
sql_aliases is not None
and json_aliases is not None
and (sql_aliases == json_aliases)
):
self.discrepancies.append(
f"[Alias Parity]:\nOLD (JSON):{json_aliases}\nNEW (SQL):{sql_aliases}"
)
self.alias_parity = False
return self.alias_parity
self.alias_parity = True
return self.alias_parity
def check_shorthand_parity(self) -> bool:
"""Check if all JSON shorthands match the new SQL shorthands."""
sql_shorthand: str = None
json_shorthand: str = None
for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_shorthand = tag.shorthand
json_shorthand = self.json_lib.get_tag(tag_id).shorthand
logger.info(
"[Shorthand Parity]",
tag_id=tag_id,
json_shorthand=json_shorthand,
sql_shorthand=sql_shorthand,
)
if not (
sql_shorthand is not None
and json_shorthand is not None
and (sql_shorthand == json_shorthand)
):
self.discrepancies.append(
f"[Shorthand Parity]:\nOLD (JSON):{json_shorthand}\nNEW (SQL):{sql_shorthand}"
)
self.shorthand_parity = False
return self.shorthand_parity
self.shorthand_parity = True
return self.shorthand_parity
def check_color_parity(self) -> bool:
"""Check if all JSON tag colors match the new SQL tag colors."""
sql_color: str = None
json_color: str = None
for tag in self.sql_lib.tags:
tag_id = tag.id # Tag IDs start at 0
sql_color = tag.color.name
json_color = (
TagColor.get_color_from_str(self.json_lib.get_tag(tag_id).color).name
if self.json_lib.get_tag(tag_id).color != ""
else TagColor.DEFAULT.name
)
logger.info(
"[Color Parity]",
tag_id=tag_id,
json_color=json_color,
sql_color=sql_color,
)
if not (sql_color is not None and json_color is not None and (sql_color == json_color)):
self.discrepancies.append(
f"[Color Parity]:\nOLD (JSON):{json_color}\nNEW (SQL):{sql_color}"
)
self.color_parity = False
return self.color_parity
self.color_parity = True
return self.color_parity

View File

@@ -0,0 +1,20 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from PySide6.QtCore import Qt
from PySide6.QtWidgets import (
QVBoxLayout,
QWidget,
)
class PagedBodyWrapper(QWidget):
"""A state object for paged panels."""
def __init__(self):
super().__init__()
layout: QVBoxLayout = QVBoxLayout(self)
layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(0)

View File

@@ -0,0 +1,112 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import structlog
from PySide6.QtCore import Qt
from PySide6.QtWidgets import (
QHBoxLayout,
QLabel,
QVBoxLayout,
QWidget,
)
from src.qt.widgets.paged_panel.paged_panel_state import PagedPanelState
logger = structlog.get_logger(__name__)
class PagedPanel(QWidget):
"""A paginated modal panel."""
def __init__(self, size: tuple[int, int], stack: list[PagedPanelState]):
super().__init__()
self._stack: list[PagedPanelState] = stack
self._index: int = 0
self.setMinimumSize(*size)
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.root_layout = QVBoxLayout(self)
self.root_layout.setObjectName("baseLayout")
self.root_layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.root_layout.setContentsMargins(0, 0, 0, 0)
self.content_container = QWidget()
self.content_layout = QVBoxLayout(self.content_container)
self.content_layout.setContentsMargins(12, 12, 12, 12)
self.title_label = QLabel()
self.title_label.setObjectName("fieldTitle")
self.title_label.setWordWrap(True)
self.title_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.body_container = QWidget()
self.body_container.setObjectName("bodyContainer")
self.body_layout = QVBoxLayout(self.body_container)
self.body_layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.body_layout.setContentsMargins(0, 0, 0, 0)
self.body_layout.setSpacing(0)
self.button_nav_container = QWidget()
self.button_nav_layout = QHBoxLayout(self.button_nav_container)
self.root_layout.addWidget(self.content_container)
self.content_layout.addWidget(self.title_label)
self.content_layout.addWidget(self.body_container)
self.content_layout.addStretch(1)
self.root_layout.addWidget(self.button_nav_container)
self.init_connections()
self.update_frame()
def init_connections(self):
"""Initialize button navigation connections."""
for frame in self._stack:
for button in frame.connect_to_back:
button.clicked.connect(self.back)
for button in frame.connect_to_next:
button.clicked.connect(self.next)
def back(self):
"""Navigate backward in the state stack. Close if out of bounds."""
if self._index > 0:
self._index = self._index - 1
self.update_frame()
else:
self.close()
def next(self):
"""Navigate forward in the state stack. Close if out of bounds."""
if self._index < len(self._stack) - 1:
self._index = self._index + 1
self.update_frame()
else:
self.close()
def update_frame(self):
"""Update the widgets with the current frame's content."""
frame: PagedPanelState = self._stack[self._index]
# Update Title
self.setWindowTitle(frame.title)
self.title_label.setText(f"<h1>{frame.title}</h1>")
# Update Body Widget
if self.body_layout.itemAt(0):
self.body_layout.itemAt(0).widget().setHidden(True)
self.body_layout.removeWidget(self.body_layout.itemAt(0).widget())
self.body_layout.addWidget(frame.body_wrapper)
self.body_layout.itemAt(0).widget().setHidden(False)
# Update Button Widgets
while self.button_nav_layout.count():
if _ := self.button_nav_layout.takeAt(0).widget():
_.setHidden(True)
for item in frame.buttons:
if isinstance(item, QWidget):
self.button_nav_layout.addWidget(item)
item.setHidden(False)
elif isinstance(item, int):
self.button_nav_layout.addStretch(item)

View File

@@ -0,0 +1,25 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from PySide6.QtWidgets import QPushButton
from src.qt.widgets.paged_panel.paged_body_wrapper import PagedBodyWrapper
class PagedPanelState:
"""A state object for paged panels."""
def __init__(
self,
title: str,
body_wrapper: PagedBodyWrapper,
buttons: list[QPushButton | int],
connect_to_back=list[QPushButton],
connect_to_next=list[QPushButton],
):
self.title: str = title
self.body_wrapper: PagedBodyWrapper = body_wrapper
self.buttons: list[QPushButton | int] = buttons
self.connect_to_back: list[QPushButton] = connect_to_back
self.connect_to_next: list[QPushButton] = connect_to_next

View File

@@ -896,10 +896,6 @@ class PreviewPanel(QWidget):
logger.error("Failed to disconnect inner_container.updated")
else:
logger.info(
"inner_container is not instance of TagBoxWidget",
container=inner_container,
)
inner_container = TagBoxWidget(
field,
title,

View File

@@ -89,12 +89,13 @@ class TagBoxWidget(FieldWidget):
self.field = field
def set_tags(self, tags: typing.Iterable[Tag]):
tags_ = sorted(list(tags), key=lambda tag: tag.name)
is_recycled = False
while self.base_layout.itemAt(0) and self.base_layout.itemAt(1):
self.base_layout.takeAt(0).widget().deleteLater()
is_recycled = True
for tag in tags:
for tag in tags_:
tag_widget = TagWidget(tag, has_edit=True, has_remove=True)
tag_widget.on_click.connect(
lambda tag_id=tag.id: (

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,49 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import pathlib
from time import time
from src.core.enums import LibraryPrefs
from src.qt.widgets.migration_modal import JsonMigrationModal
CWD = pathlib.Path(__file__)
def test_json_migration():
modal = JsonMigrationModal(CWD.parent / "fixtures" / "json_library")
modal.migrate(skip_ui=True)
start = time()
while not modal.done and (time() - start < 60):
pass
# Entries ==================================================================
# Count
assert len(modal.json_lib.entries) == modal.sql_lib.entries_count
# Path Parity
assert modal.check_path_parity()
# Field Parity
assert modal.check_field_parity()
# Tags =====================================================================
# Count
assert len(modal.json_lib.tags) == len(modal.sql_lib.tags)
# Shorthand Parity
assert modal.check_shorthand_parity()
# Subtag/Parent Tag Parity
assert modal.check_subtag_parity()
# Alias Parity
assert modal.check_alias_parity()
# Color Parity
assert modal.check_color_parity()
# Extension Filter List ====================================================
# Count
assert len(modal.json_lib.ext_list) == len(modal.sql_lib.prefs(LibraryPrefs.EXTENSION_LIST))
# List Type
assert modal.check_ext_type()
# No Leading Dot
for ext in modal.sql_lib.prefs(LibraryPrefs.EXTENSION_LIST):
assert ext[0] != "."

View File

@@ -85,7 +85,7 @@ def test_library_add_file(library):
def test_create_tag(library, generate_tag):
# tag already exists
assert not library.add_tag(generate_tag("foo"))
assert not library.add_tag(generate_tag("foo", id=1000))
# new tag name
tag = library.add_tag(generate_tag("xxx", id=123))
@@ -98,7 +98,7 @@ def test_create_tag(library, generate_tag):
def test_tag_subtag_itself(library, generate_tag):
# tag already exists
assert not library.add_tag(generate_tag("foo"))
assert not library.add_tag(generate_tag("foo", id=1000))
# new tag name
tag = library.add_tag(generate_tag("xxx", id=123))
@@ -132,19 +132,13 @@ def test_library_search(library, generate_tag, entry_full):
def test_tag_search(library):
tag = library.tags[0]
assert library.search_tags(
FilterState(tag=tag.name.lower()),
)
assert library.search_tags(tag.name.lower())
assert library.search_tags(
FilterState(tag=tag.name.upper()),
)
assert library.search_tags(tag.name.upper())
assert library.search_tags(FilterState(tag=tag.name[2:-2]))
assert library.search_tags(tag.name[2:-2])
assert not library.search_tags(
FilterState(tag=tag.name * 2),
)
assert not library.search_tags(tag.name * 2)
def test_get_entry(library, entry_min):