mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-01-30 06:40:50 +00:00
feat: store Entry suffix separately (#503)
* feat: save entry suffix separately * change LibraryPrefs to allow identical values, add test
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
from enum import Enum
|
||||
|
||||
VERSION: str = "9.3.2" # Major.Minor.Patch
|
||||
VERSION_BRANCH: str = "" # Usually "" or "Pre-Release"
|
||||
|
||||
@@ -7,7 +5,6 @@ VERSION_BRANCH: str = "" # Usually "" or "Pre-Release"
|
||||
TS_FOLDER_NAME: str = ".TagStudio"
|
||||
BACKUP_FOLDER_NAME: str = "backups"
|
||||
COLLAGE_FOLDER_NAME: str = "collages"
|
||||
LIBRARY_FILENAME: str = "ts_library.json"
|
||||
|
||||
# TODO: Turn this whitelist into a user-configurable blacklist.
|
||||
IMAGE_TYPES: list[str] = [
|
||||
@@ -122,13 +119,5 @@ ALL_FILE_TYPES: list[str] = (
|
||||
+ SHORTCUT_TYPES
|
||||
)
|
||||
|
||||
|
||||
TAG_FAVORITE = 1
|
||||
TAG_ARCHIVED = 0
|
||||
|
||||
|
||||
class LibraryPrefs(Enum):
|
||||
IS_EXCLUDE_LIST = True
|
||||
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
|
||||
PAGE_SIZE: int = 500
|
||||
DB_VERSION: int = 1
|
||||
|
||||
40
tagstudio/src/core/driver.py
Normal file
40
tagstudio/src/core/driver.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from PySide6.QtCore import QSettings
|
||||
from src.core.constants import TS_FOLDER_NAME
|
||||
from src.core.enums import SettingItems
|
||||
from src.core.library.alchemy.library import LibraryStatus
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class DriverMixin:
|
||||
settings: QSettings
|
||||
|
||||
def evaluate_path(self, open_path: str | None) -> LibraryStatus:
|
||||
"""Check if the path of library is valid."""
|
||||
library_path: Path | None = None
|
||||
if open_path:
|
||||
library_path = Path(open_path)
|
||||
if not library_path.exists():
|
||||
logger.error("Path does not exist.", open_path=open_path)
|
||||
return LibraryStatus(success=False, message="Path does not exist.")
|
||||
elif self.settings.value(
|
||||
SettingItems.START_LOAD_LAST, defaultValue=True, type=bool
|
||||
) and self.settings.value(SettingItems.LAST_LIBRARY):
|
||||
library_path = Path(str(self.settings.value(SettingItems.LAST_LIBRARY)))
|
||||
if not (library_path / TS_FOLDER_NAME).exists():
|
||||
logger.error(
|
||||
"TagStudio folder does not exist.",
|
||||
library_path=library_path,
|
||||
ts_folder=TS_FOLDER_NAME,
|
||||
)
|
||||
self.settings.setValue(SettingItems.LAST_LIBRARY, "")
|
||||
# dont consider this a fatal error, just skip opening the library
|
||||
library_path = None
|
||||
|
||||
return LibraryStatus(
|
||||
success=True,
|
||||
library_path=library_path,
|
||||
)
|
||||
@@ -1,4 +1,6 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class SettingItems(str, enum.Enum):
|
||||
@@ -31,3 +33,31 @@ class MacroID(enum.Enum):
|
||||
BUILD_URL = "build_url"
|
||||
MATCH = "match"
|
||||
CLEAN_URL = "clean_url"
|
||||
|
||||
|
||||
class DefaultEnum(enum.Enum):
|
||||
"""Allow saving multiple identical values in property called .default."""
|
||||
|
||||
default: Any
|
||||
|
||||
def __new__(cls, value):
|
||||
# Create the enum instance
|
||||
obj = object.__new__(cls)
|
||||
# make value random
|
||||
obj._value_ = uuid4()
|
||||
# assign the actual value into .default property
|
||||
obj.default = value
|
||||
return obj
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
raise AttributeError("access the value via .default property instead")
|
||||
|
||||
|
||||
class LibraryPrefs(DefaultEnum):
|
||||
"""Library preferences with default value accessible via .default property."""
|
||||
|
||||
IS_EXCLUDE_LIST = True
|
||||
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
|
||||
PAGE_SIZE: int = 500
|
||||
DB_VERSION: int = 2
|
||||
|
||||
@@ -18,27 +18,27 @@ class BaseField(Base):
|
||||
__abstract__ = True
|
||||
|
||||
@declared_attr
|
||||
def id(cls) -> Mapped[int]: # noqa: N805
|
||||
def id(self) -> Mapped[int]:
|
||||
return mapped_column(primary_key=True, autoincrement=True)
|
||||
|
||||
@declared_attr
|
||||
def type_key(cls) -> Mapped[str]: # noqa: N805
|
||||
def type_key(self) -> Mapped[str]:
|
||||
return mapped_column(ForeignKey("value_type.key"))
|
||||
|
||||
@declared_attr
|
||||
def type(cls) -> Mapped[ValueType]: # noqa: N805
|
||||
return relationship(foreign_keys=[cls.type_key], lazy=False) # type: ignore
|
||||
def type(self) -> Mapped[ValueType]:
|
||||
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore
|
||||
|
||||
@declared_attr
|
||||
def entry_id(cls) -> Mapped[int]: # noqa: N805
|
||||
def entry_id(self) -> Mapped[int]:
|
||||
return mapped_column(ForeignKey("entries.id"))
|
||||
|
||||
@declared_attr
|
||||
def entry(cls) -> Mapped[Entry]: # noqa: N805
|
||||
return relationship(foreign_keys=[cls.entry_id]) # type: ignore
|
||||
def entry(self) -> Mapped[Entry]:
|
||||
return relationship(foreign_keys=[self.entry_id]) # type: ignore
|
||||
|
||||
@declared_attr
|
||||
def position(cls) -> Mapped[int]: # noqa: N805
|
||||
def position(self) -> Mapped[int]:
|
||||
return mapped_column(default=0)
|
||||
|
||||
def __hash__(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import unicodedata
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
@@ -34,8 +35,8 @@ from ...constants import (
|
||||
TAG_ARCHIVED,
|
||||
TAG_FAVORITE,
|
||||
TS_FOLDER_NAME,
|
||||
LibraryPrefs,
|
||||
)
|
||||
from ...enums import LibraryPrefs
|
||||
from .db import make_tables
|
||||
from .enums import FieldTypeEnum, FilterState, TagColor
|
||||
from .fields import (
|
||||
@@ -48,8 +49,6 @@ from .fields import (
|
||||
from .joins import TagField, TagSubtag
|
||||
from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType
|
||||
|
||||
LIBRARY_FILENAME: str = "ts_library.sqlite"
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -115,6 +114,15 @@ class SearchResult:
|
||||
return self.items[index]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LibraryStatus:
|
||||
"""Keep status of library opening operation."""
|
||||
|
||||
success: bool
|
||||
library_path: Path | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class Library:
|
||||
"""Class for the Library object, and all CRUD operations made upon it."""
|
||||
|
||||
@@ -123,6 +131,8 @@ class Library:
|
||||
engine: Engine | None
|
||||
folder: Folder | None
|
||||
|
||||
FILENAME: str = "ts_library.sqlite"
|
||||
|
||||
def close(self):
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
@@ -130,23 +140,19 @@ class Library:
|
||||
self.storage_path = None
|
||||
self.folder = None
|
||||
|
||||
def open_library(self, library_dir: Path | str, storage_path: str | None = None) -> None:
|
||||
if isinstance(library_dir, str):
|
||||
library_dir = Path(library_dir)
|
||||
|
||||
self.library_dir = library_dir
|
||||
def open_library(self, library_dir: Path, storage_path: str | None = None) -> LibraryStatus:
|
||||
if storage_path == ":memory:":
|
||||
self.storage_path = storage_path
|
||||
else:
|
||||
self.verify_ts_folders(self.library_dir)
|
||||
self.storage_path = self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME
|
||||
self.verify_ts_folders(library_dir)
|
||||
self.storage_path = library_dir / TS_FOLDER_NAME / self.FILENAME
|
||||
|
||||
connection_string = URL.create(
|
||||
drivername="sqlite",
|
||||
database=str(self.storage_path),
|
||||
)
|
||||
|
||||
logger.info("opening library", connection_string=connection_string)
|
||||
logger.info("opening library", library_dir=library_dir, connection_string=connection_string)
|
||||
self.engine = create_engine(connection_string)
|
||||
with Session(self.engine) as session:
|
||||
make_tables(self.engine)
|
||||
@@ -159,9 +165,24 @@ class Library:
|
||||
# default tags may exist already
|
||||
session.rollback()
|
||||
|
||||
if "pytest" not in sys.modules:
|
||||
db_version = session.scalar(
|
||||
select(Preferences).where(Preferences.key == LibraryPrefs.DB_VERSION.name)
|
||||
)
|
||||
|
||||
if not db_version:
|
||||
# TODO - remove after #503 is merged and LibraryPrefs.DB_VERSION increased again
|
||||
return LibraryStatus(
|
||||
success=False,
|
||||
message=(
|
||||
"Library version mismatch.\n"
|
||||
f"Found: v0, expected: v{LibraryPrefs.DB_VERSION.default}"
|
||||
),
|
||||
)
|
||||
|
||||
for pref in LibraryPrefs:
|
||||
try:
|
||||
session.add(Preferences(key=pref.name, value=pref.value))
|
||||
session.add(Preferences(key=pref.name, value=pref.default))
|
||||
session.commit()
|
||||
except IntegrityError:
|
||||
logger.debug("preference already exists", pref=pref)
|
||||
@@ -183,11 +204,30 @@ class Library:
|
||||
logger.debug("ValueType already exists", field=field)
|
||||
session.rollback()
|
||||
|
||||
db_version = session.scalar(
|
||||
select(Preferences).where(Preferences.key == LibraryPrefs.DB_VERSION.name)
|
||||
)
|
||||
# if the db version is different, we cant proceed
|
||||
if db_version.value != LibraryPrefs.DB_VERSION.default:
|
||||
logger.error(
|
||||
"DB version mismatch",
|
||||
db_version=db_version.value,
|
||||
expected=LibraryPrefs.DB_VERSION.default,
|
||||
)
|
||||
# TODO - handle migration
|
||||
return LibraryStatus(
|
||||
success=False,
|
||||
message=(
|
||||
"Library version mismatch.\n"
|
||||
f"Found: v{db_version.value}, expected: v{LibraryPrefs.DB_VERSION.default}"
|
||||
),
|
||||
)
|
||||
|
||||
# check if folder matching current path exists already
|
||||
self.folder = session.scalar(select(Folder).where(Folder.path == self.library_dir))
|
||||
self.folder = session.scalar(select(Folder).where(Folder.path == library_dir))
|
||||
if not self.folder:
|
||||
folder = Folder(
|
||||
path=self.library_dir,
|
||||
path=library_dir,
|
||||
uuid=str(uuid4()),
|
||||
)
|
||||
session.add(folder)
|
||||
@@ -196,6 +236,10 @@ class Library:
|
||||
session.commit()
|
||||
self.folder = folder
|
||||
|
||||
# everything is fine, set the library path
|
||||
self.library_dir = library_dir
|
||||
return LibraryStatus(success=True, library_path=library_dir)
|
||||
|
||||
@property
|
||||
def default_fields(self) -> list[BaseField]:
|
||||
with Session(self.engine) as session:
|
||||
@@ -324,15 +368,18 @@ class Library:
|
||||
|
||||
with Session(self.engine) as session:
|
||||
# add all items
|
||||
session.add_all(items)
|
||||
session.flush()
|
||||
|
||||
try:
|
||||
session.add_all(items)
|
||||
session.commit()
|
||||
except IntegrityError:
|
||||
session.rollback()
|
||||
logger.exception("IntegrityError")
|
||||
return []
|
||||
|
||||
new_ids = [item.id for item in items]
|
||||
|
||||
session.expunge_all()
|
||||
|
||||
session.commit()
|
||||
|
||||
return new_ids
|
||||
|
||||
def remove_entries(self, entry_ids: list[int]) -> None:
|
||||
@@ -396,9 +443,9 @@ class Library:
|
||||
|
||||
if not search.id: # if `id` is set, we don't need to filter by extensions
|
||||
if extensions and is_exclude_list:
|
||||
statement = statement.where(Entry.path.notilike(f"%.{','.join(extensions)}"))
|
||||
statement = statement.where(Entry.suffix.notin_(extensions))
|
||||
elif extensions:
|
||||
statement = statement.where(Entry.path.ilike(f"%.{','.join(extensions)}"))
|
||||
statement = statement.where(Entry.suffix.in_(extensions))
|
||||
|
||||
statement = statement.options(
|
||||
selectinload(Entry.text_fields),
|
||||
@@ -770,7 +817,7 @@ class Library:
|
||||
target_path = self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME / filename
|
||||
|
||||
shutil.copy2(
|
||||
self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME,
|
||||
self.library_dir / TS_FOLDER_NAME / self.FILENAME,
|
||||
target_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class Entry(Base):
|
||||
folder: Mapped[Folder] = relationship("Folder")
|
||||
|
||||
path: Mapped[Path] = mapped_column(PathType, unique=True)
|
||||
suffix: Mapped[str] = mapped_column()
|
||||
|
||||
text_fields: Mapped[list[TextField]] = relationship(
|
||||
back_populates="entry",
|
||||
@@ -177,6 +178,8 @@ class Entry(Base):
|
||||
self.path = path
|
||||
self.folder = folder
|
||||
|
||||
self.suffix = path.suffix.lstrip(".").lower()
|
||||
|
||||
for field in fields:
|
||||
if isinstance(field, TextField):
|
||||
self.text_fields.append(field)
|
||||
|
||||
@@ -299,6 +299,8 @@ class Collation:
|
||||
class Library:
|
||||
"""Class for the Library object, and all CRUD operations made upon it."""
|
||||
|
||||
FILENAME: str = "ts_library.json"
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Library Info =========================================================
|
||||
self.library_dir: Path = None
|
||||
|
||||
@@ -16,7 +16,7 @@ from PySide6.QtWidgets import (
|
||||
QVBoxLayout,
|
||||
QWidget,
|
||||
)
|
||||
from src.core.constants import LibraryPrefs
|
||||
from src.core.enums import LibraryPrefs
|
||||
from src.core.library import Library
|
||||
from src.qt.widgets.panel import PanelWidget
|
||||
|
||||
@@ -104,7 +104,7 @@ class FileExtensionModal(PanelWidget):
|
||||
for i in range(self.table.rowCount()):
|
||||
ext = self.table.item(i, 0)
|
||||
if ext and ext.text().strip():
|
||||
extensions.append(ext.text().strip().lower())
|
||||
extensions.append(ext.text().strip().lstrip(".").lower())
|
||||
|
||||
# save preference
|
||||
self.lib.set_prefs(LibraryPrefs.EXTENSION_LIST, extensions)
|
||||
|
||||
@@ -50,6 +50,7 @@ from PySide6.QtWidgets import (
|
||||
QLineEdit,
|
||||
QMenu,
|
||||
QMenuBar,
|
||||
QMessageBox,
|
||||
QPushButton,
|
||||
QScrollArea,
|
||||
QSplashScreen,
|
||||
@@ -58,12 +59,11 @@ from PySide6.QtWidgets import (
|
||||
from src.core.constants import (
|
||||
TAG_ARCHIVED,
|
||||
TAG_FAVORITE,
|
||||
TS_FOLDER_NAME,
|
||||
VERSION,
|
||||
VERSION_BRANCH,
|
||||
LibraryPrefs,
|
||||
)
|
||||
from src.core.enums import MacroID, SettingItems
|
||||
from src.core.driver import DriverMixin
|
||||
from src.core.enums import LibraryPrefs, MacroID, SettingItems
|
||||
from src.core.library.alchemy.enums import (
|
||||
FieldTypeEnum,
|
||||
FilterState,
|
||||
@@ -71,6 +71,7 @@ from src.core.library.alchemy.enums import (
|
||||
SearchMode,
|
||||
)
|
||||
from src.core.library.alchemy.fields import _FieldID
|
||||
from src.core.library.alchemy.library import LibraryStatus
|
||||
from src.core.ts_core import TagStudioCore
|
||||
from src.core.utils.refresh_dir import RefreshDirTracker
|
||||
from src.core.utils.web import strip_web_protocol
|
||||
@@ -120,7 +121,7 @@ class Consumer(QThread):
|
||||
pass
|
||||
|
||||
|
||||
class QtDriver(QObject):
|
||||
class QtDriver(DriverMixin, QObject):
|
||||
"""A Qt GUI frontend driver for TagStudio."""
|
||||
|
||||
SIGTERM = Signal()
|
||||
@@ -173,16 +174,15 @@ class QtDriver(QObject):
|
||||
filename=self.settings.fileName(),
|
||||
)
|
||||
|
||||
max_threads = os.cpu_count()
|
||||
for i in range(max_threads):
|
||||
# thread = threading.Thread(
|
||||
# target=self.consumer, name=f"ThumbRenderer_{i}", args=(), daemon=True
|
||||
# )
|
||||
# thread.start()
|
||||
thread = Consumer(self.thumb_job_queue)
|
||||
thread.setObjectName(f"ThumbRenderer_{i}")
|
||||
self.thumb_threads.append(thread)
|
||||
thread.start()
|
||||
def init_workers(self):
|
||||
"""Init workers for rendering thumbnails."""
|
||||
if not self.thumb_threads:
|
||||
max_threads = os.cpu_count()
|
||||
for i in range(max_threads):
|
||||
thread = Consumer(self.thumb_job_queue)
|
||||
thread.setObjectName(f"ThumbRenderer_{i}")
|
||||
self.thumb_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
def open_library_from_dialog(self):
|
||||
dir = QFileDialog.getExistingDirectory(
|
||||
@@ -457,33 +457,35 @@ class QtDriver(QObject):
|
||||
self.item_thumbs: list[ItemThumb] = []
|
||||
self.thumb_renderers: list[ThumbRenderer] = []
|
||||
self.filter = FilterState()
|
||||
|
||||
self.init_library_window()
|
||||
|
||||
lib: str | None = None
|
||||
if self.args.open:
|
||||
lib = self.args.open
|
||||
elif self.settings.value(SettingItems.START_LOAD_LAST, defaultValue=True, type=bool):
|
||||
lib = str(self.settings.value(SettingItems.LAST_LIBRARY))
|
||||
|
||||
# TODO: Remove this check if the library is no longer saved with files
|
||||
if lib and not (Path(lib) / TS_FOLDER_NAME).exists():
|
||||
logger.error(f"[QT DRIVER] {TS_FOLDER_NAME} folder in {lib} does not exist.")
|
||||
self.settings.setValue(SettingItems.LAST_LIBRARY, "")
|
||||
lib = None
|
||||
|
||||
if lib:
|
||||
path_result = self.evaluate_path(self.args.open)
|
||||
# check status of library path evaluating
|
||||
if path_result.success and path_result.library_path:
|
||||
self.splash.showMessage(
|
||||
f'Opening Library "{lib}"...',
|
||||
f'Opening Library "{path_result.library_path}"...',
|
||||
int(Qt.AlignmentFlag.AlignBottom | Qt.AlignmentFlag.AlignHCenter),
|
||||
QColor("#9782ff"),
|
||||
)
|
||||
self.open_library(lib)
|
||||
self.open_library(path_result.library_path)
|
||||
|
||||
app.exec()
|
||||
|
||||
self.shutdown()
|
||||
|
||||
def show_error_message(self, message: str):
|
||||
self.main_window.statusbar.showMessage(message, Qt.AlignmentFlag.AlignLeft)
|
||||
self.main_window.landing_widget.set_status_label(message)
|
||||
self.main_window.setWindowTitle(message)
|
||||
|
||||
msg_box = QMessageBox()
|
||||
msg_box.setIcon(QMessageBox.Icon.Critical)
|
||||
msg_box.setText(message)
|
||||
msg_box.setWindowTitle("Error")
|
||||
msg_box.addButton("Close", QMessageBox.ButtonRole.AcceptRole)
|
||||
|
||||
# Show the message box
|
||||
msg_box.exec()
|
||||
|
||||
def init_library_window(self):
|
||||
# self._init_landing_page() # Taken care of inside the widget now
|
||||
self._init_thumb_grid()
|
||||
@@ -562,7 +564,7 @@ class QtDriver(QObject):
|
||||
self.main_window.statusbar.showMessage("Closing Library...")
|
||||
start_time = time.time()
|
||||
|
||||
self.settings.setValue(SettingItems.LAST_LIBRARY, self.lib.library_dir)
|
||||
self.settings.setValue(SettingItems.LAST_LIBRARY, str(self.lib.library_dir))
|
||||
self.settings.sync()
|
||||
|
||||
self.lib.close()
|
||||
@@ -1061,14 +1063,19 @@ class QtDriver(QObject):
|
||||
self.settings.endGroup()
|
||||
self.settings.sync()
|
||||
|
||||
def open_library(self, path: Path | str):
|
||||
"""Opens a TagStudio library."""
|
||||
def open_library(self, path: Path) -> LibraryStatus:
|
||||
"""Open a TagStudio library."""
|
||||
open_message: str = f'Opening Library "{str(path)}"...'
|
||||
self.main_window.landing_widget.set_status_label(open_message)
|
||||
self.main_window.statusbar.showMessage(open_message, 3)
|
||||
self.main_window.repaint()
|
||||
|
||||
self.lib.open_library(path)
|
||||
open_status = self.lib.open_library(path)
|
||||
if not open_status.success:
|
||||
self.show_error_message(open_status.message or "Error opening library.")
|
||||
return open_status
|
||||
|
||||
self.init_workers()
|
||||
|
||||
self.filter.page_size = self.lib.prefs(LibraryPrefs.PAGE_SIZE)
|
||||
|
||||
@@ -1086,3 +1093,4 @@ class QtDriver(QObject):
|
||||
self.filter_items()
|
||||
|
||||
self.main_window.toggle_landing_page(enabled=False)
|
||||
return open_status
|
||||
|
||||
@@ -41,6 +41,7 @@ class VideoPlayer(QGraphicsView):
|
||||
video_preview = None
|
||||
play_pause = None
|
||||
mute_button = None
|
||||
filepath: str | None
|
||||
|
||||
def __init__(self, driver: "QtDriver") -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -32,8 +32,8 @@ def library(request):
|
||||
library_path = request.param
|
||||
|
||||
lib = Library()
|
||||
lib.open_library(library_path, ":memory:")
|
||||
assert lib.folder
|
||||
status = lib.open_library(pathlib.Path(library_path), ":memory:")
|
||||
assert status.success
|
||||
|
||||
tag = Tag(
|
||||
name="foo",
|
||||
|
||||
@@ -2,7 +2,7 @@ import pathlib
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
from src.core.constants import LibraryPrefs
|
||||
from src.core.enums import LibraryPrefs
|
||||
from src.core.utils.refresh_dir import RefreshDirTracker
|
||||
|
||||
CWD = pathlib.Path(__file__).parent
|
||||
|
||||
66
tagstudio/tests/test_driver.py
Normal file
66
tagstudio/tests/test_driver.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from os import makedirs
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from PySide6.QtCore import QSettings
|
||||
from src.core.constants import TS_FOLDER_NAME
|
||||
from src.core.driver import DriverMixin
|
||||
from src.core.enums import SettingItems
|
||||
from src.core.library.alchemy.library import LibraryStatus
|
||||
|
||||
|
||||
class TestDriver(DriverMixin):
|
||||
def __init__(self, settings):
|
||||
self.settings = settings
|
||||
|
||||
|
||||
def test_evaluate_path_empty():
|
||||
# Given
|
||||
settings = QSettings()
|
||||
driver = TestDriver(settings)
|
||||
|
||||
# When
|
||||
result = driver.evaluate_path(None)
|
||||
|
||||
# Then
|
||||
assert result == LibraryStatus(success=True)
|
||||
|
||||
|
||||
def test_evaluate_path_missing():
|
||||
# Given
|
||||
settings = QSettings()
|
||||
driver = TestDriver(settings)
|
||||
|
||||
# When
|
||||
result = driver.evaluate_path("/0/4/5/1/")
|
||||
|
||||
# Then
|
||||
assert result == LibraryStatus(success=False, message="Path does not exist.")
|
||||
|
||||
|
||||
def test_evaluate_path_last_lib_not_exists():
|
||||
# Given
|
||||
settings = QSettings()
|
||||
settings.setValue(SettingItems.LAST_LIBRARY, "/0/4/5/1/")
|
||||
driver = TestDriver(settings)
|
||||
|
||||
# When
|
||||
result = driver.evaluate_path(None)
|
||||
|
||||
# Then
|
||||
assert result == LibraryStatus(success=True, library_path=None, message=None)
|
||||
|
||||
|
||||
def test_evaluate_path_last_lib_present():
|
||||
# Given
|
||||
settings = QSettings()
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
settings.setValue(SettingItems.LAST_LIBRARY, tmpdir)
|
||||
makedirs(Path(tmpdir) / TS_FOLDER_NAME)
|
||||
driver = TestDriver(settings)
|
||||
|
||||
# When
|
||||
result = driver.evaluate_path(None)
|
||||
|
||||
# Then
|
||||
assert result == LibraryStatus(success=True, library_path=Path(tmpdir))
|
||||
@@ -2,40 +2,27 @@ from pathlib import Path, PureWindowsPath
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
from src.core.constants import LibraryPrefs
|
||||
from src.core.library.alchemy import Entry, Library
|
||||
from src.core.enums import DefaultEnum, LibraryPrefs
|
||||
from src.core.library.alchemy import Entry
|
||||
from src.core.library.alchemy.enums import FilterState
|
||||
from src.core.library.alchemy.fields import TextField, _FieldID
|
||||
|
||||
|
||||
def test_library_bootstrap():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
lib = Library()
|
||||
lib.open_library(tmp_dir)
|
||||
assert lib.engine
|
||||
|
||||
|
||||
def test_library_add_file():
|
||||
@pytest.mark.parametrize("library", [TemporaryDirectory()], indirect=True)
|
||||
def test_library_add_file(library):
|
||||
"""Check Entry.path handling for insert vs lookup"""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
# create file in tmp_dir
|
||||
file_path = Path(tmp_dir) / "bar.txt"
|
||||
file_path.write_text("bar")
|
||||
|
||||
lib = Library()
|
||||
lib.open_library(tmp_dir)
|
||||
entry = Entry(
|
||||
path=Path("bar.txt"),
|
||||
folder=library.folder,
|
||||
fields=library.default_fields,
|
||||
)
|
||||
|
||||
entry = Entry(
|
||||
path=file_path,
|
||||
folder=lib.folder,
|
||||
fields=lib.default_fields,
|
||||
)
|
||||
assert not library.has_path_entry(entry.path)
|
||||
|
||||
assert not lib.has_path_entry(entry.path)
|
||||
assert library.add_entries([entry])
|
||||
|
||||
assert lib.add_entries([entry])
|
||||
|
||||
assert lib.has_path_entry(entry.path) is True
|
||||
assert library.has_path_entry(entry.path)
|
||||
|
||||
|
||||
def test_create_tag(library, generate_tag):
|
||||
@@ -99,7 +86,9 @@ def test_get_entry(library, entry_min):
|
||||
|
||||
def test_entries_count(library):
|
||||
entries = [Entry(path=Path(f"{x}.txt"), folder=library.folder, fields=[]) for x in range(10)]
|
||||
library.add_entries(entries)
|
||||
new_ids = library.add_entries(entries)
|
||||
assert len(new_ids) == 10
|
||||
|
||||
results = library.search_library(
|
||||
FilterState(
|
||||
page_size=5,
|
||||
@@ -120,7 +109,7 @@ def test_add_field_to_entry(library):
|
||||
# meta tags + content tags
|
||||
assert len(entry.tag_box_fields) == 2
|
||||
|
||||
library.add_entries([entry])
|
||||
assert library.add_entries([entry])
|
||||
|
||||
# When
|
||||
library.add_entry_field_type(entry.id, field_id=_FieldID.TAGS)
|
||||
@@ -208,7 +197,7 @@ def test_search_library_case_insensitive(library):
|
||||
|
||||
def test_preferences(library):
|
||||
for pref in LibraryPrefs:
|
||||
assert library.prefs(pref) == pref.value
|
||||
assert library.prefs(pref) == pref.default
|
||||
|
||||
|
||||
def test_save_windows_path(library, generate_tag):
|
||||
@@ -394,3 +383,21 @@ def test_update_field_order(library, entry_full):
|
||||
assert entry.text_fields[0].value == "first"
|
||||
assert entry.text_fields[1].position == 1
|
||||
assert entry.text_fields[1].value == "second"
|
||||
|
||||
|
||||
def test_library_prefs_multiple_identical_vals():
|
||||
# check the preferences are inherited from DefaultEnum
|
||||
assert issubclass(LibraryPrefs, DefaultEnum)
|
||||
|
||||
# create custom settings with identical values
|
||||
class TestPrefs(DefaultEnum):
|
||||
FOO = 1
|
||||
BAR = 1
|
||||
|
||||
assert TestPrefs.FOO.default == 1
|
||||
assert TestPrefs.BAR.default == 1
|
||||
assert TestPrefs.BAR.name == "BAR"
|
||||
|
||||
# accessing .value should raise exception
|
||||
with pytest.raises(AttributeError):
|
||||
assert TestPrefs.BAR.value
|
||||
|
||||
Reference in New Issue
Block a user