refactor: fix most pyright issues in library/alchemy/ (#1103)

* refactor: fix most pyright issues in library/alchemy/

* chore: implement review feedback
This commit is contained in:
Travis Abendshien
2025-09-06 14:20:05 -07:00
committed by GitHub
parent f49cb4fade
commit cee64a8c31
15 changed files with 214 additions and 143 deletions

View File

@@ -85,7 +85,11 @@ ignore_errors = true
qt_api = "pyside6"
[tool.pyright]
ignore = ["src/tagstudio/qt/previews/vendored/pydub/", ".venv/**"]
ignore = [
".venv/**",
"src/tagstudio/core/library/json/",
"src/tagstudio/qt/previews/vendored/pydub/",
]
include = ["src/tagstudio", "tests"]
reportAny = false
reportIgnoreCommentWithoutRule = false

View File

@@ -23,3 +23,14 @@ WITH RECURSIVE ChildTags AS (
)
SELECT * FROM ChildTags;
""")
TAG_CHILDREN_ID_QUERY = text("""
WITH RECURSIVE ChildTags AS (
SELECT :tag_id AS tag_id
UNION
SELECT tp.child_id AS tag_id
FROM tag_parents tp
INNER JOIN ChildTags c ON tp.parent_id = c.tag_id
)
SELECT tag_id FROM ChildTags;
""")

View File

@@ -4,6 +4,7 @@
from pathlib import Path
from typing import override
import structlog
from sqlalchemy import Dialect, Engine, String, TypeDecorator, create_engine, text
@@ -19,12 +20,14 @@ class PathType(TypeDecorator):
impl = String
cache_ok = True
def process_bind_param(self, value: Path, dialect: Dialect):
@override
def process_bind_param(self, value: Path | None, dialect: Dialect):
if value is not None:
return Path(value).as_posix()
return None
def process_result_value(self, value: str, dialect: Dialect):
@override
def process_result_value(self, value: str | None, dialect: Dialect):
if value is not None:
return Path(value)
return None

View File

@@ -7,7 +7,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
@@ -32,7 +32,7 @@ class BaseField(Base):
@declared_attr
def type(self) -> Mapped[ValueType]:
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore # pyright: ignore[reportArgumentType]
@declared_attr
def entry_id(self) -> Mapped[int]:
@@ -40,19 +40,20 @@ class BaseField(Base):
@declared_attr
def entry(self) -> Mapped[Entry]:
return relationship(foreign_keys=[self.entry_id]) # type: ignore
return relationship(foreign_keys=[self.entry_id]) # type: ignore # pyright: ignore[reportArgumentType]
@declared_attr
def position(self) -> Mapped[int]:
return mapped_column(default=0)
@override
def __hash__(self):
return hash(self.__key())
def __key(self):
def __key(self): # pyright: ignore[reportUnknownParameterType]
raise NotImplementedError
value: Any
value: Any # pyright: ignore
class BooleanField(BaseField):
@@ -63,7 +64,8 @@ class BooleanField(BaseField):
def __key(self):
return (self.type, self.value)
def __eq__(self, value) -> bool:
@override
def __eq__(self, value: object) -> bool:
if isinstance(value, BooleanField):
return self.__key() == value.__key()
raise NotImplementedError
@@ -74,10 +76,11 @@ class TextField(BaseField):
value: Mapped[str | None]
def __key(self) -> tuple:
def __key(self) -> tuple[ValueType, str | None]:
return self.type, self.value
def __eq__(self, value) -> bool:
@override
def __eq__(self, value: object) -> bool:
if isinstance(value, TextField):
return self.__key() == value.__key()
elif isinstance(value, DatetimeField):
@@ -93,7 +96,8 @@ class DatetimeField(BaseField):
def __key(self):
return (self.type, self.value)
def __eq__(self, value) -> bool:
@override
def __eq__(self, value: object) -> bool:
if isinstance(value, DatetimeField):
return self.__key() == value.__key()
raise NotImplementedError
@@ -107,7 +111,7 @@ class DefaultField:
is_default: bool = field(default=False)
class _FieldID(Enum):
class FieldID(Enum):
"""Only for bootstrapping content of DB table."""
TITLE = DefaultField(id=0, name="Title", type=FieldTypeEnum.TEXT_LINE, is_default=True)

View File

@@ -2,12 +2,16 @@
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
# NOTE: This file contains necessary use of deprecated first-party code until that
# code is removed in a future version (prefs).
# pyright: reportDeprecated=false
import re
import shutil
import time
import unicodedata
from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, MutableSequence
from dataclasses import dataclass
from datetime import UTC, datetime
from os import makedirs
@@ -18,7 +22,7 @@ from warnings import catch_warnings
import sqlalchemy
import structlog
from humanfriendly import format_timespan
from humanfriendly import format_timespan # pyright: ignore[reportUnknownVariableType]
from sqlalchemy import (
URL,
ColumnExpressionArgument,
@@ -40,6 +44,7 @@ from sqlalchemy import (
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import (
InstanceState,
Session,
contains_eager,
joinedload,
@@ -47,6 +52,7 @@ from sqlalchemy.orm import (
noload,
selectinload,
)
from typing_extensions import deprecated
from tagstudio.core.constants import (
BACKUP_FOLDER_NAME,
@@ -81,8 +87,8 @@ from tagstudio.core.library.alchemy.enums import (
from tagstudio.core.library.alchemy.fields import (
BaseField,
DatetimeField,
FieldID,
TextField,
_FieldID,
)
from tagstudio.core.library.alchemy.joins import TagEntry, TagParent
from tagstudio.core.library.alchemy.models import (
@@ -210,9 +216,9 @@ class Library:
"""Class for the Library object, and all CRUD operations made upon it."""
library_dir: Path | None = None
storage_path: Path | str | None
storage_path: Path | str | None = None
engine: Engine | None = None
folder: Folder | None
folder: Folder | None = None
included_files: set[Path] = set()
def __init__(self) -> None:
@@ -224,7 +230,7 @@ class Library:
def close(self):
if self.engine:
self.engine.dispose()
self.library_dir: Path | None = None
self.library_dir = None
self.storage_path = None
self.folder = None
self.included_files = set()
@@ -300,8 +306,8 @@ class Library:
]
)
for entry in json_lib.entries:
for field in entry.fields:
for k, v in field.items():
for field in entry.fields: # pyright: ignore[reportUnknownVariableType]
for k, v in field.items(): # pyright: ignore[reportUnknownVariableType]
# Old tag fields get added as tags
if k in LEGACY_TAG_FIELD_IDS:
self.add_tags_to_entries(entry_ids=entry.id + 1, tag_ids=v)
@@ -319,8 +325,8 @@ class Library:
end_time = time.time()
logger.info(f"Library Converted! ({format_timespan(end_time - start_time)})")
def get_field_name_from_id(self, field_id: int) -> _FieldID:
for f in _FieldID:
def get_field_name_from_id(self, field_id: int) -> FieldID | None:
for f in FieldID:
if field_id == f.value.id:
return f
return None
@@ -482,7 +488,7 @@ class Library:
except IntegrityError:
session.rollback()
for field in _FieldID:
for field in FieldID:
try:
session.add(
ValueType(
@@ -562,7 +568,7 @@ class Library:
# Repair "Description" fields with a TEXT_LINE key instead of a TEXT_BOX key.
desc_stmd = (
update(ValueType)
.where(ValueType.key == _FieldID.DESCRIPTION.name)
.where(ValueType.key == FieldID.DESCRIPTION.name)
.values(type=FieldTypeEnum.TEXT_BOX.name)
)
session.execute(desc_stmd)
@@ -697,8 +703,8 @@ class Library:
logger.error("[ERROR][Library] Could not generate '.ts_ignore' file!", error=e)
# Load legacy extension data
extensions: list[str] = self.prefs(LibraryPrefs.EXTENSION_LIST) # pyright: ignore[reportAssignmentType]
is_exclude_list: bool = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) # pyright: ignore[reportAssignmentType]
extensions: list[str] = self.prefs(LibraryPrefs.EXTENSION_LIST) # pyright: ignore
is_exclude_list: bool = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) # pyright: ignore
# Copy extensions to '.ts_ignore' file
if ts_ignore.exists():
@@ -720,12 +726,6 @@ class Library:
)
return [x.as_field for x in types]
def delete_item(self, item):
logger.info("deleting item", item=item)
with Session(self.engine) as session:
session.delete(item)
session.commit()
def get_entry(self, entry_id: int) -> Entry | None:
"""Load entry without joins."""
with Session(self.engine) as session:
@@ -794,7 +794,7 @@ class Library:
entries = dict((e.id, e) for e in session.scalars(statement))
return [entries[id] for id in entry_ids]
def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]:
def get_entries_full(self, entry_ids: MutableSequence[int]) -> Iterator[Entry]:
"""Load entry and join with all joins and all tags."""
with Session(self.engine) as session:
statement = select(Entry).where(Entry.id.in_(set(entry_ids)))
@@ -864,7 +864,7 @@ class Library:
@property
def entries_count(self) -> int:
with Session(self.engine) as session:
return session.scalar(select(func.count(Entry.id)))
return unwrap(session.scalar(select(func.count(Entry.id))))
def all_entries(self, with_joins: bool = False) -> Iterator[Entry]:
"""Load entries without joins."""
@@ -906,7 +906,7 @@ class Library:
return list(tags_list)
def verify_ts_folder(self, library_dir: Path) -> bool:
def verify_ts_folder(self, library_dir: Path | None) -> bool:
"""Verify/create folders required by TagStudio.
Returns:
@@ -960,7 +960,7 @@ class Library:
with Session(self.engine) as session:
return session.query(exists().where(Entry.path == path)).scalar()
def get_paths(self, glob: str | None = None, limit: int = -1) -> list[str]:
def get_paths(self, limit: int = -1) -> list[str]:
path_strings: list[str] = []
with Session(self.engine) as session:
if limit > 0:
@@ -1020,7 +1020,7 @@ class Library:
ids = []
count = 0
for row in rows:
id, count = row._tuple()
id, count = row._tuple() # pyright: ignore[reportPrivateUsage]
ids.append(id)
end_time = time.time()
logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})")
@@ -1109,17 +1109,13 @@ class Library:
session.commit()
return True
def remove_tag(self, tag: Tag):
def remove_tag(self, tag_id: int):
with Session(self.engine, expire_on_commit=False) as session:
try:
child_tags = session.scalars(
select(TagParent).where(TagParent.child_id == tag.id)
select(TagParent).where(TagParent.child_id == tag_id)
).all()
tags_query = select(Tag).options(
selectinload(Tag.parent_tags), selectinload(Tag.aliases)
)
tag = session.scalar(tags_query.where(Tag.id == tag.id))
aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id))
aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag_id))
for alias in aliases or []:
session.delete(alias)
@@ -1130,24 +1126,18 @@ class Library:
disam_stmt = (
update(Tag)
.where(Tag.disambiguation_id == tag.id)
.where(Tag.disambiguation_id == tag_id)
.values(disambiguation_id=None)
)
session.execute(disam_stmt)
session.flush()
session.delete(tag)
session.query(Tag).filter_by(id=tag_id).delete()
session.commit()
session.expunge(tag)
return tag
except IntegrityError as e:
logger.error(e)
session.rollback()
return None
def update_field_position(
self,
field_class: type[BaseField],
@@ -1247,7 +1237,7 @@ class Library:
def get_value_type(self, field_key: str) -> ValueType:
with Session(self.engine) as session:
field = session.scalar(select(ValueType).where(ValueType.key == field_key))
field = unwrap(session.scalar(select(ValueType).where(ValueType.key == field_key)))
session.expunge(field)
return field
@@ -1256,7 +1246,7 @@ class Library:
entry_id: int,
*,
field: ValueType | None = None,
field_id: _FieldID | str | None = None,
field_id: FieldID | str | None = None,
value: str | datetime | None = None,
) -> bool:
logger.info(
@@ -1270,9 +1260,9 @@ class Library:
assert bool(field) != (field_id is not None)
if not field:
if isinstance(field_id, _FieldID):
if isinstance(field_id, FieldID):
field_id = field_id.name
field = self.get_value_type(field_id)
field = self.get_value_type(unwrap(field_id))
field_model: TextField | DatetimeField
if field.type in (FieldTypeEnum.TEXT_LINE, FieldTypeEnum.TEXT_BOX):
@@ -1407,9 +1397,9 @@ class Library:
def add_tag(
self,
tag: Tag,
parent_ids: list[int] | set[int] | None = None,
alias_names: list[str] | set[str] | None = None,
alias_ids: list[int] | set[int] | None = None,
parent_ids: MutableSequence[int] | None = None,
alias_names: MutableSequence[str] | None = None,
alias_ids: MutableSequence[int] | None = None,
) -> Tag | None:
with Session(self.engine, expire_on_commit=False) as session:
try:
@@ -1432,7 +1422,7 @@ class Library:
return None
def add_tags_to_entries(
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
self, entry_ids: int | list[int], tag_ids: int | MutableSequence[int]
) -> int:
"""Add one or more tags to one or more entries.
@@ -1461,7 +1451,7 @@ class Library:
return total_added
def remove_tags_from_entries(
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
self, entry_ids: int | list[int], tag_ids: int | MutableSequence[int]
) -> bool:
"""Remove one or more tags from one or more entries."""
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
@@ -1619,9 +1609,9 @@ class Library:
# When calling session.add with this tag instance sqlalchemy will
# attempt to create TagParents that already exist.
state = inspect(tag)
# Prevent sqlalchemy from thinking any fields are different from what's commited
# commited_state contains original values for fields that have changed.
state: InstanceState[Tag] = inspect(tag)
# Prevent sqlalchemy from thinking any fields are different from what's committed
# committed_state contains original values for fields that have changed.
# empty when no fields have changed
state.committed_state.clear()
@@ -1679,9 +1669,9 @@ class Library:
def update_tag(
self,
tag: Tag,
parent_ids: list[int] | set[int] | None = None,
alias_names: list[str] | set[str] | None = None,
alias_ids: list[int] | set[int] | None = None,
parent_ids: MutableSequence[int] | None = None,
alias_names: MutableSequence[str] | None = None,
alias_ids: MutableSequence[int] | None = None,
) -> None:
"""Edit a Tag in the Library."""
self.add_tag(tag, parent_ids, alias_names, alias_ids)
@@ -1735,7 +1725,13 @@ class Library:
else:
self.add_color(new_color_group)
def update_aliases(self, tag, alias_ids, alias_names, session):
def update_aliases(
self,
tag: Tag,
alias_ids: MutableSequence[int],
alias_names: MutableSequence[str],
session: Session,
):
prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all()
for alias in prev_aliases:
@@ -1749,7 +1745,7 @@ class Library:
alias = TagAlias(alias_name, tag.id)
session.add(alias)
def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session):
def update_parent_tags(self, tag: Tag, parent_ids: MutableSequence[int], session: Session):
if tag.id in parent_ids:
parent_ids.remove(tag.id)
@@ -1822,10 +1818,11 @@ class Library:
# by older TagStudio versions.
engine = sqlalchemy.inspect(self.engine)
if engine and engine.has_table("Preferences"):
pref = session.scalar(
select(Preferences).where(Preferences.key == DB_VERSION_LEGACY_KEY)
pref = unwrap(
session.scalar(
select(Preferences).where(Preferences.key == DB_VERSION_LEGACY_KEY)
)
)
assert pref is not None
pref.value = value # pyright: ignore
session.add(pref)
session.commit()
@@ -1833,15 +1830,19 @@ class Library:
logger.error("[Library][ERROR] Couldn't add default tag color namespaces", error=e)
session.rollback()
def prefs(self, key: str | LibraryPrefs):
# TODO: Remove this once the 'preferences' table is removed.
@deprecated("Use `get_version() for version and `ts_ignore` system for extension exclusion.")
def prefs(self, key: str | LibraryPrefs): # pyright: ignore[reportUnknownParameterType]
# load given item from Preferences table
with Session(self.engine) as session:
if isinstance(key, LibraryPrefs):
return session.scalar(select(Preferences).where(Preferences.key == key.name)).value
return session.scalar(select(Preferences).where(Preferences.key == key.name)).value # pyright: ignore
else:
return session.scalar(select(Preferences).where(Preferences.key == key)).value
return session.scalar(select(Preferences).where(Preferences.key == key)).value # pyright: ignore
def set_prefs(self, key: str | LibraryPrefs, value: Any) -> None:
# TODO: Remove this once the 'preferences' table is removed.
@deprecated("Use `get_version() for version and `ts_ignore` system for extension exclusion.")
def set_prefs(self, key: str | LibraryPrefs, value: Any) -> None: # pyright: ignore[reportExplicitAny]
# set given item in Preferences table
with Session(self.engine) as session:
# load existing preference and update value
@@ -1873,7 +1874,7 @@ class Library:
# assign the field to all entries
for entry in entries:
for field_key, field in fields.items():
for field_key, field in fields.items(): # pyright: ignore[reportUnknownVariableType]
if field_key not in existing_fields:
self.add_field_to_entry(
entry_id=entry.id,

View File

@@ -4,6 +4,7 @@
from datetime import datetime as dt
from pathlib import Path
from typing import override
from sqlalchemy import JSON, ForeignKey, ForeignKeyConstraint, Integer, event
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -150,25 +151,28 @@ class Tag(Base):
self.id = id # pyright: ignore[reportAttributeAccessIssue]
super().__init__()
@override
def __str__(self) -> str:
return f"<Tag ID: {self.id} Name: {self.name}>"
@override
def __repr__(self) -> str:
return self.__str__()
@override
def __hash__(self) -> int:
return hash(self.id)
def __lt__(self, other) -> bool:
def __lt__(self, other: "Tag") -> bool:
return self.name < other.name
def __le__(self, other) -> bool:
def __le__(self, other: "Tag") -> bool:
return self.name <= other.name
def __gt__(self, other) -> bool:
def __gt__(self, other: "Tag") -> bool:
return self.name > other.name
def __ge__(self, other) -> bool:
def __ge__(self, other: "Tag") -> bool:
return self.name >= other.name
@@ -233,6 +237,7 @@ class Entry(Base):
date_modified: dt | None = None,
date_added: dt | None = None,
) -> None:
super().__init__()
self.path = path
self.folder = folder
self.id = id # pyright: ignore[reportAttributeAccessIssue]
@@ -280,8 +285,8 @@ class ValueType(Base):
key: Mapped[str] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False)
type: Mapped[FieldTypeEnum] = mapped_column(default=FieldTypeEnum.TEXT_LINE)
is_default: Mapped[bool]
position: Mapped[int]
is_default: Mapped[bool] # pyright: ignore[reportUninitializedInstanceVariable]
position: Mapped[int] # pyright: ignore[reportUninitializedInstanceVariable]
# add relations to other tables
text_fields: Mapped[list[TextField]] = relationship("TextField", back_populates="type")
@@ -306,7 +311,7 @@ class ValueType(Base):
@event.listens_for(ValueType, "before_insert")
def slugify_field_key(mapper, connection, target):
def slugify_field_key(mapper, connection, target): # pyright: ignore
"""Slugify the field key before inserting into the database."""
if not target.key:
from tagstudio.core.library.alchemy.library import slugify

View File

@@ -3,13 +3,14 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, override
import structlog
from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text
from sqlalchemy import ColumnElement, and_, distinct, func, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.sql.operators import ilike_op
from tagstudio.core.library.alchemy.constants import TAG_CHILDREN_ID_QUERY
from tagstudio.core.library.alchemy.joins import TagEntry
from tagstudio.core.library.alchemy.models import Entry, Tag, TagAlias
from tagstudio.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
@@ -32,17 +33,6 @@ else:
logger = structlog.get_logger(__name__)
TAG_CHILDREN_ID_QUERY = text("""
WITH RECURSIVE ChildTags AS (
SELECT :tag_id AS tag_id
UNION
SELECT tp.child_id AS tag_id
FROM tag_parents tp
INNER JOIN ChildTags c ON tp.parent_id = c.tag_id
)
SELECT tag_id FROM ChildTags;
""")
def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
for s in FILETYPE_EQUIVALENTS:
@@ -56,19 +46,22 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
super().__init__()
self.lib = lib
def visit_or_list(self, node: ORList) -> ColumnElement[bool]:
@override
def visit_or_list(self, node: ORList) -> ColumnElement[bool]: # type: ignore
tag_ids, bool_expressions = self.__separate_tags(node.elements, only_single=False)
if len(tag_ids) > 0:
bool_expressions.append(self.__entry_has_any_tags(tag_ids))
return or_(*bool_expressions)
def visit_and_list(self, node: ANDList) -> ColumnElement[bool]:
@override
def visit_and_list(self, node: ANDList) -> ColumnElement[bool]: # type: ignore
tag_ids, bool_expressions = self.__separate_tags(node.terms, only_single=True)
if len(tag_ids) > 0:
bool_expressions.append(self.__entry_has_all_tags(tag_ids))
return and_(*bool_expressions)
def visit_constraint(self, node: Constraint) -> ColumnElement[bool]:
@override
def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: # type: ignore
"""Returns a Boolean Expression that is true, if the Entry satisfies the constraint."""
if len(node.properties) != 0:
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG
@@ -119,10 +112,12 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
# raise exception if Constraint stays unhandled
raise NotImplementedError("This type of constraint is not implemented yet")
def visit_property(self, node: Property) -> ColumnElement[bool]:
@override
def visit_property(self, node: Property) -> ColumnElement[bool]: # type: ignore
raise NotImplementedError("This should never be reached!")
def visit_not(self, node: Not) -> ColumnElement[bool]:
@override
def visit_not(self, node: Not) -> ColumnElement[bool]: # type: ignore
return ~self.visit(node.child)
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
@@ -143,7 +138,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
)
if not include_children:
return tag_ids
outp = []
outp: list[int] = []
for tag_id in tag_ids:
outp.extend(list(session.scalars(TAG_CHILDREN_ID_QUERY, {"tag_id": tag_id})))
return outp
@@ -174,6 +169,14 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
elif len(ids) == 1:
tag_ids.append(ids[0])
continue
case ConstraintType.FileType:
pass
case ConstraintType.Path:
pass
case ConstraintType.Special:
pass
case _:
raise NotImplementedError(f"Unhandled constraint: '{term.type}'")
bool_expressions.append(self.visit(term))
return tag_ids, bool_expressions

View File

@@ -1,6 +1,11 @@
# Copyright (C) 2025
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from abc import ABC, abstractmethod
from enum import Enum
from typing import Generic, TypeVar, Union
from typing import Generic, TypeVar, override
class ConstraintType(Enum):
@@ -12,7 +17,7 @@ class ConstraintType(Enum):
Special = 5
@staticmethod
def from_string(text: str) -> Union["ConstraintType", None]:
def from_string(text: str) -> "ConstraintType | None":
return {
"tag": ConstraintType.Tag,
"tag_id": ConstraintType.TagID,
@@ -24,14 +29,16 @@ class ConstraintType(Enum):
class AST:
parent: Union["AST", None] = None
parent: "AST | None" = None
@override
def __str__(self):
class_name = self.__class__.__name__
fields = vars(self) # Get all instance variables as a dictionary
field_str = ", ".join(f"{key}={value}" for key, value in fields.items())
return f"{class_name}({field_str})"
@override
def __repr__(self) -> str:
return self.__str__()

View File

@@ -1,3 +1,8 @@
# Copyright (C) 2025
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from tagstudio.core.query_lang.ast import (
AST,
ANDList,
@@ -27,7 +32,7 @@ class Parser:
if self.next_token.type == TokenType.EOF:
return ORList([])
out = self.__or_list()
if self.next_token.type != TokenType.EOF:
if self.next_token.type != TokenType.EOF: # pyright: ignore[reportUnnecessaryComparison]
raise ParsingError(self.next_token.start, self.next_token.end, "Syntax Error")
return out
@@ -41,7 +46,7 @@ class Parser:
return ORList(terms) if len(terms) > 1 else terms[0]
def __is_next_or(self) -> bool:
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "OR"
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "OR" # pyright: ignore
def __and_list(self) -> AST:
elements = [self.__term()]
@@ -67,7 +72,7 @@ class Parser:
raise self.__syntax_error("Unexpected AND")
def __is_next_and(self) -> bool:
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "AND"
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "AND" # pyright: ignore
def __term(self) -> AST:
if self.__is_next_not():
@@ -85,11 +90,14 @@ class Parser:
return self.__constraint()
def __is_next_not(self) -> bool:
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "NOT"
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "NOT" # pyright: ignore
def __constraint(self) -> Constraint:
if self.next_token.type == TokenType.CONSTRAINTTYPE:
self.last_constraint_type = self.__eat(TokenType.CONSTRAINTTYPE).value
constraint = self.__eat(TokenType.CONSTRAINTTYPE).value
if not isinstance(constraint, ConstraintType):
raise self.__syntax_error()
self.last_constraint_type = constraint
value = self.__literal()
@@ -98,7 +106,7 @@ class Parser:
self.__eat(TokenType.SBRACKETO)
properties.append(self.__property())
while self.next_token.type == TokenType.COMMA:
while self.next_token.type == TokenType.COMMA: # pyright: ignore[reportUnnecessaryComparison]
self.__eat(TokenType.COMMA)
properties.append(self.__property())
@@ -110,11 +118,16 @@ class Parser:
key = self.__eat(TokenType.ULITERAL).value
self.__eat(TokenType.EQUALS)
value = self.__literal()
if not isinstance(key, str):
raise self.__syntax_error()
return Property(key, value)
def __literal(self) -> str:
if self.next_token.type in [TokenType.QLITERAL, TokenType.ULITERAL]:
return self.__eat(self.next_token.type).value
literal = self.__eat(self.next_token.type).value
if not isinstance(literal, str):
raise self.__syntax_error()
return literal
raise self.__syntax_error()
def __eat(self, type: TokenType) -> Token:

View File

@@ -1,5 +1,10 @@
# Copyright (C) 2025
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from enum import Enum
from typing import Any
from typing import override
from tagstudio.core.query_lang.ast import ConstraintType
from tagstudio.core.query_lang.util import ParsingError
@@ -21,12 +26,14 @@ class TokenType(Enum):
class Token:
type: TokenType
value: Any
value: str | ConstraintType | None
start: int
end: int
def __init__(self, type: TokenType, value: Any, start: int, end: int) -> None:
def __init__(
self, type: TokenType, value: str | ConstraintType | None, start: int, end: int
) -> None:
self.type = type
self.value = value
self.start = start
@@ -40,9 +47,11 @@ class Token:
def EOF(pos: int) -> "Token": # noqa: N802
return Token.from_type(TokenType.EOF, pos)
@override
def __str__(self) -> str:
return f"Token({self.type}, {self.value}, {self.start}, {self.end})" # pragma: nocover
@override
def __repr__(self) -> str:
return self.__str__() # pragma: nocover

View File

@@ -1,15 +1,26 @@
# Copyright (C) 2025
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from typing import override
class ParsingError(BaseException):
start: int
end: int
msg: str
def __init__(self, start: int, end: int, msg: str = "Syntax Error") -> None:
super().__init__()
self.start = start
self.end = end
self.msg = msg
@override
def __str__(self) -> str:
return f"Syntax Error {self.start}->{self.end}: {self.msg}" # pragma: nocover
@override
def __repr__(self) -> str:
return self.__str__() # pragma: nocover

View File

@@ -10,7 +10,7 @@ from pathlib import Path
import structlog
from tagstudio.core.constants import TS_FOLDER_NAME
from tagstudio.core.library.alchemy.fields import _FieldID
from tagstudio.core.library.alchemy.fields import FieldID
from tagstudio.core.library.alchemy.library import Library
from tagstudio.core.library.alchemy.models import Entry
@@ -46,27 +46,27 @@ class TagStudioCore:
return {}
if source == "twitter":
info[_FieldID.DESCRIPTION] = json_dump["content"].strip()
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
info[FieldID.DESCRIPTION] = json_dump["content"].strip()
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
elif source == "instagram":
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
info[FieldID.DESCRIPTION] = json_dump["description"].strip()
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
elif source == "artstation":
info[_FieldID.TITLE] = json_dump["title"].strip()
info[_FieldID.ARTIST] = json_dump["user"]["full_name"].strip()
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
info[_FieldID.TAGS] = json_dump["tags"]
info[FieldID.TITLE] = json_dump["title"].strip()
info[FieldID.ARTIST] = json_dump["user"]["full_name"].strip()
info[FieldID.DESCRIPTION] = json_dump["description"].strip()
info[FieldID.TAGS] = json_dump["tags"]
# info["tags"] = [x for x in json_dump["mediums"]["name"]]
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
elif source == "newgrounds":
# info["title"] = json_dump["title"]
# info["artist"] = json_dump["artist"]
# info["description"] = json_dump["description"]
info[_FieldID.TAGS] = json_dump["tags"]
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
info[_FieldID.ARTIST] = json_dump["user"].strip()
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
info[_FieldID.SOURCE] = json_dump["post_url"].strip()
info[FieldID.TAGS] = json_dump["tags"]
info[FieldID.DATE_PUBLISHED] = json_dump["date"]
info[FieldID.ARTIST] = json_dump["user"].strip()
info[FieldID.DESCRIPTION] = json_dump["description"].strip()
info[FieldID.SOURCE] = json_dump["post_url"].strip()
except Exception:
logger.exception("Error handling sidecar file.", path=_filepath)

View File

@@ -71,5 +71,5 @@ class TagDatabasePanel(TagSearchPanel):
if result != QMessageBox.Ok: # type: ignore
return
self.lib.remove_tag(tag)
self.lib.remove_tag(tag.id)
self.update_tags()

View File

@@ -55,7 +55,7 @@ from tagstudio.core.library.alchemy.enums import (
ItemType,
SortingModeEnum,
)
from tagstudio.core.library.alchemy.fields import _FieldID
from tagstudio.core.library.alchemy.fields import FieldID
from tagstudio.core.library.alchemy.library import Library, LibraryStatus
from tagstudio.core.library.alchemy.models import Entry
from tagstudio.core.library.ignore import Ignore
@@ -1129,7 +1129,7 @@ class QtDriver(DriverMixin, QObject):
elif name == MacroID.BUILD_URL:
url = TagStudioCore.build_url(entry, source)
if url is not None:
self.lib.add_field_to_entry(entry.id, field_id=_FieldID.SOURCE, value=url)
self.lib.add_field_to_entry(entry.id, field_id=FieldID.SOURCE, value=url)
elif name == MacroID.MATCH:
TagStudioCore.match_conditions(self.lib, entry.id)
elif name == MacroID.CLEAN_URL:

View File

@@ -13,8 +13,8 @@ import structlog
from tagstudio.core.enums import DefaultEnum, LibraryPrefs
from tagstudio.core.library.alchemy.enums import BrowsingState
from tagstudio.core.library.alchemy.fields import (
FieldID, # pyright: ignore[reportPrivateUsage]
TextField,
_FieldID, # pyright: ignore[reportPrivateUsage]
)
from tagstudio.core.library.alchemy.library import Library
from tagstudio.core.library.alchemy.models import Entry, Tag
@@ -174,7 +174,7 @@ def test_remove_tag(library: Library, generate_tag: Callable[..., Tag]):
tag_count = len(library.tags)
library.remove_tag(tag)
library.remove_tag(tag.id)
assert len(library.tags) == tag_count - 1
@@ -270,7 +270,7 @@ def test_mirror_entry_fields(library: Library, entry_full: Entry):
path=Path("xxx"),
fields=[
TextField(
type_key=_FieldID.NOTES.name,
type_key=FieldID.NOTES.name,
value="notes",
position=0,
)
@@ -292,8 +292,8 @@ def test_mirror_entry_fields(library: Library, entry_full: Entry):
# make sure fields are there after getting it from the library again
assert len(entry.fields) == 2
assert {x.type_key for x in entry.fields} == {
_FieldID.TITLE.name,
_FieldID.NOTES.name,
FieldID.TITLE.name,
FieldID.NOTES.name,
}
@@ -308,14 +308,14 @@ def test_merge_entries(library: Library):
folder=folder,
path=Path("a"),
fields=[
TextField(type_key=_FieldID.AUTHOR.name, value="Author McAuthorson", position=0),
TextField(type_key=_FieldID.DESCRIPTION.name, value="test description", position=2),
TextField(type_key=FieldID.AUTHOR.name, value="Author McAuthorson", position=0),
TextField(type_key=FieldID.DESCRIPTION.name, value="test description", position=2),
],
)
b = Entry(
folder=folder,
path=Path("b"),
fields=[TextField(type_key=_FieldID.NOTES.name, value="test note", position=1)],
fields=[TextField(type_key=FieldID.NOTES.name, value="test note", position=1)],
)
ids = library.add_entries([a, b])