mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-01-28 22:01:24 +00:00
refactor: fix most pyright issues in library/alchemy/ (#1103)
* refactor: fix most pyright issues in library/alchemy/ * chore: implement review feedback
This commit is contained in:
committed by
GitHub
parent
f49cb4fade
commit
cee64a8c31
@@ -85,7 +85,11 @@ ignore_errors = true
|
||||
qt_api = "pyside6"
|
||||
|
||||
[tool.pyright]
|
||||
ignore = ["src/tagstudio/qt/previews/vendored/pydub/", ".venv/**"]
|
||||
ignore = [
|
||||
".venv/**",
|
||||
"src/tagstudio/core/library/json/",
|
||||
"src/tagstudio/qt/previews/vendored/pydub/",
|
||||
]
|
||||
include = ["src/tagstudio", "tests"]
|
||||
reportAny = false
|
||||
reportIgnoreCommentWithoutRule = false
|
||||
|
||||
@@ -23,3 +23,14 @@ WITH RECURSIVE ChildTags AS (
|
||||
)
|
||||
SELECT * FROM ChildTags;
|
||||
""")
|
||||
|
||||
TAG_CHILDREN_ID_QUERY = text("""
|
||||
WITH RECURSIVE ChildTags AS (
|
||||
SELECT :tag_id AS tag_id
|
||||
UNION
|
||||
SELECT tp.child_id AS tag_id
|
||||
FROM tag_parents tp
|
||||
INNER JOIN ChildTags c ON tp.parent_id = c.tag_id
|
||||
)
|
||||
SELECT tag_id FROM ChildTags;
|
||||
""")
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import override
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import Dialect, Engine, String, TypeDecorator, create_engine, text
|
||||
@@ -19,12 +20,14 @@ class PathType(TypeDecorator):
|
||||
impl = String
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: Path, dialect: Dialect):
|
||||
@override
|
||||
def process_bind_param(self, value: Path | None, dialect: Dialect):
|
||||
if value is not None:
|
||||
return Path(value).as_posix()
|
||||
return None
|
||||
|
||||
def process_result_value(self, value: str, dialect: Dialect):
|
||||
@override
|
||||
def process_result_value(self, value: str | None, dialect: Dialect):
|
||||
if value is not None:
|
||||
return Path(value)
|
||||
return None
|
||||
|
||||
@@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
|
||||
@@ -32,7 +32,7 @@ class BaseField(Base):
|
||||
|
||||
@declared_attr
|
||||
def type(self) -> Mapped[ValueType]:
|
||||
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore
|
||||
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore # pyright: ignore[reportArgumentType]
|
||||
|
||||
@declared_attr
|
||||
def entry_id(self) -> Mapped[int]:
|
||||
@@ -40,19 +40,20 @@ class BaseField(Base):
|
||||
|
||||
@declared_attr
|
||||
def entry(self) -> Mapped[Entry]:
|
||||
return relationship(foreign_keys=[self.entry_id]) # type: ignore
|
||||
return relationship(foreign_keys=[self.entry_id]) # type: ignore # pyright: ignore[reportArgumentType]
|
||||
|
||||
@declared_attr
|
||||
def position(self) -> Mapped[int]:
|
||||
return mapped_column(default=0)
|
||||
|
||||
@override
|
||||
def __hash__(self):
|
||||
return hash(self.__key())
|
||||
|
||||
def __key(self):
|
||||
def __key(self): # pyright: ignore[reportUnknownParameterType]
|
||||
raise NotImplementedError
|
||||
|
||||
value: Any
|
||||
value: Any # pyright: ignore
|
||||
|
||||
|
||||
class BooleanField(BaseField):
|
||||
@@ -63,7 +64,8 @@ class BooleanField(BaseField):
|
||||
def __key(self):
|
||||
return (self.type, self.value)
|
||||
|
||||
def __eq__(self, value) -> bool:
|
||||
@override
|
||||
def __eq__(self, value: object) -> bool:
|
||||
if isinstance(value, BooleanField):
|
||||
return self.__key() == value.__key()
|
||||
raise NotImplementedError
|
||||
@@ -74,10 +76,11 @@ class TextField(BaseField):
|
||||
|
||||
value: Mapped[str | None]
|
||||
|
||||
def __key(self) -> tuple:
|
||||
def __key(self) -> tuple[ValueType, str | None]:
|
||||
return self.type, self.value
|
||||
|
||||
def __eq__(self, value) -> bool:
|
||||
@override
|
||||
def __eq__(self, value: object) -> bool:
|
||||
if isinstance(value, TextField):
|
||||
return self.__key() == value.__key()
|
||||
elif isinstance(value, DatetimeField):
|
||||
@@ -93,7 +96,8 @@ class DatetimeField(BaseField):
|
||||
def __key(self):
|
||||
return (self.type, self.value)
|
||||
|
||||
def __eq__(self, value) -> bool:
|
||||
@override
|
||||
def __eq__(self, value: object) -> bool:
|
||||
if isinstance(value, DatetimeField):
|
||||
return self.__key() == value.__key()
|
||||
raise NotImplementedError
|
||||
@@ -107,7 +111,7 @@ class DefaultField:
|
||||
is_default: bool = field(default=False)
|
||||
|
||||
|
||||
class _FieldID(Enum):
|
||||
class FieldID(Enum):
|
||||
"""Only for bootstrapping content of DB table."""
|
||||
|
||||
TITLE = DefaultField(id=0, name="Title", type=FieldTypeEnum.TEXT_LINE, is_default=True)
|
||||
|
||||
@@ -2,12 +2,16 @@
|
||||
# Licensed under the GPL-3.0 License.
|
||||
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
|
||||
|
||||
# NOTE: This file contains necessary use of deprecated first-party code until that
|
||||
# code is removed in a future version (prefs).
|
||||
# pyright: reportDeprecated=false
|
||||
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import unicodedata
|
||||
from collections.abc import Iterable, Iterator
|
||||
from collections.abc import Iterable, Iterator, MutableSequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from os import makedirs
|
||||
@@ -18,7 +22,7 @@ from warnings import catch_warnings
|
||||
|
||||
import sqlalchemy
|
||||
import structlog
|
||||
from humanfriendly import format_timespan
|
||||
from humanfriendly import format_timespan # pyright: ignore[reportUnknownVariableType]
|
||||
from sqlalchemy import (
|
||||
URL,
|
||||
ColumnExpressionArgument,
|
||||
@@ -40,6 +44,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import (
|
||||
InstanceState,
|
||||
Session,
|
||||
contains_eager,
|
||||
joinedload,
|
||||
@@ -47,6 +52,7 @@ from sqlalchemy.orm import (
|
||||
noload,
|
||||
selectinload,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from tagstudio.core.constants import (
|
||||
BACKUP_FOLDER_NAME,
|
||||
@@ -81,8 +87,8 @@ from tagstudio.core.library.alchemy.enums import (
|
||||
from tagstudio.core.library.alchemy.fields import (
|
||||
BaseField,
|
||||
DatetimeField,
|
||||
FieldID,
|
||||
TextField,
|
||||
_FieldID,
|
||||
)
|
||||
from tagstudio.core.library.alchemy.joins import TagEntry, TagParent
|
||||
from tagstudio.core.library.alchemy.models import (
|
||||
@@ -210,9 +216,9 @@ class Library:
|
||||
"""Class for the Library object, and all CRUD operations made upon it."""
|
||||
|
||||
library_dir: Path | None = None
|
||||
storage_path: Path | str | None
|
||||
storage_path: Path | str | None = None
|
||||
engine: Engine | None = None
|
||||
folder: Folder | None
|
||||
folder: Folder | None = None
|
||||
included_files: set[Path] = set()
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -224,7 +230,7 @@ class Library:
|
||||
def close(self):
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
self.library_dir: Path | None = None
|
||||
self.library_dir = None
|
||||
self.storage_path = None
|
||||
self.folder = None
|
||||
self.included_files = set()
|
||||
@@ -300,8 +306,8 @@ class Library:
|
||||
]
|
||||
)
|
||||
for entry in json_lib.entries:
|
||||
for field in entry.fields:
|
||||
for k, v in field.items():
|
||||
for field in entry.fields: # pyright: ignore[reportUnknownVariableType]
|
||||
for k, v in field.items(): # pyright: ignore[reportUnknownVariableType]
|
||||
# Old tag fields get added as tags
|
||||
if k in LEGACY_TAG_FIELD_IDS:
|
||||
self.add_tags_to_entries(entry_ids=entry.id + 1, tag_ids=v)
|
||||
@@ -319,8 +325,8 @@ class Library:
|
||||
end_time = time.time()
|
||||
logger.info(f"Library Converted! ({format_timespan(end_time - start_time)})")
|
||||
|
||||
def get_field_name_from_id(self, field_id: int) -> _FieldID:
|
||||
for f in _FieldID:
|
||||
def get_field_name_from_id(self, field_id: int) -> FieldID | None:
|
||||
for f in FieldID:
|
||||
if field_id == f.value.id:
|
||||
return f
|
||||
return None
|
||||
@@ -482,7 +488,7 @@ class Library:
|
||||
except IntegrityError:
|
||||
session.rollback()
|
||||
|
||||
for field in _FieldID:
|
||||
for field in FieldID:
|
||||
try:
|
||||
session.add(
|
||||
ValueType(
|
||||
@@ -562,7 +568,7 @@ class Library:
|
||||
# Repair "Description" fields with a TEXT_LINE key instead of a TEXT_BOX key.
|
||||
desc_stmd = (
|
||||
update(ValueType)
|
||||
.where(ValueType.key == _FieldID.DESCRIPTION.name)
|
||||
.where(ValueType.key == FieldID.DESCRIPTION.name)
|
||||
.values(type=FieldTypeEnum.TEXT_BOX.name)
|
||||
)
|
||||
session.execute(desc_stmd)
|
||||
@@ -697,8 +703,8 @@ class Library:
|
||||
logger.error("[ERROR][Library] Could not generate '.ts_ignore' file!", error=e)
|
||||
|
||||
# Load legacy extension data
|
||||
extensions: list[str] = self.prefs(LibraryPrefs.EXTENSION_LIST) # pyright: ignore[reportAssignmentType]
|
||||
is_exclude_list: bool = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) # pyright: ignore[reportAssignmentType]
|
||||
extensions: list[str] = self.prefs(LibraryPrefs.EXTENSION_LIST) # pyright: ignore
|
||||
is_exclude_list: bool = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) # pyright: ignore
|
||||
|
||||
# Copy extensions to '.ts_ignore' file
|
||||
if ts_ignore.exists():
|
||||
@@ -720,12 +726,6 @@ class Library:
|
||||
)
|
||||
return [x.as_field for x in types]
|
||||
|
||||
def delete_item(self, item):
|
||||
logger.info("deleting item", item=item)
|
||||
with Session(self.engine) as session:
|
||||
session.delete(item)
|
||||
session.commit()
|
||||
|
||||
def get_entry(self, entry_id: int) -> Entry | None:
|
||||
"""Load entry without joins."""
|
||||
with Session(self.engine) as session:
|
||||
@@ -794,7 +794,7 @@ class Library:
|
||||
entries = dict((e.id, e) for e in session.scalars(statement))
|
||||
return [entries[id] for id in entry_ids]
|
||||
|
||||
def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]:
|
||||
def get_entries_full(self, entry_ids: MutableSequence[int]) -> Iterator[Entry]:
|
||||
"""Load entry and join with all joins and all tags."""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(Entry).where(Entry.id.in_(set(entry_ids)))
|
||||
@@ -864,7 +864,7 @@ class Library:
|
||||
@property
|
||||
def entries_count(self) -> int:
|
||||
with Session(self.engine) as session:
|
||||
return session.scalar(select(func.count(Entry.id)))
|
||||
return unwrap(session.scalar(select(func.count(Entry.id))))
|
||||
|
||||
def all_entries(self, with_joins: bool = False) -> Iterator[Entry]:
|
||||
"""Load entries without joins."""
|
||||
@@ -906,7 +906,7 @@ class Library:
|
||||
|
||||
return list(tags_list)
|
||||
|
||||
def verify_ts_folder(self, library_dir: Path) -> bool:
|
||||
def verify_ts_folder(self, library_dir: Path | None) -> bool:
|
||||
"""Verify/create folders required by TagStudio.
|
||||
|
||||
Returns:
|
||||
@@ -960,7 +960,7 @@ class Library:
|
||||
with Session(self.engine) as session:
|
||||
return session.query(exists().where(Entry.path == path)).scalar()
|
||||
|
||||
def get_paths(self, glob: str | None = None, limit: int = -1) -> list[str]:
|
||||
def get_paths(self, limit: int = -1) -> list[str]:
|
||||
path_strings: list[str] = []
|
||||
with Session(self.engine) as session:
|
||||
if limit > 0:
|
||||
@@ -1020,7 +1020,7 @@ class Library:
|
||||
ids = []
|
||||
count = 0
|
||||
for row in rows:
|
||||
id, count = row._tuple()
|
||||
id, count = row._tuple() # pyright: ignore[reportPrivateUsage]
|
||||
ids.append(id)
|
||||
end_time = time.time()
|
||||
logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})")
|
||||
@@ -1109,17 +1109,13 @@ class Library:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def remove_tag(self, tag: Tag):
|
||||
def remove_tag(self, tag_id: int):
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
child_tags = session.scalars(
|
||||
select(TagParent).where(TagParent.child_id == tag.id)
|
||||
select(TagParent).where(TagParent.child_id == tag_id)
|
||||
).all()
|
||||
tags_query = select(Tag).options(
|
||||
selectinload(Tag.parent_tags), selectinload(Tag.aliases)
|
||||
)
|
||||
tag = session.scalar(tags_query.where(Tag.id == tag.id))
|
||||
aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id))
|
||||
aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag_id))
|
||||
|
||||
for alias in aliases or []:
|
||||
session.delete(alias)
|
||||
@@ -1130,24 +1126,18 @@ class Library:
|
||||
|
||||
disam_stmt = (
|
||||
update(Tag)
|
||||
.where(Tag.disambiguation_id == tag.id)
|
||||
.where(Tag.disambiguation_id == tag_id)
|
||||
.values(disambiguation_id=None)
|
||||
)
|
||||
session.execute(disam_stmt)
|
||||
session.flush()
|
||||
|
||||
session.delete(tag)
|
||||
session.query(Tag).filter_by(id=tag_id).delete()
|
||||
session.commit()
|
||||
session.expunge(tag)
|
||||
|
||||
return tag
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(e)
|
||||
session.rollback()
|
||||
|
||||
return None
|
||||
|
||||
def update_field_position(
|
||||
self,
|
||||
field_class: type[BaseField],
|
||||
@@ -1247,7 +1237,7 @@ class Library:
|
||||
|
||||
def get_value_type(self, field_key: str) -> ValueType:
|
||||
with Session(self.engine) as session:
|
||||
field = session.scalar(select(ValueType).where(ValueType.key == field_key))
|
||||
field = unwrap(session.scalar(select(ValueType).where(ValueType.key == field_key)))
|
||||
session.expunge(field)
|
||||
return field
|
||||
|
||||
@@ -1256,7 +1246,7 @@ class Library:
|
||||
entry_id: int,
|
||||
*,
|
||||
field: ValueType | None = None,
|
||||
field_id: _FieldID | str | None = None,
|
||||
field_id: FieldID | str | None = None,
|
||||
value: str | datetime | None = None,
|
||||
) -> bool:
|
||||
logger.info(
|
||||
@@ -1270,9 +1260,9 @@ class Library:
|
||||
assert bool(field) != (field_id is not None)
|
||||
|
||||
if not field:
|
||||
if isinstance(field_id, _FieldID):
|
||||
if isinstance(field_id, FieldID):
|
||||
field_id = field_id.name
|
||||
field = self.get_value_type(field_id)
|
||||
field = self.get_value_type(unwrap(field_id))
|
||||
|
||||
field_model: TextField | DatetimeField
|
||||
if field.type in (FieldTypeEnum.TEXT_LINE, FieldTypeEnum.TEXT_BOX):
|
||||
@@ -1407,9 +1397,9 @@ class Library:
|
||||
def add_tag(
|
||||
self,
|
||||
tag: Tag,
|
||||
parent_ids: list[int] | set[int] | None = None,
|
||||
alias_names: list[str] | set[str] | None = None,
|
||||
alias_ids: list[int] | set[int] | None = None,
|
||||
parent_ids: MutableSequence[int] | None = None,
|
||||
alias_names: MutableSequence[str] | None = None,
|
||||
alias_ids: MutableSequence[int] | None = None,
|
||||
) -> Tag | None:
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
@@ -1432,7 +1422,7 @@ class Library:
|
||||
return None
|
||||
|
||||
def add_tags_to_entries(
|
||||
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
|
||||
self, entry_ids: int | list[int], tag_ids: int | MutableSequence[int]
|
||||
) -> int:
|
||||
"""Add one or more tags to one or more entries.
|
||||
|
||||
@@ -1461,7 +1451,7 @@ class Library:
|
||||
return total_added
|
||||
|
||||
def remove_tags_from_entries(
|
||||
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
|
||||
self, entry_ids: int | list[int], tag_ids: int | MutableSequence[int]
|
||||
) -> bool:
|
||||
"""Remove one or more tags from one or more entries."""
|
||||
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
|
||||
@@ -1619,9 +1609,9 @@ class Library:
|
||||
# When calling session.add with this tag instance sqlalchemy will
|
||||
# attempt to create TagParents that already exist.
|
||||
|
||||
state = inspect(tag)
|
||||
# Prevent sqlalchemy from thinking any fields are different from what's commited
|
||||
# commited_state contains original values for fields that have changed.
|
||||
state: InstanceState[Tag] = inspect(tag)
|
||||
# Prevent sqlalchemy from thinking any fields are different from what's committed
|
||||
# committed_state contains original values for fields that have changed.
|
||||
# empty when no fields have changed
|
||||
state.committed_state.clear()
|
||||
|
||||
@@ -1679,9 +1669,9 @@ class Library:
|
||||
def update_tag(
|
||||
self,
|
||||
tag: Tag,
|
||||
parent_ids: list[int] | set[int] | None = None,
|
||||
alias_names: list[str] | set[str] | None = None,
|
||||
alias_ids: list[int] | set[int] | None = None,
|
||||
parent_ids: MutableSequence[int] | None = None,
|
||||
alias_names: MutableSequence[str] | None = None,
|
||||
alias_ids: MutableSequence[int] | None = None,
|
||||
) -> None:
|
||||
"""Edit a Tag in the Library."""
|
||||
self.add_tag(tag, parent_ids, alias_names, alias_ids)
|
||||
@@ -1735,7 +1725,13 @@ class Library:
|
||||
else:
|
||||
self.add_color(new_color_group)
|
||||
|
||||
def update_aliases(self, tag, alias_ids, alias_names, session):
|
||||
def update_aliases(
|
||||
self,
|
||||
tag: Tag,
|
||||
alias_ids: MutableSequence[int],
|
||||
alias_names: MutableSequence[str],
|
||||
session: Session,
|
||||
):
|
||||
prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all()
|
||||
|
||||
for alias in prev_aliases:
|
||||
@@ -1749,7 +1745,7 @@ class Library:
|
||||
alias = TagAlias(alias_name, tag.id)
|
||||
session.add(alias)
|
||||
|
||||
def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session):
|
||||
def update_parent_tags(self, tag: Tag, parent_ids: MutableSequence[int], session: Session):
|
||||
if tag.id in parent_ids:
|
||||
parent_ids.remove(tag.id)
|
||||
|
||||
@@ -1822,10 +1818,11 @@ class Library:
|
||||
# by older TagStudio versions.
|
||||
engine = sqlalchemy.inspect(self.engine)
|
||||
if engine and engine.has_table("Preferences"):
|
||||
pref = session.scalar(
|
||||
select(Preferences).where(Preferences.key == DB_VERSION_LEGACY_KEY)
|
||||
pref = unwrap(
|
||||
session.scalar(
|
||||
select(Preferences).where(Preferences.key == DB_VERSION_LEGACY_KEY)
|
||||
)
|
||||
)
|
||||
assert pref is not None
|
||||
pref.value = value # pyright: ignore
|
||||
session.add(pref)
|
||||
session.commit()
|
||||
@@ -1833,15 +1830,19 @@ class Library:
|
||||
logger.error("[Library][ERROR] Couldn't add default tag color namespaces", error=e)
|
||||
session.rollback()
|
||||
|
||||
def prefs(self, key: str | LibraryPrefs):
|
||||
# TODO: Remove this once the 'preferences' table is removed.
|
||||
@deprecated("Use `get_version() for version and `ts_ignore` system for extension exclusion.")
|
||||
def prefs(self, key: str | LibraryPrefs): # pyright: ignore[reportUnknownParameterType]
|
||||
# load given item from Preferences table
|
||||
with Session(self.engine) as session:
|
||||
if isinstance(key, LibraryPrefs):
|
||||
return session.scalar(select(Preferences).where(Preferences.key == key.name)).value
|
||||
return session.scalar(select(Preferences).where(Preferences.key == key.name)).value # pyright: ignore
|
||||
else:
|
||||
return session.scalar(select(Preferences).where(Preferences.key == key)).value
|
||||
return session.scalar(select(Preferences).where(Preferences.key == key)).value # pyright: ignore
|
||||
|
||||
def set_prefs(self, key: str | LibraryPrefs, value: Any) -> None:
|
||||
# TODO: Remove this once the 'preferences' table is removed.
|
||||
@deprecated("Use `get_version() for version and `ts_ignore` system for extension exclusion.")
|
||||
def set_prefs(self, key: str | LibraryPrefs, value: Any) -> None: # pyright: ignore[reportExplicitAny]
|
||||
# set given item in Preferences table
|
||||
with Session(self.engine) as session:
|
||||
# load existing preference and update value
|
||||
@@ -1873,7 +1874,7 @@ class Library:
|
||||
|
||||
# assign the field to all entries
|
||||
for entry in entries:
|
||||
for field_key, field in fields.items():
|
||||
for field_key, field in fields.items(): # pyright: ignore[reportUnknownVariableType]
|
||||
if field_key not in existing_fields:
|
||||
self.add_field_to_entry(
|
||||
entry_id=entry.id,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
from typing import override
|
||||
|
||||
from sqlalchemy import JSON, ForeignKey, ForeignKeyConstraint, Integer, event
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -150,25 +151,28 @@ class Tag(Base):
|
||||
self.id = id # pyright: ignore[reportAttributeAccessIssue]
|
||||
super().__init__()
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"<Tag ID: {self.id} Name: {self.name}>"
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
def __lt__(self, other: "Tag") -> bool:
|
||||
return self.name < other.name
|
||||
|
||||
def __le__(self, other) -> bool:
|
||||
def __le__(self, other: "Tag") -> bool:
|
||||
return self.name <= other.name
|
||||
|
||||
def __gt__(self, other) -> bool:
|
||||
def __gt__(self, other: "Tag") -> bool:
|
||||
return self.name > other.name
|
||||
|
||||
def __ge__(self, other) -> bool:
|
||||
def __ge__(self, other: "Tag") -> bool:
|
||||
return self.name >= other.name
|
||||
|
||||
|
||||
@@ -233,6 +237,7 @@ class Entry(Base):
|
||||
date_modified: dt | None = None,
|
||||
date_added: dt | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.folder = folder
|
||||
self.id = id # pyright: ignore[reportAttributeAccessIssue]
|
||||
@@ -280,8 +285,8 @@ class ValueType(Base):
|
||||
key: Mapped[str] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
type: Mapped[FieldTypeEnum] = mapped_column(default=FieldTypeEnum.TEXT_LINE)
|
||||
is_default: Mapped[bool]
|
||||
position: Mapped[int]
|
||||
is_default: Mapped[bool] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
position: Mapped[int] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
# add relations to other tables
|
||||
text_fields: Mapped[list[TextField]] = relationship("TextField", back_populates="type")
|
||||
@@ -306,7 +311,7 @@ class ValueType(Base):
|
||||
|
||||
|
||||
@event.listens_for(ValueType, "before_insert")
|
||||
def slugify_field_key(mapper, connection, target):
|
||||
def slugify_field_key(mapper, connection, target): # pyright: ignore
|
||||
"""Slugify the field key before inserting into the database."""
|
||||
if not target.key:
|
||||
from tagstudio.core.library.alchemy.library import slugify
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text
|
||||
from sqlalchemy import ColumnElement, and_, distinct, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.operators import ilike_op
|
||||
|
||||
from tagstudio.core.library.alchemy.constants import TAG_CHILDREN_ID_QUERY
|
||||
from tagstudio.core.library.alchemy.joins import TagEntry
|
||||
from tagstudio.core.library.alchemy.models import Entry, Tag, TagAlias
|
||||
from tagstudio.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
|
||||
@@ -32,17 +33,6 @@ else:
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
TAG_CHILDREN_ID_QUERY = text("""
|
||||
WITH RECURSIVE ChildTags AS (
|
||||
SELECT :tag_id AS tag_id
|
||||
UNION
|
||||
SELECT tp.child_id AS tag_id
|
||||
FROM tag_parents tp
|
||||
INNER JOIN ChildTags c ON tp.parent_id = c.tag_id
|
||||
)
|
||||
SELECT tag_id FROM ChildTags;
|
||||
""")
|
||||
|
||||
|
||||
def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
|
||||
for s in FILETYPE_EQUIVALENTS:
|
||||
@@ -56,19 +46,22 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
|
||||
super().__init__()
|
||||
self.lib = lib
|
||||
|
||||
def visit_or_list(self, node: ORList) -> ColumnElement[bool]:
|
||||
@override
|
||||
def visit_or_list(self, node: ORList) -> ColumnElement[bool]: # type: ignore
|
||||
tag_ids, bool_expressions = self.__separate_tags(node.elements, only_single=False)
|
||||
if len(tag_ids) > 0:
|
||||
bool_expressions.append(self.__entry_has_any_tags(tag_ids))
|
||||
return or_(*bool_expressions)
|
||||
|
||||
def visit_and_list(self, node: ANDList) -> ColumnElement[bool]:
|
||||
@override
|
||||
def visit_and_list(self, node: ANDList) -> ColumnElement[bool]: # type: ignore
|
||||
tag_ids, bool_expressions = self.__separate_tags(node.terms, only_single=True)
|
||||
if len(tag_ids) > 0:
|
||||
bool_expressions.append(self.__entry_has_all_tags(tag_ids))
|
||||
return and_(*bool_expressions)
|
||||
|
||||
def visit_constraint(self, node: Constraint) -> ColumnElement[bool]:
|
||||
@override
|
||||
def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: # type: ignore
|
||||
"""Returns a Boolean Expression that is true, if the Entry satisfies the constraint."""
|
||||
if len(node.properties) != 0:
|
||||
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG
|
||||
@@ -119,10 +112,12 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
|
||||
# raise exception if Constraint stays unhandled
|
||||
raise NotImplementedError("This type of constraint is not implemented yet")
|
||||
|
||||
def visit_property(self, node: Property) -> ColumnElement[bool]:
|
||||
@override
|
||||
def visit_property(self, node: Property) -> ColumnElement[bool]: # type: ignore
|
||||
raise NotImplementedError("This should never be reached!")
|
||||
|
||||
def visit_not(self, node: Not) -> ColumnElement[bool]:
|
||||
@override
|
||||
def visit_not(self, node: Not) -> ColumnElement[bool]: # type: ignore
|
||||
return ~self.visit(node.child)
|
||||
|
||||
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
|
||||
@@ -143,7 +138,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
|
||||
)
|
||||
if not include_children:
|
||||
return tag_ids
|
||||
outp = []
|
||||
outp: list[int] = []
|
||||
for tag_id in tag_ids:
|
||||
outp.extend(list(session.scalars(TAG_CHILDREN_ID_QUERY, {"tag_id": tag_id})))
|
||||
return outp
|
||||
@@ -174,6 +169,14 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
|
||||
elif len(ids) == 1:
|
||||
tag_ids.append(ids[0])
|
||||
continue
|
||||
case ConstraintType.FileType:
|
||||
pass
|
||||
case ConstraintType.Path:
|
||||
pass
|
||||
case ConstraintType.Special:
|
||||
pass
|
||||
case _:
|
||||
raise NotImplementedError(f"Unhandled constraint: '{term.type}'")
|
||||
|
||||
bool_expressions.append(self.visit(term))
|
||||
return tag_ids, bool_expressions
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
# Copyright (C) 2025
|
||||
# Licensed under the GPL-3.0 License.
|
||||
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Generic, TypeVar, Union
|
||||
from typing import Generic, TypeVar, override
|
||||
|
||||
|
||||
class ConstraintType(Enum):
|
||||
@@ -12,7 +17,7 @@ class ConstraintType(Enum):
|
||||
Special = 5
|
||||
|
||||
@staticmethod
|
||||
def from_string(text: str) -> Union["ConstraintType", None]:
|
||||
def from_string(text: str) -> "ConstraintType | None":
|
||||
return {
|
||||
"tag": ConstraintType.Tag,
|
||||
"tag_id": ConstraintType.TagID,
|
||||
@@ -24,14 +29,16 @@ class ConstraintType(Enum):
|
||||
|
||||
|
||||
class AST:
|
||||
parent: Union["AST", None] = None
|
||||
parent: "AST | None" = None
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
class_name = self.__class__.__name__
|
||||
fields = vars(self) # Get all instance variables as a dictionary
|
||||
field_str = ", ".join(f"{key}={value}" for key, value in fields.items())
|
||||
return f"{class_name}({field_str})"
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
# Copyright (C) 2025
|
||||
# Licensed under the GPL-3.0 License.
|
||||
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
|
||||
|
||||
|
||||
from tagstudio.core.query_lang.ast import (
|
||||
AST,
|
||||
ANDList,
|
||||
@@ -27,7 +32,7 @@ class Parser:
|
||||
if self.next_token.type == TokenType.EOF:
|
||||
return ORList([])
|
||||
out = self.__or_list()
|
||||
if self.next_token.type != TokenType.EOF:
|
||||
if self.next_token.type != TokenType.EOF: # pyright: ignore[reportUnnecessaryComparison]
|
||||
raise ParsingError(self.next_token.start, self.next_token.end, "Syntax Error")
|
||||
return out
|
||||
|
||||
@@ -41,7 +46,7 @@ class Parser:
|
||||
return ORList(terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
def __is_next_or(self) -> bool:
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "OR"
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "OR" # pyright: ignore
|
||||
|
||||
def __and_list(self) -> AST:
|
||||
elements = [self.__term()]
|
||||
@@ -67,7 +72,7 @@ class Parser:
|
||||
raise self.__syntax_error("Unexpected AND")
|
||||
|
||||
def __is_next_and(self) -> bool:
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "AND"
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "AND" # pyright: ignore
|
||||
|
||||
def __term(self) -> AST:
|
||||
if self.__is_next_not():
|
||||
@@ -85,11 +90,14 @@ class Parser:
|
||||
return self.__constraint()
|
||||
|
||||
def __is_next_not(self) -> bool:
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "NOT"
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "NOT" # pyright: ignore
|
||||
|
||||
def __constraint(self) -> Constraint:
|
||||
if self.next_token.type == TokenType.CONSTRAINTTYPE:
|
||||
self.last_constraint_type = self.__eat(TokenType.CONSTRAINTTYPE).value
|
||||
constraint = self.__eat(TokenType.CONSTRAINTTYPE).value
|
||||
if not isinstance(constraint, ConstraintType):
|
||||
raise self.__syntax_error()
|
||||
self.last_constraint_type = constraint
|
||||
|
||||
value = self.__literal()
|
||||
|
||||
@@ -98,7 +106,7 @@ class Parser:
|
||||
self.__eat(TokenType.SBRACKETO)
|
||||
properties.append(self.__property())
|
||||
|
||||
while self.next_token.type == TokenType.COMMA:
|
||||
while self.next_token.type == TokenType.COMMA: # pyright: ignore[reportUnnecessaryComparison]
|
||||
self.__eat(TokenType.COMMA)
|
||||
properties.append(self.__property())
|
||||
|
||||
@@ -110,11 +118,16 @@ class Parser:
|
||||
key = self.__eat(TokenType.ULITERAL).value
|
||||
self.__eat(TokenType.EQUALS)
|
||||
value = self.__literal()
|
||||
if not isinstance(key, str):
|
||||
raise self.__syntax_error()
|
||||
return Property(key, value)
|
||||
|
||||
def __literal(self) -> str:
|
||||
if self.next_token.type in [TokenType.QLITERAL, TokenType.ULITERAL]:
|
||||
return self.__eat(self.next_token.type).value
|
||||
literal = self.__eat(self.next_token.type).value
|
||||
if not isinstance(literal, str):
|
||||
raise self.__syntax_error()
|
||||
return literal
|
||||
raise self.__syntax_error()
|
||||
|
||||
def __eat(self, type: TokenType) -> Token:
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
# Copyright (C) 2025
|
||||
# Licensed under the GPL-3.0 License.
|
||||
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import override
|
||||
|
||||
from tagstudio.core.query_lang.ast import ConstraintType
|
||||
from tagstudio.core.query_lang.util import ParsingError
|
||||
@@ -21,12 +26,14 @@ class TokenType(Enum):
|
||||
|
||||
class Token:
|
||||
type: TokenType
|
||||
value: Any
|
||||
value: str | ConstraintType | None
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def __init__(self, type: TokenType, value: Any, start: int, end: int) -> None:
|
||||
def __init__(
|
||||
self, type: TokenType, value: str | ConstraintType | None, start: int, end: int
|
||||
) -> None:
|
||||
self.type = type
|
||||
self.value = value
|
||||
self.start = start
|
||||
@@ -40,9 +47,11 @@ class Token:
|
||||
def EOF(pos: int) -> "Token": # noqa: N802
|
||||
return Token.from_type(TokenType.EOF, pos)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"Token({self.type}, {self.value}, {self.start}, {self.end})" # pragma: nocover
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__() # pragma: nocover
|
||||
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
# Copyright (C) 2025
|
||||
# Licensed under the GPL-3.0 License.
|
||||
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
|
||||
|
||||
|
||||
from typing import override
|
||||
|
||||
|
||||
class ParsingError(BaseException):
|
||||
start: int
|
||||
end: int
|
||||
msg: str
|
||||
|
||||
def __init__(self, start: int, end: int, msg: str = "Syntax Error") -> None:
|
||||
super().__init__()
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.msg = msg
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"Syntax Error {self.start}->{self.end}: {self.msg}" # pragma: nocover
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__() # pragma: nocover
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
import structlog
|
||||
|
||||
from tagstudio.core.constants import TS_FOLDER_NAME
|
||||
from tagstudio.core.library.alchemy.fields import _FieldID
|
||||
from tagstudio.core.library.alchemy.fields import FieldID
|
||||
from tagstudio.core.library.alchemy.library import Library
|
||||
from tagstudio.core.library.alchemy.models import Entry
|
||||
|
||||
@@ -46,27 +46,27 @@ class TagStudioCore:
|
||||
return {}
|
||||
|
||||
if source == "twitter":
|
||||
info[_FieldID.DESCRIPTION] = json_dump["content"].strip()
|
||||
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
info[FieldID.DESCRIPTION] = json_dump["content"].strip()
|
||||
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
elif source == "instagram":
|
||||
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
|
||||
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
info[FieldID.DESCRIPTION] = json_dump["description"].strip()
|
||||
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
elif source == "artstation":
|
||||
info[_FieldID.TITLE] = json_dump["title"].strip()
|
||||
info[_FieldID.ARTIST] = json_dump["user"]["full_name"].strip()
|
||||
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
|
||||
info[_FieldID.TAGS] = json_dump["tags"]
|
||||
info[FieldID.TITLE] = json_dump["title"].strip()
|
||||
info[FieldID.ARTIST] = json_dump["user"]["full_name"].strip()
|
||||
info[FieldID.DESCRIPTION] = json_dump["description"].strip()
|
||||
info[FieldID.TAGS] = json_dump["tags"]
|
||||
# info["tags"] = [x for x in json_dump["mediums"]["name"]]
|
||||
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
elif source == "newgrounds":
|
||||
# info["title"] = json_dump["title"]
|
||||
# info["artist"] = json_dump["artist"]
|
||||
# info["description"] = json_dump["description"]
|
||||
info[_FieldID.TAGS] = json_dump["tags"]
|
||||
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
info[_FieldID.ARTIST] = json_dump["user"].strip()
|
||||
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
|
||||
info[_FieldID.SOURCE] = json_dump["post_url"].strip()
|
||||
info[FieldID.TAGS] = json_dump["tags"]
|
||||
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
|
||||
info[FieldID.ARTIST] = json_dump["user"].strip()
|
||||
info[FieldID.DESCRIPTION] = json_dump["description"].strip()
|
||||
info[FieldID.SOURCE] = json_dump["post_url"].strip()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error handling sidecar file.", path=_filepath)
|
||||
|
||||
@@ -71,5 +71,5 @@ class TagDatabasePanel(TagSearchPanel):
|
||||
if result != QMessageBox.Ok: # type: ignore
|
||||
return
|
||||
|
||||
self.lib.remove_tag(tag)
|
||||
self.lib.remove_tag(tag.id)
|
||||
self.update_tags()
|
||||
|
||||
@@ -55,7 +55,7 @@ from tagstudio.core.library.alchemy.enums import (
|
||||
ItemType,
|
||||
SortingModeEnum,
|
||||
)
|
||||
from tagstudio.core.library.alchemy.fields import _FieldID
|
||||
from tagstudio.core.library.alchemy.fields import FieldID
|
||||
from tagstudio.core.library.alchemy.library import Library, LibraryStatus
|
||||
from tagstudio.core.library.alchemy.models import Entry
|
||||
from tagstudio.core.library.ignore import Ignore
|
||||
@@ -1129,7 +1129,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
elif name == MacroID.BUILD_URL:
|
||||
url = TagStudioCore.build_url(entry, source)
|
||||
if url is not None:
|
||||
self.lib.add_field_to_entry(entry.id, field_id=_FieldID.SOURCE, value=url)
|
||||
self.lib.add_field_to_entry(entry.id, field_id=FieldID.SOURCE, value=url)
|
||||
elif name == MacroID.MATCH:
|
||||
TagStudioCore.match_conditions(self.lib, entry.id)
|
||||
elif name == MacroID.CLEAN_URL:
|
||||
|
||||
@@ -13,8 +13,8 @@ import structlog
|
||||
from tagstudio.core.enums import DefaultEnum, LibraryPrefs
|
||||
from tagstudio.core.library.alchemy.enums import BrowsingState
|
||||
from tagstudio.core.library.alchemy.fields import (
|
||||
FieldID, # pyright: ignore[reportPrivateUsage]
|
||||
TextField,
|
||||
_FieldID, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from tagstudio.core.library.alchemy.library import Library
|
||||
from tagstudio.core.library.alchemy.models import Entry, Tag
|
||||
@@ -174,7 +174,7 @@ def test_remove_tag(library: Library, generate_tag: Callable[..., Tag]):
|
||||
|
||||
tag_count = len(library.tags)
|
||||
|
||||
library.remove_tag(tag)
|
||||
library.remove_tag(tag.id)
|
||||
assert len(library.tags) == tag_count - 1
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ def test_mirror_entry_fields(library: Library, entry_full: Entry):
|
||||
path=Path("xxx"),
|
||||
fields=[
|
||||
TextField(
|
||||
type_key=_FieldID.NOTES.name,
|
||||
type_key=FieldID.NOTES.name,
|
||||
value="notes",
|
||||
position=0,
|
||||
)
|
||||
@@ -292,8 +292,8 @@ def test_mirror_entry_fields(library: Library, entry_full: Entry):
|
||||
# make sure fields are there after getting it from the library again
|
||||
assert len(entry.fields) == 2
|
||||
assert {x.type_key for x in entry.fields} == {
|
||||
_FieldID.TITLE.name,
|
||||
_FieldID.NOTES.name,
|
||||
FieldID.TITLE.name,
|
||||
FieldID.NOTES.name,
|
||||
}
|
||||
|
||||
|
||||
@@ -308,14 +308,14 @@ def test_merge_entries(library: Library):
|
||||
folder=folder,
|
||||
path=Path("a"),
|
||||
fields=[
|
||||
TextField(type_key=_FieldID.AUTHOR.name, value="Author McAuthorson", position=0),
|
||||
TextField(type_key=_FieldID.DESCRIPTION.name, value="test description", position=2),
|
||||
TextField(type_key=FieldID.AUTHOR.name, value="Author McAuthorson", position=0),
|
||||
TextField(type_key=FieldID.DESCRIPTION.name, value="test description", position=2),
|
||||
],
|
||||
)
|
||||
b = Entry(
|
||||
folder=folder,
|
||||
path=Path("b"),
|
||||
fields=[TextField(type_key=_FieldID.NOTES.name, value="test note", position=1)],
|
||||
fields=[TextField(type_key=FieldID.NOTES.name, value="test note", position=1)],
|
||||
)
|
||||
ids = library.add_entries([a, b])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user