feat: implement query language (#606)
* add files
* fix: term was parsing ANDList instead of ORList
* make mypy happy
* ruff format
* add missing todo
* add more constraint types
* add parent property to AST
* add BaseVisitor class
* make mypy happy
* add __init__.py
* Revert "make mypy happy"
This reverts commit 926d0dd2e79d06203e84e2f83c06c7fe5b33de23.
* refactoring and fixes
* rudimentary search field integration
* fix: check for None properly
* fix: Entries without Tags are now searchable
* make mypy happy
* Revert "fix: Entries without Tags are now searchable"
This reverts commit 19b40af7480b0c068b81b642b51536a9ec96d030.
* fix: changed joins to outerjoins and added missing outerjoin
* use query lang instead of tag_id FIlterState
* add todos
* fix: remove uncecessary line that broke search when searching for exact name
* fix tag search
* refactoring
* fix: path now uses GLOB operator for proper GLOBs
* refactoring: remove FilterState.id and implement Library.get_entry_full as replacement
* fix: use default value notation instead of if None statement in __post_init__
* remove obsolete Search Mode UI and related code
* ruff fixes
* remove obsolete tests
* fix: item_thumb didn't query entries correctly
* fix: search_library now correctly returns the number of *unique* entries
* make mypy happy
* implement NOT
* remove obsolete filename search
* remove summary as it is not applicable anymore
* finish refactoring of FilterState
* implement special:untagged
* fix: make mypy happy
* Revert changes to search_tags in favor of changes from #604
* fix: also port test changes
* fix: remove unneccessary import
* fix: remove unused dataclass
* fix: AND now works correctly with tags
* simplify structure of parsed AST
* add performance logging
* perf: Improve performance of search by reducing number of required joins from 4 to 1
* perf: double NOT is now optimized out of the AST
* fix: bug where pages would show less than the configured number of entries
* Revert "add performance logging"
This reverts commit c3c7d7546d.
* fix: tag_id search was broken
* somewhat adapt the existing autocompletion to this PR
* perf: Use Relational Division Queries to improve Query Execution Time
* fix: raise Exception so as to not fail silently
* fix: Parser bug broke parentheses
* little bit of clean up
* remove unnecessary comment
* add library for testing search
* feat: add basic tests
* fix: and queries containing just one tag were broken
* chore: remove debug code
* feat: more tests
* refactor: more consistent name for variable
Co-authored-by: Travis Abendshien <46939827+CyanVoxel@users.noreply.github.com>
* fix: ruff check complaint over double import
---------
Co-authored-by: Travis Abendshien <46939827+CyanVoxel@users.noreply.github.com>
@@ -1,7 +1,10 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.query_lang import AST as Query # noqa: N811
|
||||
from src.core.query_lang import Constraint, ConstraintType, Parser
|
||||
|
||||
|
||||
class TagColor(enum.IntEnum):
|
||||
DEFAULT = 1
|
||||
@@ -50,13 +53,6 @@ class TagColor(enum.IntEnum):
|
||||
return TagColor.DEFAULT
|
||||
|
||||
|
||||
class SearchMode(enum.IntEnum):
|
||||
"""Operational modes for item searching."""
|
||||
|
||||
AND = 0
|
||||
OR = 1
|
||||
|
||||
|
||||
class ItemType(enum.Enum):
|
||||
ENTRY = 0
|
||||
COLLATION = 1
|
||||
@@ -68,71 +64,12 @@ class FilterState:
|
||||
"""Represent a state of the Library grid view."""
|
||||
|
||||
# these should remain
|
||||
page_index: int | None = None
|
||||
page_size: int | None = None
|
||||
search_mode: SearchMode = SearchMode.AND # TODO - actually implement this
|
||||
page_index: int | None = 0
|
||||
page_size: int | None = 500
|
||||
|
||||
# these should be erased on update
|
||||
# tag name
|
||||
tag: str | None = None
|
||||
# tag ID
|
||||
tag_id: int | None = None
|
||||
|
||||
# entry id
|
||||
id: int | None = None
|
||||
# whole path
|
||||
path: Path | str | None = None
|
||||
# file name
|
||||
name: str | None = None
|
||||
# file type
|
||||
filetype: str | None = None
|
||||
mediatype: str | None = None
|
||||
|
||||
# a generic query to be parsed
|
||||
query: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
# strip values automatically
|
||||
if query := (self.query and self.query.strip()):
|
||||
# parse the value
|
||||
if ":" in query:
|
||||
kind, _, value = query.partition(":")
|
||||
value = value.replace('"', "")
|
||||
else:
|
||||
# default to tag search
|
||||
kind, value = "tag", query
|
||||
|
||||
if kind == "tag_id":
|
||||
self.tag_id = int(value)
|
||||
elif kind == "tag":
|
||||
self.tag = value
|
||||
elif kind == "path":
|
||||
self.path = value
|
||||
elif kind == "name":
|
||||
self.name = value
|
||||
elif kind == "id":
|
||||
self.id = int(self.id) if str(self.id).isnumeric() else self.id
|
||||
elif kind == "filetype":
|
||||
self.filetype = value
|
||||
elif kind == "mediatype":
|
||||
self.mediatype = value
|
||||
|
||||
else:
|
||||
self.tag = self.tag and self.tag.strip()
|
||||
self.tag_id = int(self.tag_id) if str(self.tag_id).isnumeric() else self.tag_id
|
||||
self.path = self.path and str(self.path).strip()
|
||||
self.name = self.name and self.name.strip()
|
||||
self.id = int(self.id) if str(self.id).isnumeric() else self.id
|
||||
|
||||
if self.page_index is None:
|
||||
self.page_index = 0
|
||||
if self.page_size is None:
|
||||
self.page_size = 500
|
||||
|
||||
@property
|
||||
def summary(self):
|
||||
"""Show query summary."""
|
||||
return self.query or self.tag or self.name or self.tag_id or self.path or self.id
|
||||
# Abstract Syntax Tree Of the current Search Query
|
||||
ast: Query = None
|
||||
|
||||
@property
|
||||
def limit(self):
|
||||
@@ -142,6 +79,37 @@ class FilterState:
|
||||
def offset(self):
|
||||
return self.page_size * self.page_index
|
||||
|
||||
@classmethod
|
||||
def show_all(cls) -> "FilterState":
|
||||
return FilterState()
|
||||
|
||||
@classmethod
|
||||
def from_search_query(cls, search_query: str) -> "FilterState":
|
||||
return cls(ast=Parser(search_query).parse())
|
||||
|
||||
@classmethod
|
||||
def from_tag_id(cls, tag_id: int | str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.TagID, str(tag_id), []))
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path: Path | str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.Path, str(path).strip(), []))
|
||||
|
||||
@classmethod
|
||||
def from_mediatype(cls, mediatype: str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.MediaType, mediatype, []))
|
||||
|
||||
@classmethod
|
||||
def from_filetype(cls, filetype: str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.FileType, filetype, []))
|
||||
|
||||
@classmethod
|
||||
def from_tag_name(cls, tag_name: str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.Tag, tag_name, []))
|
||||
|
||||
def with_page_size(self, page_size: int) -> "FilterState":
|
||||
return replace(self, page_size=page_size)
|
||||
|
||||
|
||||
class FieldTypeEnum(enum.Enum):
|
||||
TEXT_LINE = "Text Line"
|
||||
|
||||
@@ -28,7 +28,6 @@ from sqlalchemy import (
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import (
|
||||
Session,
|
||||
aliased,
|
||||
contains_eager,
|
||||
make_transient,
|
||||
selectinload,
|
||||
@@ -42,7 +41,6 @@ from ...constants import (
|
||||
TS_FOLDER_NAME,
|
||||
)
|
||||
from ...enums import LibraryPrefs
|
||||
from ...media_types import MediaCategories
|
||||
from .db import make_tables
|
||||
from .enums import FieldTypeEnum, FilterState, TagColor
|
||||
from .fields import (
|
||||
@@ -54,6 +52,7 @@ from .fields import (
|
||||
)
|
||||
from .joins import TagField, TagSubtag
|
||||
from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType
|
||||
from .visitors import SQLBoolExpressionBuilder
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@@ -402,6 +401,29 @@ class Library:
|
||||
make_transient(entry)
|
||||
return entry
|
||||
|
||||
def get_entry_full(self, entry_id: int) -> Entry | None:
|
||||
"""Load entry an join with all joins and all tags."""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(Entry).where(Entry.id == entry_id)
|
||||
statement = (
|
||||
statement.outerjoin(Entry.text_fields)
|
||||
.outerjoin(Entry.datetime_fields)
|
||||
.outerjoin(Entry.tag_box_fields)
|
||||
)
|
||||
statement = statement.options(
|
||||
selectinload(Entry.text_fields),
|
||||
selectinload(Entry.datetime_fields),
|
||||
selectinload(Entry.tag_box_fields)
|
||||
.joinedload(TagBoxField.tags)
|
||||
.options(selectinload(Tag.aliases), selectinload(Tag.subtags)),
|
||||
)
|
||||
entry = session.scalar(statement)
|
||||
if not entry:
|
||||
return None
|
||||
session.expunge(entry)
|
||||
make_transient(entry)
|
||||
return entry
|
||||
|
||||
@property
|
||||
def entries_count(self) -> int:
|
||||
with Session(self.engine) as session:
|
||||
@@ -518,63 +540,18 @@ class Library:
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
statement = select(Entry)
|
||||
|
||||
if search.tag:
|
||||
SubtagAlias = aliased(Tag) # noqa: N806
|
||||
statement = (
|
||||
statement.join(Entry.tag_box_fields)
|
||||
.join(TagBoxField.tags)
|
||||
.outerjoin(Tag.aliases)
|
||||
.outerjoin(SubtagAlias, Tag.subtags)
|
||||
.where(
|
||||
or_(
|
||||
Tag.name.ilike(search.tag),
|
||||
Tag.shorthand.ilike(search.tag),
|
||||
TagAlias.name.ilike(search.tag),
|
||||
SubtagAlias.name.ilike(search.tag),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif search.tag_id:
|
||||
statement = (
|
||||
statement.join(Entry.tag_box_fields)
|
||||
.join(TagBoxField.tags)
|
||||
.where(Tag.id == search.tag_id)
|
||||
)
|
||||
|
||||
elif search.id:
|
||||
statement = statement.where(Entry.id == search.id)
|
||||
elif search.name:
|
||||
statement = select(Entry).where(
|
||||
and_(
|
||||
Entry.path.ilike(f"%{search.name}%"),
|
||||
# dont match directory name (ie. has following slash)
|
||||
~Entry.path.ilike(f"%{search.name}%/%"),
|
||||
)
|
||||
)
|
||||
elif search.path:
|
||||
search_str = str(search.path).replace("*", "%")
|
||||
statement = statement.where(Entry.path.ilike(search_str))
|
||||
elif search.filetype:
|
||||
statement = statement.where(Entry.suffix.ilike(f"{search.filetype}"))
|
||||
elif search.mediatype:
|
||||
extensions: set[str] = set[str]()
|
||||
for media_cat in MediaCategories.ALL_CATEGORIES:
|
||||
if search.mediatype == media_cat.name:
|
||||
extensions = extensions | media_cat.extensions
|
||||
break
|
||||
# just need to map it to search db - suffixes do not have '.'
|
||||
statement = statement.where(
|
||||
Entry.suffix.in_(map(lambda x: x.replace(".", ""), extensions))
|
||||
if search.ast:
|
||||
statement = statement.outerjoin(Entry.tag_box_fields).where(
|
||||
SQLBoolExpressionBuilder(self).visit(search.ast)
|
||||
)
|
||||
|
||||
extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
|
||||
is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST)
|
||||
|
||||
if not search.id: # if `id` is set, we don't need to filter by extensions
|
||||
if extensions and is_exclude_list:
|
||||
statement = statement.where(Entry.suffix.notin_(extensions))
|
||||
elif extensions:
|
||||
statement = statement.where(Entry.suffix.in_(extensions))
|
||||
if extensions and is_exclude_list:
|
||||
statement = statement.where(Entry.suffix.notin_(extensions))
|
||||
elif extensions:
|
||||
statement = statement.where(Entry.suffix.in_(extensions))
|
||||
|
||||
statement = statement.options(
|
||||
selectinload(Entry.text_fields),
|
||||
@@ -584,6 +561,8 @@ class Library:
|
||||
.options(selectinload(Tag.aliases), selectinload(Tag.subtags)),
|
||||
)
|
||||
|
||||
statement = statement.distinct(Entry.id)
|
||||
|
||||
query_count = select(func.count()).select_from(statement.alias("entries"))
|
||||
count_all: int = session.execute(query_count).scalar()
|
||||
|
||||
@@ -597,7 +576,7 @@ class Library:
|
||||
|
||||
res = SearchResult(
|
||||
total_count=count_all,
|
||||
items=list(session.scalars(statement).unique()),
|
||||
items=list(session.scalars(statement)),
|
||||
)
|
||||
|
||||
session.expunge_all()
|
||||
|
||||
125
tagstudio/src/core/library/alchemy/visitors.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import and_, distinct, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
|
||||
from src.core.media_types import MediaCategories
|
||||
from src.core.query_lang import BaseVisitor
|
||||
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
|
||||
|
||||
from .joins import TagField
|
||||
from .models import Entry, Tag, TagAlias, TagBoxField
|
||||
|
||||
# workaround to have autocompletion in the Editor
|
||||
if TYPE_CHECKING:
|
||||
from .library import Library
|
||||
else:
|
||||
Library = None # don't import .library because of circular imports
|
||||
|
||||
|
||||
class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
def __init__(self, lib: Library) -> None:
|
||||
super().__init__()
|
||||
self.lib = lib
|
||||
|
||||
def visit_or_list(self, node: ORList) -> ColumnExpressionArgument:
|
||||
return or_(*[self.visit(element) for element in node.elements])
|
||||
|
||||
def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
|
||||
tag_ids: list[int] = []
|
||||
bool_expressions: list[ColumnExpressionArgument] = []
|
||||
|
||||
# Search for TagID / unambigous Tag Constraints and store the respective tag ids seperately
|
||||
for term in node.terms:
|
||||
if isinstance(term, Constraint) and len(term.properties) == 0:
|
||||
match term.type:
|
||||
case ConstraintType.TagID:
|
||||
tag_ids.append(int(term.value))
|
||||
continue
|
||||
case ConstraintType.Tag:
|
||||
if len(ids := self.__get_tag_ids(term.value)) == 1:
|
||||
tag_ids.append(ids[0])
|
||||
continue
|
||||
|
||||
bool_expressions.append(self.__entry_satisfies_ast(term))
|
||||
|
||||
# If there are at least two tag ids use a relational division query
|
||||
# to efficiently check all of them
|
||||
if len(tag_ids) > 1:
|
||||
bool_expressions.append(self.__entry_has_all_tags(tag_ids))
|
||||
# If there is just one tag id, check the normal way
|
||||
elif len(tag_ids) == 1:
|
||||
bool_expressions.append(
|
||||
self.__entry_satisfies_expression(TagField.tag_id == tag_ids[0])
|
||||
)
|
||||
|
||||
return and_(*bool_expressions)
|
||||
|
||||
def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument:
|
||||
if len(node.properties) != 0:
|
||||
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG
|
||||
|
||||
if node.type == ConstraintType.Tag:
|
||||
return TagBoxField.tags.any(Tag.id.in_(self.__get_tag_ids(node.value)))
|
||||
elif node.type == ConstraintType.TagID:
|
||||
return TagBoxField.tags.any(Tag.id == int(node.value))
|
||||
elif node.type == ConstraintType.Path:
|
||||
return Entry.path.op("GLOB")(node.value)
|
||||
elif node.type == ConstraintType.MediaType:
|
||||
extensions: set[str] = set[str]()
|
||||
for media_cat in MediaCategories.ALL_CATEGORIES:
|
||||
if node.value == media_cat.name:
|
||||
extensions = extensions | media_cat.extensions
|
||||
break
|
||||
return Entry.suffix.in_(map(lambda x: x.replace(".", ""), extensions))
|
||||
elif node.type == ConstraintType.FileType:
|
||||
return Entry.suffix.ilike(node.value)
|
||||
elif node.type == ConstraintType.Special: # noqa: SIM102 unnecessary once there is a second special constraint
|
||||
if node.value.lower() == "untagged":
|
||||
return ~Entry.id.in_(
|
||||
select(Entry.id).join(Entry.tag_box_fields).join(TagBoxField.tags)
|
||||
)
|
||||
|
||||
# raise exception if Constraint stays unhandled
|
||||
raise NotImplementedError("This type of constraint is not implemented yet")
|
||||
|
||||
def visit_property(self, node: Property) -> None:
|
||||
raise NotImplementedError("This should never be reached!")
|
||||
|
||||
def visit_not(self, node: Not) -> ColumnExpressionArgument:
|
||||
return ~self.__entry_satisfies_ast(node.child)
|
||||
|
||||
def __get_tag_ids(self, tag_name: str) -> list[int]:
|
||||
"""Given a tag name find the ids of all tags that this name could refer to."""
|
||||
with Session(self.lib.engine, expire_on_commit=False) as session:
|
||||
return list(
|
||||
session.scalars(
|
||||
select(Tag.id)
|
||||
.where(or_(Tag.name.ilike(tag_name), Tag.shorthand.ilike(tag_name)))
|
||||
.union(select(TagAlias.tag_id).where(TagAlias.name.ilike(tag_name)))
|
||||
)
|
||||
)
|
||||
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
|
||||
# Relational Division Query
|
||||
return Entry.id.in_(
|
||||
select(Entry.id)
|
||||
.outerjoin(TagBoxField)
|
||||
.outerjoin(TagField)
|
||||
.where(TagField.tag_id.in_(tag_ids))
|
||||
.group_by(Entry.id)
|
||||
.having(func.count(distinct(TagField.tag_id)) == len(tag_ids))
|
||||
)
|
||||
|
||||
def __entry_satisfies_ast(self, partial_query: AST) -> BinaryExpression[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry satisfies the partial query."""
|
||||
return self.__entry_satisfies_expression(self.visit(partial_query))
|
||||
|
||||
def __entry_satisfies_expression(
|
||||
self, expr: ColumnExpressionArgument
|
||||
) -> BinaryExpression[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry satisfies the column expression."""
|
||||
return Entry.id.in_(
|
||||
select(Entry.id).outerjoin(Entry.tag_box_fields).outerjoin(TagField).where(expr)
|
||||
)
|
||||
11
tagstudio/src/core/query_lang/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from src.core.query_lang.ast import ( # noqa
|
||||
AST,
|
||||
ANDList,
|
||||
BaseVisitor,
|
||||
Constraint,
|
||||
ConstraintType,
|
||||
ORList,
|
||||
Property,
|
||||
)
|
||||
from src.core.query_lang.parser import Parser # noqa
|
||||
from src.core.query_lang.util import ParsingError # noqa
|
||||
126
tagstudio/src/core/query_lang/ast.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
|
||||
class ConstraintType(Enum):
|
||||
Tag = 0
|
||||
TagID = 1
|
||||
MediaType = 2
|
||||
FileType = 3
|
||||
Path = 4
|
||||
Special = 5
|
||||
|
||||
@staticmethod
|
||||
def from_string(text: str) -> "ConstraintType":
|
||||
return {
|
||||
"tag": ConstraintType.Tag,
|
||||
"tag_id": ConstraintType.TagID,
|
||||
"mediatype": ConstraintType.MediaType,
|
||||
"filetype": ConstraintType.FileType,
|
||||
"path": ConstraintType.Path,
|
||||
"special": ConstraintType.Special,
|
||||
}.get(text.lower(), None)
|
||||
|
||||
|
||||
class AST:
|
||||
parent: "AST" = None
|
||||
|
||||
def __str__(self):
|
||||
class_name = self.__class__.__name__
|
||||
fields = vars(self) # Get all instance variables as a dictionary
|
||||
field_str = ", ".join(f"{key}={value}" for key, value in fields.items())
|
||||
return f"{class_name}({field_str})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class ANDList(AST):
|
||||
terms: list[AST]
|
||||
|
||||
def __init__(self, terms: list[AST]) -> None:
|
||||
super().__init__()
|
||||
for term in terms:
|
||||
term.parent = self
|
||||
self.terms = terms
|
||||
|
||||
|
||||
class ORList(AST):
|
||||
elements: list[AST]
|
||||
|
||||
def __init__(self, elements: list[AST]) -> None:
|
||||
super().__init__()
|
||||
for element in elements:
|
||||
element.parent = self
|
||||
self.elements = elements
|
||||
|
||||
|
||||
class Constraint(AST):
|
||||
type: ConstraintType
|
||||
value: str
|
||||
properties: list["Property"]
|
||||
|
||||
def __init__(self, type: ConstraintType, value: str, properties: list["Property"]) -> None:
|
||||
super().__init__()
|
||||
for prop in properties:
|
||||
prop.parent = self
|
||||
self.type = type
|
||||
self.value = value
|
||||
self.properties = properties
|
||||
|
||||
|
||||
class Property(AST):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
def __init__(self, key: str, value: str) -> None:
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
|
||||
class Not(AST):
|
||||
child: AST
|
||||
|
||||
def __init__(self, child: AST) -> None:
|
||||
super().__init__()
|
||||
self.child = child
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseVisitor(ABC, Generic[T]):
|
||||
def visit(self, node: AST) -> T:
|
||||
if isinstance(node, ANDList):
|
||||
return self.visit_and_list(node)
|
||||
elif isinstance(node, ORList):
|
||||
return self.visit_or_list(node)
|
||||
elif isinstance(node, Constraint):
|
||||
return self.visit_constraint(node)
|
||||
elif isinstance(node, Property):
|
||||
return self.visit_property(node)
|
||||
elif isinstance(node, Not):
|
||||
return self.visit_not(node)
|
||||
raise Exception(f"Unknown Node Type of {node}") # pragma: nocover
|
||||
|
||||
@abstractmethod
|
||||
def visit_and_list(self, node: ANDList) -> T:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
@abstractmethod
|
||||
def visit_or_list(self, node: ORList) -> T:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint(self, node: Constraint) -> T:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
@abstractmethod
|
||||
def visit_property(self, node: Property) -> T:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
@abstractmethod
|
||||
def visit_not(self, node: Not) -> T:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
120
tagstudio/src/core/query_lang/parser.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from .ast import AST, ANDList, Constraint, Not, ORList, Property
|
||||
from .tokenizer import ConstraintType, Token, Tokenizer, TokenType
|
||||
from .util import ParsingError
|
||||
|
||||
|
||||
class Parser:
|
||||
text: str
|
||||
tokenizer: Tokenizer
|
||||
next_token: Token
|
||||
|
||||
last_constraint_type: ConstraintType = ConstraintType.Tag
|
||||
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
self.tokenizer = Tokenizer(self.text)
|
||||
self.next_token = self.tokenizer.get_next_token()
|
||||
|
||||
def parse(self) -> AST:
|
||||
if self.next_token.type == TokenType.EOF:
|
||||
return ORList([])
|
||||
out = self.__or_list()
|
||||
if self.next_token.type != TokenType.EOF:
|
||||
raise ParsingError(self.next_token.start, self.next_token.end, "Syntax Error")
|
||||
return out
|
||||
|
||||
def __or_list(self) -> AST:
|
||||
terms = [self.__and_list()]
|
||||
|
||||
while self.__is_next_or():
|
||||
self.__eat(TokenType.ULITERAL)
|
||||
terms.append(self.__and_list())
|
||||
|
||||
return ORList(terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
def __is_next_or(self) -> bool:
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "OR"
|
||||
|
||||
def __and_list(self) -> AST:
|
||||
elements = [self.__term()]
|
||||
while (
|
||||
self.next_token.type
|
||||
in [
|
||||
TokenType.QLITERAL,
|
||||
TokenType.ULITERAL,
|
||||
TokenType.CONSTRAINTTYPE,
|
||||
TokenType.RBRACKETO,
|
||||
]
|
||||
and not self.__is_next_or()
|
||||
):
|
||||
self.__skip_and()
|
||||
elements.append(self.__term())
|
||||
return ANDList(elements) if len(elements) > 1 else elements[0]
|
||||
|
||||
def __skip_and(self) -> None:
|
||||
if self.__is_next_and():
|
||||
self.__eat(TokenType.ULITERAL)
|
||||
|
||||
if self.__is_next_and():
|
||||
raise self.__syntax_error("Unexpected AND")
|
||||
|
||||
def __is_next_and(self) -> bool:
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "AND"
|
||||
|
||||
def __term(self) -> AST:
|
||||
if self.__is_next_not():
|
||||
self.__eat(TokenType.ULITERAL)
|
||||
term = self.__term()
|
||||
if isinstance(term, Not): # instead of Not(Not(child)) return child
|
||||
return term.child
|
||||
return Not(term)
|
||||
if self.next_token.type == TokenType.RBRACKETO:
|
||||
self.__eat(TokenType.RBRACKETO)
|
||||
out = self.__or_list()
|
||||
self.__eat(TokenType.RBRACKETC)
|
||||
return out
|
||||
else:
|
||||
return self.__constraint()
|
||||
|
||||
def __is_next_not(self) -> bool:
|
||||
return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "NOT"
|
||||
|
||||
def __constraint(self) -> Constraint:
|
||||
if self.next_token.type == TokenType.CONSTRAINTTYPE:
|
||||
self.last_constraint_type = self.__eat(TokenType.CONSTRAINTTYPE).value
|
||||
|
||||
value = self.__literal()
|
||||
|
||||
properties = []
|
||||
if self.next_token.type == TokenType.SBRACKETO:
|
||||
self.__eat(TokenType.SBRACKETO)
|
||||
properties.append(self.__property())
|
||||
|
||||
while self.next_token.type == TokenType.COMMA:
|
||||
self.__eat(TokenType.COMMA)
|
||||
properties.append(self.__property())
|
||||
|
||||
self.__eat(TokenType.SBRACKETC)
|
||||
|
||||
return Constraint(self.last_constraint_type, value, properties)
|
||||
|
||||
def __property(self) -> Property:
|
||||
key = self.__eat(TokenType.ULITERAL).value
|
||||
self.__eat(TokenType.EQUALS)
|
||||
value = self.__literal()
|
||||
return Property(key, value)
|
||||
|
||||
def __literal(self) -> str:
|
||||
if self.next_token.type in [TokenType.QLITERAL, TokenType.ULITERAL]:
|
||||
return self.__eat(self.next_token.type).value
|
||||
raise self.__syntax_error()
|
||||
|
||||
def __eat(self, type: TokenType) -> Token:
|
||||
if self.next_token.type != type:
|
||||
raise self.__syntax_error(f"expected {type} found {self.next_token.type}")
|
||||
out = self.next_token
|
||||
self.next_token = self.tokenizer.get_next_token()
|
||||
return out
|
||||
|
||||
def __syntax_error(self, msg: str = "Syntax Error") -> ParsingError:
|
||||
return ParsingError(self.next_token.start, self.next_token.end, msg)
|
||||
147
tagstudio/src/core/query_lang/tokenizer.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .ast import ConstraintType
|
||||
from .util import ParsingError
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
EOF = -1
|
||||
QLITERAL = 0 # Quoted Literal
|
||||
ULITERAL = 1 # Unquoted Literal (does not contain ":", " ", "[", "]", "(", ")", "=", ",")
|
||||
RBRACKETO = 2 # Round Bracket Open
|
||||
RBRACKETC = 3 # Round Bracket Close
|
||||
SBRACKETO = 4 # Square Bracket Open
|
||||
SBRACKETC = 5 # Square Bracket Close
|
||||
CONSTRAINTTYPE = 6
|
||||
COLON = 10
|
||||
COMMA = 11
|
||||
EQUALS = 12
|
||||
|
||||
|
||||
class Token:
|
||||
type: TokenType
|
||||
value: Any
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def __init__(self, type: TokenType, value: Any, start: int = None, end: int = None) -> None:
|
||||
self.type = type
|
||||
self.value = value
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
@staticmethod
|
||||
def from_type(type: TokenType, pos: int = None) -> "Token":
|
||||
return Token(type, None, pos, pos)
|
||||
|
||||
@staticmethod
|
||||
def EOF() -> "Token": # noqa: N802
|
||||
return Token.from_type(TokenType.EOF)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Token({self.type}, {self.value}, {self.start}, {self.end})" # pragma: nocover
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__() # pragma: nocover
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
text: str
|
||||
pos: int
|
||||
current_char: str
|
||||
|
||||
ESCAPABLE_CHARS = ["\\", '"', '"']
|
||||
NOT_IN_ULITERAL = [":", " ", "[", "]", "(", ")", "=", ","]
|
||||
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
self.pos = 0
|
||||
self.current_char = self.text[self.pos] if len(text) > 0 else None
|
||||
|
||||
def get_next_token(self) -> Token:
|
||||
self.__skip_whitespace()
|
||||
if self.current_char is None:
|
||||
return Token.EOF()
|
||||
|
||||
if self.current_char in ("'", '"'):
|
||||
return self.__quoted_string()
|
||||
elif self.current_char == "(":
|
||||
self.__advance()
|
||||
return Token.from_type(TokenType.RBRACKETO, self.pos - 1)
|
||||
elif self.current_char == ")":
|
||||
self.__advance()
|
||||
return Token.from_type(TokenType.RBRACKETC, self.pos - 1)
|
||||
elif self.current_char == "[":
|
||||
self.__advance()
|
||||
return Token.from_type(TokenType.SBRACKETO, self.pos - 1)
|
||||
elif self.current_char == "]":
|
||||
self.__advance()
|
||||
return Token.from_type(TokenType.SBRACKETC, self.pos - 1)
|
||||
elif self.current_char == ",":
|
||||
self.__advance()
|
||||
return Token.from_type(TokenType.COMMA, self.pos - 1)
|
||||
elif self.current_char == "=":
|
||||
self.__advance()
|
||||
return Token.from_type(TokenType.EQUALS, self.pos - 1)
|
||||
else:
|
||||
return self.__unquoted_string_or_constraint_type()
|
||||
|
||||
def __unquoted_string_or_constraint_type(self) -> Token:
|
||||
out = ""
|
||||
|
||||
start = self.pos
|
||||
|
||||
while self.current_char not in self.NOT_IN_ULITERAL and self.current_char is not None:
|
||||
out += self.current_char
|
||||
self.__advance()
|
||||
|
||||
end = self.pos - 1
|
||||
|
||||
if self.current_char == ":":
|
||||
if len(out) == 0:
|
||||
raise ParsingError(self.pos, self.pos)
|
||||
self.__advance()
|
||||
constraint_type = ConstraintType.from_string(out)
|
||||
if constraint_type is None:
|
||||
raise ParsingError(start, end, f'Invalid ContraintType "{out}"')
|
||||
return Token(TokenType.CONSTRAINTTYPE, constraint_type, start, end)
|
||||
else:
|
||||
return Token(TokenType.ULITERAL, out, start, end)
|
||||
|
||||
def __quoted_string(self) -> Token:
|
||||
start = self.pos
|
||||
quote = self.current_char
|
||||
self.__advance()
|
||||
escape = False
|
||||
out = ""
|
||||
|
||||
while escape or self.current_char != quote:
|
||||
if escape:
|
||||
escape = False
|
||||
if self.current_char not in Tokenizer.ESCAPABLE_CHARS:
|
||||
out += "\\"
|
||||
else:
|
||||
out += self.current_char
|
||||
self.__advance()
|
||||
continue
|
||||
if self.current_char == "\\":
|
||||
escape = True
|
||||
else:
|
||||
out += self.current_char
|
||||
self.__advance()
|
||||
end = self.pos
|
||||
self.__advance()
|
||||
return Token(TokenType.QLITERAL, out, start, end)
|
||||
|
||||
def __advance(self) -> None:
|
||||
if self.pos < len(self.text) - 1:
|
||||
self.pos += 1
|
||||
self.current_char = self.text[self.pos]
|
||||
else:
|
||||
self.current_char = None
|
||||
|
||||
def __skip_whitespace(self) -> None:
|
||||
while self.current_char is not None and self.current_char.isspace():
|
||||
self.__advance()
|
||||
15
tagstudio/src/core/query_lang/util.py
Normal file
@@ -0,0 +1,15 @@
|
||||
class ParsingError(BaseException):
|
||||
start: int
|
||||
end: int
|
||||
msg: str
|
||||
|
||||
def __init__(self, start: int, end: int, msg: str = "Syntax Error") -> None:
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Syntax Error {self.start}->{self.end}: {self.msg}" # pragma: nocover
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__() # pragma: nocover
|
||||
@@ -50,7 +50,7 @@ class DupeRegistry:
|
||||
continue
|
||||
|
||||
results = self.library.search_library(
|
||||
FilterState(path=path_relative),
|
||||
FilterState.from_path(path_relative),
|
||||
)
|
||||
|
||||
if not results:
|
||||
|
||||
@@ -74,14 +74,6 @@ class Ui_MainWindow(QMainWindow):
|
||||
spacerItem = QSpacerItem(40, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
|
||||
self.horizontalLayout_3.addItem(spacerItem)
|
||||
|
||||
# Search type selector
|
||||
self.comboBox_2 = QComboBox(self.centralwidget)
|
||||
self.comboBox_2.setMinimumSize(QSize(165, 0))
|
||||
self.comboBox_2.setObjectName("comboBox_2")
|
||||
self.comboBox_2.addItem("")
|
||||
self.comboBox_2.addItem("")
|
||||
self.horizontalLayout_3.addWidget(self.comboBox_2)
|
||||
|
||||
# Thumbnail Size placeholder
|
||||
self.thumb_size_combobox = QComboBox(self.centralwidget)
|
||||
self.thumb_size_combobox.setObjectName(u"thumbSizeComboBox")
|
||||
@@ -214,9 +206,6 @@ class Ui_MainWindow(QMainWindow):
|
||||
self.searchButton.setText(
|
||||
QCoreApplication.translate("MainWindow", u"Search", None))
|
||||
|
||||
# Search type selector
|
||||
self.comboBox_2.setItemText(0, QCoreApplication.translate("MainWindow", "And (Includes All Tags)"))
|
||||
self.comboBox_2.setItemText(1, QCoreApplication.translate("MainWindow", "Or (Includes Any Tag)"))
|
||||
self.thumb_size_combobox.setCurrentText("")
|
||||
|
||||
# Thumbnail size selector
|
||||
|
||||
@@ -65,14 +65,14 @@ from src.core.constants import (
|
||||
)
|
||||
from src.core.driver import DriverMixin
|
||||
from src.core.enums import LibraryPrefs, MacroID, SettingItems
|
||||
from src.core.library.alchemy import Library
|
||||
from src.core.library.alchemy.enums import (
|
||||
FieldTypeEnum,
|
||||
FilterState,
|
||||
ItemType,
|
||||
SearchMode,
|
||||
)
|
||||
from src.core.library.alchemy.fields import _FieldID
|
||||
from src.core.library.alchemy.library import Entry, Library, LibraryStatus
|
||||
from src.core.library.alchemy.library import Entry, LibraryStatus
|
||||
from src.core.media_types import MediaCategories
|
||||
from src.core.ts_core import TagStudioCore
|
||||
from src.core.utils.refresh_dir import RefreshDirTracker
|
||||
@@ -140,7 +140,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
self.rm: ResourceManager = ResourceManager()
|
||||
self.args = args
|
||||
self.frame_content = []
|
||||
self.filter = FilterState()
|
||||
self.filter = FilterState.show_all()
|
||||
self.pages_count = 0
|
||||
|
||||
self.scrollbar_pos = 0
|
||||
@@ -468,7 +468,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
]
|
||||
self.item_thumbs: list[ItemThumb] = []
|
||||
self.thumb_renderers: list[ThumbRenderer] = []
|
||||
self.filter = FilterState()
|
||||
self.filter = FilterState.show_all()
|
||||
self.init_library_window()
|
||||
self.migration_modal: JsonMigrationModal = None
|
||||
|
||||
@@ -510,18 +510,17 @@ class QtDriver(DriverMixin, QObject):
|
||||
# Search Button
|
||||
search_button: QPushButton = self.main_window.searchButton
|
||||
search_button.clicked.connect(
|
||||
lambda: self.filter_items(FilterState(query=self.main_window.searchField.text()))
|
||||
lambda: self.filter_items(
|
||||
FilterState.from_search_query(self.main_window.searchField.text())
|
||||
)
|
||||
)
|
||||
# Search Field
|
||||
search_field: QLineEdit = self.main_window.searchField
|
||||
search_field.returnPressed.connect(
|
||||
# TODO - parse search field for filters
|
||||
lambda: self.filter_items(FilterState(query=self.main_window.searchField.text()))
|
||||
)
|
||||
# Search Type Selector
|
||||
search_type_selector: QComboBox = self.main_window.comboBox_2
|
||||
search_type_selector.currentIndexChanged.connect(
|
||||
lambda: self.set_search_type(SearchMode(search_type_selector.currentIndex()))
|
||||
lambda: self.filter_items(
|
||||
FilterState.from_search_query(self.main_window.searchField.text())
|
||||
)
|
||||
)
|
||||
# Thumbnail Size ComboBox
|
||||
thumb_size_combobox: QComboBox = self.main_window.thumb_size_combobox
|
||||
@@ -963,11 +962,20 @@ class QtDriver(DriverMixin, QObject):
|
||||
self.autofill_action.setDisabled(not self.selected)
|
||||
|
||||
def update_completions_list(self, text: str) -> None:
|
||||
matches = re.search(r"(mediatype|filetype|path|tag):(\"?[A-Za-z0-9\ \t]+\"?)?", text)
|
||||
matches = re.search(
|
||||
r"((?:.* )?)(mediatype|filetype|path|tag|tag_id):(\"?[A-Za-z0-9\ \t]+\"?)?", text
|
||||
)
|
||||
|
||||
completion_list: list[str] = []
|
||||
if len(text) < 3:
|
||||
completion_list = ["mediatype:", "filetype:", "path:", "tag:"]
|
||||
completion_list = [
|
||||
"mediatype:",
|
||||
"filetype:",
|
||||
"path:",
|
||||
"tag:",
|
||||
"tag_id:",
|
||||
"special:untagged",
|
||||
]
|
||||
self.main_window.searchFieldCompletionList.setStringList(completion_list)
|
||||
|
||||
if not matches:
|
||||
@@ -975,26 +983,28 @@ class QtDriver(DriverMixin, QObject):
|
||||
|
||||
query_type: str
|
||||
query_value: str | None
|
||||
query_type, query_value = matches.groups()
|
||||
prefix, query_type, query_value = matches.groups()
|
||||
|
||||
if not query_value:
|
||||
return
|
||||
|
||||
if query_type == "tag":
|
||||
completion_list = list(map(lambda x: "tag:" + x.name, self.lib.tags))
|
||||
completion_list = list(map(lambda x: prefix + "tag:" + x.name, self.lib.tags))
|
||||
elif query_type == "tag_id":
|
||||
completion_list = list(map(lambda x: prefix + "tag_id:" + str(x.id), self.lib.tags))
|
||||
elif query_type == "path":
|
||||
completion_list = list(map(lambda x: "path:" + x, self.lib.get_paths()))
|
||||
completion_list = list(map(lambda x: prefix + "path:" + x, self.lib.get_paths()))
|
||||
elif query_type == "mediatype":
|
||||
single_word_completions = map(
|
||||
lambda x: "mediatype:" + x.name,
|
||||
lambda x: prefix + "mediatype:" + x.name,
|
||||
filter(lambda y: " " not in y.name, MediaCategories.ALL_CATEGORIES),
|
||||
)
|
||||
single_word_completions_quoted = map(
|
||||
lambda x: 'mediatype:"' + x.name + '"',
|
||||
lambda x: prefix + 'mediatype:"' + x.name + '"',
|
||||
filter(lambda y: " " not in y.name, MediaCategories.ALL_CATEGORIES),
|
||||
)
|
||||
multi_word_completions = map(
|
||||
lambda x: 'mediatype:"' + x.name + '"',
|
||||
lambda x: prefix + 'mediatype:"' + x.name + '"',
|
||||
filter(lambda y: " " in y.name, MediaCategories.ALL_CATEGORIES),
|
||||
)
|
||||
|
||||
@@ -1008,7 +1018,9 @@ class QtDriver(DriverMixin, QObject):
|
||||
extensions_list: set[str] = set()
|
||||
for media_cat in MediaCategories.ALL_CATEGORIES:
|
||||
extensions_list = extensions_list | media_cat.extensions
|
||||
completion_list = list(map(lambda x: "filetype:" + x.replace(".", ""), extensions_list))
|
||||
completion_list = list(
|
||||
map(lambda x: prefix + "filetype:" + x.replace(".", ""), extensions_list)
|
||||
)
|
||||
|
||||
update_completion_list: bool = (
|
||||
completion_list != self.main_window.searchFieldCompletionList.stringList()
|
||||
@@ -1125,8 +1137,12 @@ class QtDriver(DriverMixin, QObject):
|
||||
if filter:
|
||||
self.filter = dataclasses.replace(self.filter, **dataclasses.asdict(filter))
|
||||
|
||||
self.main_window.statusbar.showMessage(f'Searching Library: "{self.filter.summary}"')
|
||||
# inform user about running search
|
||||
self.main_window.statusbar.showMessage("Searching Library...")
|
||||
self.main_window.statusbar.repaint()
|
||||
|
||||
# search the library
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
results = self.lib.search_library(self.filter)
|
||||
@@ -1134,17 +1150,11 @@ class QtDriver(DriverMixin, QObject):
|
||||
logger.info("items to render", count=len(results))
|
||||
|
||||
end_time = time.time()
|
||||
if self.filter.summary:
|
||||
# fmt: off
|
||||
self.main_window.statusbar.showMessage(
|
||||
f"{results.total_count} Results Found for \"{self.filter.summary}\""
|
||||
f" ({format_timespan(end_time - start_time)})"
|
||||
)
|
||||
# fmt: on
|
||||
else:
|
||||
self.main_window.statusbar.showMessage(
|
||||
f"{results.total_count} Results ({format_timespan(end_time - start_time)})"
|
||||
)
|
||||
|
||||
# inform user about completed search
|
||||
self.main_window.statusbar.showMessage(
|
||||
f"{results.total_count} Results Found ({format_timespan(end_time - start_time)})"
|
||||
)
|
||||
|
||||
# update page content
|
||||
self.frame_content = results.items
|
||||
@@ -1156,14 +1166,6 @@ class QtDriver(DriverMixin, QObject):
|
||||
self.pages_count, self.filter.page_index, emit=False
|
||||
)
|
||||
|
||||
def set_search_type(self, mode: SearchMode = SearchMode.AND):
|
||||
self.filter_items(
|
||||
FilterState(
|
||||
search_mode=mode,
|
||||
path=self.main_window.searchField.text(),
|
||||
)
|
||||
)
|
||||
|
||||
def remove_recent_library(self, item_key: str):
|
||||
self.settings.beginGroup(SettingItems.LIBS_LIST)
|
||||
self.settings.remove(item_key)
|
||||
|
||||
@@ -25,7 +25,6 @@ from src.core.constants import (
|
||||
TAG_FAVORITE,
|
||||
)
|
||||
from src.core.library import Entry, ItemType, Library
|
||||
from src.core.library.alchemy.enums import FilterState
|
||||
from src.core.library.alchemy.fields import _FieldID
|
||||
from src.core.media_types import MediaCategories, MediaType
|
||||
from src.qt.flowlayout import FlowWidget
|
||||
@@ -453,9 +452,7 @@ class ItemThumb(FlowWidget):
|
||||
entry, toggle_value, tag_id, _FieldID.TAGS_META.name, create_field=True
|
||||
)
|
||||
# update the entry
|
||||
self.driver.frame_content[idx] = self.lib.search_library(
|
||||
FilterState(id=entry.id)
|
||||
).items[0]
|
||||
self.driver.frame_content[idx] = self.lib.get_entry_full(entry.id)
|
||||
|
||||
self.driver.update_badges(update_items)
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ from src.core.constants import (
|
||||
TS_FOLDER_NAME,
|
||||
)
|
||||
from src.core.enums import SettingItems, Theme
|
||||
from src.core.library.alchemy.enums import FilterState
|
||||
from src.core.library.alchemy.fields import (
|
||||
BaseField,
|
||||
DatetimeField,
|
||||
@@ -295,14 +294,13 @@ class PreviewPanel(QWidget):
|
||||
def update_selected_entry(self, driver: "QtDriver"):
|
||||
for grid_idx in driver.selected:
|
||||
entry = driver.frame_content[grid_idx]
|
||||
results = self.lib.search_library(FilterState(id=entry.id))
|
||||
result = self.lib.get_entry_full(entry.id)
|
||||
logger.info(
|
||||
"found item",
|
||||
entries=len(results.items),
|
||||
grid_idx=grid_idx,
|
||||
lookup_id=entry.id,
|
||||
)
|
||||
self.driver.frame_content[grid_idx] = results[0]
|
||||
self.driver.frame_content[grid_idx] = result
|
||||
|
||||
def remove_field_prompt(self, name: str) -> str:
|
||||
return f'Are you sure you want to remove field "{name}"?'
|
||||
@@ -564,14 +562,13 @@ class PreviewPanel(QWidget):
|
||||
# TODO - Entry reload is maybe not necessary
|
||||
for grid_idx in self.driver.selected:
|
||||
entry = self.driver.frame_content[grid_idx]
|
||||
results = self.lib.search_library(FilterState(id=entry.id))
|
||||
result = self.lib.get_entry_full(entry.id)
|
||||
logger.info(
|
||||
"found item",
|
||||
entries=len(results.items),
|
||||
grid_idx=grid_idx,
|
||||
lookup_id=entry.id,
|
||||
)
|
||||
self.driver.frame_content[grid_idx] = results[0]
|
||||
self.driver.frame_content[grid_idx] = result
|
||||
|
||||
if len(self.driver.selected) == 1:
|
||||
# 1 Selected Entry
|
||||
|
||||
@@ -100,7 +100,7 @@ class TagBoxWidget(FieldWidget):
|
||||
tag_widget.on_click.connect(
|
||||
lambda tag_id=tag.id: (
|
||||
self.driver.main_window.searchField.setText(f"tag_id:{tag_id}"),
|
||||
self.driver.filter_items(FilterState(tag_id=tag_id)),
|
||||
self.driver.filter_items(FilterState.from_tag_id(tag_id)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -117,13 +117,20 @@ def library(request):
|
||||
yield lib
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_library() -> Library:
|
||||
lib = Library()
|
||||
lib.open_library(pathlib.Path(CWD / "fixtures" / "search_library"))
|
||||
return lib
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entry_min(library):
|
||||
yield next(library.get_entries())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entry_full(library):
|
||||
def entry_full(library: Library):
|
||||
yield next(library.get_entries(with_joins=True))
|
||||
|
||||
|
||||
|
||||
1
tagstudio/tests/fixtures/search_library/.TagStudio/ts_library.json
vendored
Normal file
BIN
tagstudio/tests/fixtures/search_library/.TagStudio/ts_library.sqlite
vendored
Normal file
BIN
tagstudio/tests/fixtures/search_library/comp colors shapes/r_circle_b_square.png
vendored
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
tagstudio/tests/fixtures/search_library/comp colors shapes/r_circle_g_square.png
vendored
Normal file
|
After Width: | Height: | Size: 3.6 KiB |
BIN
tagstudio/tests/fixtures/search_library/comp colors shapes/r_circle_o_square.png
vendored
Normal file
|
After Width: | Height: | Size: 3.6 KiB |
BIN
tagstudio/tests/fixtures/search_library/comp colors shapes/r_circle_r_square.png
vendored
Normal file
|
After Width: | Height: | Size: 3.2 KiB |
BIN
tagstudio/tests/fixtures/search_library/comp colors shapes/r_circle_y_square.png
vendored
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/blue.jpg
vendored
Normal file
|
After Width: | Height: | Size: 5.9 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/blue_circle.jpg
vendored
Normal file
|
After Width: | Height: | Size: 8.4 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/blue_ellipse.png
vendored
Normal file
|
After Width: | Height: | Size: 5.4 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/blue_square.jpg
vendored
Normal file
|
After Width: | Height: | Size: 8.1 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/circle.png
vendored
Normal file
|
After Width: | Height: | Size: 4.8 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/ellipse.png
vendored
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/green.png
vendored
Normal file
|
After Width: | Height: | Size: 8.1 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/green_circle.png
vendored
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/green_ellipse.png
vendored
Normal file
|
After Width: | Height: | Size: 15 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/green_square.png
vendored
Normal file
|
After Width: | Height: | Size: 8.4 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/orange.png
vendored
Normal file
|
After Width: | Height: | Size: 8.9 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/orange_circle.png
vendored
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/orange_ellipse.png
vendored
Normal file
|
After Width: | Height: | Size: 16 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/orange_square.png
vendored
Normal file
|
After Width: | Height: | Size: 9.2 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/red.jpg
vendored
Normal file
|
After Width: | Height: | Size: 9.7 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/red_circle.jpg
vendored
Normal file
|
After Width: | Height: | Size: 8.1 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/red_ellipse.png
vendored
Normal file
|
After Width: | Height: | Size: 5.8 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/red_square.jpg
vendored
Normal file
|
After Width: | Height: | Size: 7.8 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/shape.png
vendored
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/square.png
vendored
Normal file
|
After Width: | Height: | Size: 4.0 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/yellow.png
vendored
Normal file
|
After Width: | Height: | Size: 3.6 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/yellow_circle.png
vendored
Normal file
|
After Width: | Height: | Size: 6.5 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/yellow_ellipse.png
vendored
Normal file
|
After Width: | Height: | Size: 6.2 KiB |
BIN
tagstudio/tests/fixtures/search_library/inherit colors shapes/yellow_square.png
vendored
Normal file
|
After Width: | Height: | Size: 3.8 KiB |
@@ -26,5 +26,5 @@ def test_refresh_missing_files(library: Library):
|
||||
assert list(registry.fix_missing_files()) == [1, 2]
|
||||
|
||||
# `bar.md` should be relinked to new correct path
|
||||
results = library.search_library(FilterState(path="bar.md"))
|
||||
results = library.search_library(FilterState.from_path("bar.md"))
|
||||
assert results[0].path == pathlib.Path("bar.md")
|
||||
|
||||
@@ -79,7 +79,7 @@ def test_library_state_update(qt_driver):
|
||||
assert len(qt_driver.frame_content) == 2
|
||||
|
||||
# filter by tag
|
||||
state = FilterState(tag="foo", page_size=10)
|
||||
state = FilterState.from_tag_name("foo").with_page_size(10)
|
||||
qt_driver.filter_items(state)
|
||||
assert qt_driver.filter.page_size == 10
|
||||
assert len(qt_driver.frame_content) == 1
|
||||
@@ -94,7 +94,7 @@ def test_library_state_update(qt_driver):
|
||||
assert list(entry.tags)[0].name == "foo"
|
||||
|
||||
# When state property is changed, previous one is overwritten
|
||||
state = FilterState(path="*bar.md")
|
||||
state = FilterState.from_path("*bar.md")
|
||||
qt_driver.filter_items(state)
|
||||
assert len(qt_driver.frame_content) == 1
|
||||
entry = qt_driver.frame_content[0]
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import pytest
|
||||
from src.core.library.alchemy.enums import FilterState
|
||||
|
||||
|
||||
def test_filter_state_query():
|
||||
# Given
|
||||
query = "tag:foo"
|
||||
state = FilterState(query=query)
|
||||
|
||||
# When
|
||||
assert state.tag == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["attribute", "comparator"],
|
||||
[
|
||||
("tag", str),
|
||||
("tag_id", int),
|
||||
("path", str),
|
||||
("name", str),
|
||||
("id", int),
|
||||
],
|
||||
)
|
||||
def test_filter_state_attrs_compare(attribute, comparator):
|
||||
# When
|
||||
state = FilterState(**{attribute: "2"})
|
||||
|
||||
# Then
|
||||
# compare the attribute value
|
||||
assert getattr(state, attribute) == comparator("2")
|
||||
|
||||
# Then
|
||||
for prop in ("tag", "tag_id", "path", "name", "id"):
|
||||
if prop == attribute:
|
||||
continue
|
||||
assert not getattr(state, prop)
|
||||
@@ -115,7 +115,7 @@ def test_library_search(library, generate_tag, entry_full):
|
||||
tag = list(entry_full.tags)[0]
|
||||
|
||||
results = library.search_library(
|
||||
FilterState(tag=tag.name),
|
||||
FilterState.from_tag_name(tag.name),
|
||||
)
|
||||
|
||||
assert results.total_count == 1
|
||||
@@ -141,11 +141,11 @@ def test_tag_search(library):
|
||||
assert not library.search_tags(tag.name * 2)
|
||||
|
||||
|
||||
def test_get_entry(library, entry_min):
|
||||
def test_get_entry(library: Library, entry_min):
|
||||
assert entry_min.id
|
||||
results = library.search_library(FilterState(id=entry_min.id))
|
||||
assert len(results) == results.total_count == 1
|
||||
assert results[0].tags
|
||||
result = library.get_entry_full(entry_min.id)
|
||||
assert result
|
||||
assert result.tags
|
||||
|
||||
|
||||
def test_entries_count(library):
|
||||
@@ -153,11 +153,7 @@ def test_entries_count(library):
|
||||
new_ids = library.add_entries(entries)
|
||||
assert len(new_ids) == 10
|
||||
|
||||
results = library.search_library(
|
||||
FilterState(
|
||||
page_size=5,
|
||||
)
|
||||
)
|
||||
results = library.search_library(FilterState.show_all().with_page_size(5))
|
||||
|
||||
assert results.total_count == 12
|
||||
assert len(results) == 5
|
||||
@@ -184,7 +180,7 @@ def test_add_field_to_entry(library):
|
||||
assert len(entry.tag_box_fields) == 3
|
||||
|
||||
|
||||
def test_add_field_tag(library, entry_full, generate_tag):
|
||||
def test_add_field_tag(library: Library, entry_full, generate_tag):
|
||||
# Given
|
||||
tag_name = "xxx"
|
||||
tag = generate_tag(tag_name)
|
||||
@@ -194,8 +190,8 @@ def test_add_field_tag(library, entry_full, generate_tag):
|
||||
library.add_field_tag(entry_full, tag, tag_field.type_key)
|
||||
|
||||
# Then
|
||||
results = library.search_library(FilterState(id=entry_full.id))
|
||||
tag_field = results[0].tag_box_fields[0]
|
||||
result = library.get_entry_full(entry_full.id)
|
||||
tag_field = result.tag_box_fields[0]
|
||||
assert [x.name for x in tag_field.tags if x.name == tag_name]
|
||||
|
||||
|
||||
@@ -228,7 +224,7 @@ def test_search_filter_extensions(library, is_exclude):
|
||||
|
||||
# When
|
||||
results = library.search_library(
|
||||
FilterState(),
|
||||
FilterState.show_all(),
|
||||
)
|
||||
|
||||
# Then
|
||||
@@ -249,7 +245,7 @@ def test_search_library_case_insensitive(library):
|
||||
|
||||
# When
|
||||
results = library.search_library(
|
||||
FilterState(tag=tag.name.upper()),
|
||||
FilterState.from_tag_name(tag.name.upper()),
|
||||
)
|
||||
|
||||
# Then
|
||||
@@ -323,7 +319,8 @@ def test_update_entry_with_multiple_identical_fields(library, entry_full):
|
||||
assert entry.text_fields[1].value == "new value"
|
||||
|
||||
|
||||
def test_mirror_entry_fields(library, entry_full):
|
||||
def test_mirror_entry_fields(library: Library, entry_full):
|
||||
# new entry
|
||||
target_entry = Entry(
|
||||
folder=library.folder,
|
||||
path=Path("xxx"),
|
||||
@@ -336,16 +333,19 @@ def test_mirror_entry_fields(library, entry_full):
|
||||
],
|
||||
)
|
||||
|
||||
# insert new entry and get id
|
||||
entry_id = library.add_entries([target_entry])[0]
|
||||
|
||||
results = library.search_library(FilterState(id=entry_id))
|
||||
new_entry = results[0]
|
||||
# get new entry from library
|
||||
new_entry = library.get_entry_full(entry_id)
|
||||
|
||||
# mirror fields onto new entry
|
||||
library.mirror_entry_fields(new_entry, entry_full)
|
||||
|
||||
results = library.search_library(FilterState(id=entry_id))
|
||||
entry = results[0]
|
||||
# get new entry from library again
|
||||
entry = library.get_entry_full(entry_id)
|
||||
|
||||
# make sure fields are there after getting it from the library again
|
||||
assert len(entry.fields) == 4
|
||||
assert {x.type_key for x in entry.fields} == {
|
||||
_FieldID.TITLE.name,
|
||||
@@ -367,22 +367,6 @@ def test_remove_tag_from_field(library, entry_full):
|
||||
assert removed_tag not in [tag.name for tag in field.tags]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query_name", "has_result"],
|
||||
[
|
||||
("foo", 1), # filename substring
|
||||
("bar", 1), # filename substring
|
||||
("one", 0), # path, should not match
|
||||
],
|
||||
)
|
||||
def test_search_file_name(library, query_name, has_result):
|
||||
results = library.search_library(
|
||||
FilterState(name=query_name),
|
||||
)
|
||||
|
||||
assert results.total_count == has_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query_name", "has_result"],
|
||||
[
|
||||
@@ -392,12 +376,10 @@ def test_search_file_name(library, query_name, has_result):
|
||||
(222, 0),
|
||||
],
|
||||
)
|
||||
def test_search_entry_id(library, query_name, has_result):
|
||||
results = library.search_library(
|
||||
FilterState(id=query_name),
|
||||
)
|
||||
def test_search_entry_id(library: Library, query_name: int, has_result):
|
||||
result = library.get_entry(query_name)
|
||||
|
||||
assert results.total_count == has_result
|
||||
assert (result is not None) == has_result
|
||||
|
||||
|
||||
def test_update_field_order(library, entry_full):
|
||||
@@ -446,36 +428,36 @@ def test_library_prefs_multiple_identical_vals():
|
||||
|
||||
|
||||
def test_path_search_glob_after(library: Library):
|
||||
results = library.search_library(FilterState(path="foo*"))
|
||||
results = library.search_library(FilterState.from_path("foo*"))
|
||||
assert results.total_count == 1
|
||||
assert len(results.items) == 1
|
||||
|
||||
|
||||
def test_path_search_glob_in_front(library: Library):
|
||||
results = library.search_library(FilterState(path="*bar.md"))
|
||||
results = library.search_library(FilterState.from_path("*bar.md"))
|
||||
assert results.total_count == 1
|
||||
assert len(results.items) == 1
|
||||
|
||||
|
||||
def test_path_search_glob_both_sides(library: Library):
|
||||
results = library.search_library(FilterState(path="*one/two*"))
|
||||
results = library.search_library(FilterState.from_path("*one/two*"))
|
||||
assert results.total_count == 1
|
||||
assert len(results.items) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("md", 1), ("txt", 1), ("png", 0)])
|
||||
def test_filetype_search(library, filetype, num_of_filetype):
|
||||
results = library.search_library(FilterState(filetype=filetype))
|
||||
results = library.search_library(FilterState.from_filetype(filetype))
|
||||
assert len(results.items) == num_of_filetype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("png", 2), ("apng", 1), ("ng", 0)])
|
||||
def test_filetype_return_one_filetype(file_mediatypes_library, filetype, num_of_filetype):
|
||||
results = file_mediatypes_library.search_library(FilterState(filetype=filetype))
|
||||
results = file_mediatypes_library.search_library(FilterState.from_filetype(filetype))
|
||||
assert len(results.items) == num_of_filetype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["mediatype", "num_of_mediatype"], [("plaintext", 2), ("image", 0)])
|
||||
def test_mediatype_search(library, mediatype, num_of_mediatype):
|
||||
results = library.search_library(FilterState(mediatype=mediatype))
|
||||
results = library.search_library(FilterState.from_mediatype(mediatype))
|
||||
assert len(results.items) == num_of_mediatype
|
||||
|
||||
124
tagstudio/tests/test_search.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from src.core.library.alchemy.enums import FilterState
|
||||
from src.core.library.alchemy.library import Library
|
||||
from src.core.query_lang.util import ParsingError
|
||||
|
||||
|
||||
def verify_count(lib: Library, query: str, count: int):
|
||||
results = lib.search_library(FilterState.from_search_query(query))
|
||||
assert results.total_count == count
|
||||
assert len(results.items) == count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query", "count"],
|
||||
[
|
||||
("", 29),
|
||||
("path:*", 29),
|
||||
("path:*inherit*", 24),
|
||||
("path:*comp*", 5),
|
||||
("special:untagged", 1),
|
||||
("filetype:png", 23),
|
||||
("filetype:jpg", 6),
|
||||
("filetype:'jpg'", 6),
|
||||
("tag_id:1011", 5),
|
||||
("tag_id:1038", 11),
|
||||
("doesnt exist", 0),
|
||||
("archived", 0),
|
||||
("favorite", 0),
|
||||
("tag:favorite", 0),
|
||||
("circle", 11),
|
||||
("tag:square", 11),
|
||||
("green", 5),
|
||||
("orange", 5),
|
||||
("tag:orange", 5),
|
||||
],
|
||||
)
|
||||
def test_single_constraint(search_library: Library, query: str, count: int):
|
||||
verify_count(search_library, query, count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query", "count"],
|
||||
[
|
||||
("circle aND square", 5),
|
||||
("circle square", 5),
|
||||
("green AND square", 2),
|
||||
("green square", 2),
|
||||
("orange AnD square", 2),
|
||||
("orange square", 2),
|
||||
("orange and filetype:png", 5),
|
||||
("square and filetype:jpg", 2),
|
||||
("orange filetype:png", 5),
|
||||
("green path:*inherit*", 4),
|
||||
],
|
||||
)
|
||||
def test_and(search_library: Library, query: str, count: int):
|
||||
verify_count(search_library, query, count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query", "count"],
|
||||
[
|
||||
("square or circle", 17),
|
||||
("orange or green", 10),
|
||||
("orange Or circle", 14),
|
||||
("orange oR square", 14),
|
||||
("square OR green", 14),
|
||||
("circle or green", 14),
|
||||
("green or circle", 14),
|
||||
("filetype:jpg or tag:orange", 11),
|
||||
("red or filetype:png", 25),
|
||||
("filetype:jpg or path:*comp*", 11),
|
||||
],
|
||||
)
|
||||
def test_or(search_library: Library, query: str, count: int):
|
||||
verify_count(search_library, query, count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query", "count"],
|
||||
[
|
||||
("not unexistant", 29),
|
||||
("not path:*", 0),
|
||||
("not not path:*", 29),
|
||||
("not special:untagged", 28),
|
||||
("not filetype:png", 6),
|
||||
("not filetype:jpg", 23),
|
||||
("not tag_id:1011", 24),
|
||||
("not tag_id:1038", 18),
|
||||
("not green", 24),
|
||||
("tag:favorite", 0),
|
||||
("not circle", 18),
|
||||
("not tag:square", 18),
|
||||
("circle and not square", 6),
|
||||
("not circle and square", 6),
|
||||
("special:untagged or not filetype:jpg", 24),
|
||||
("not square or green", 20),
|
||||
],
|
||||
)
|
||||
def test_not(search_library: Library, query: str, count: int):
|
||||
verify_count(search_library, query, count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query", "count"],
|
||||
[
|
||||
("(tag_id:1041)", 11),
|
||||
("(((tag_id:1041)))", 11),
|
||||
("not (not tag_id:1041)", 11),
|
||||
("((circle) and (not square))", 6),
|
||||
("(not ((square) OR (green)))", 15),
|
||||
("filetype:png and (tag:square or green)", 12),
|
||||
],
|
||||
)
|
||||
def test_parentheses(search_library: Library, query: str, count: int):
|
||||
verify_count(search_library, query, count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_query", ["asd AND", "asd AND AND", "tag:(", "(asd", "asd[]", "asd]", ":", "tag: :"]
|
||||
)
|
||||
def test_syntax(search_library: Library, invalid_query: str):
|
||||
with pytest.raises(ParsingError) as e_info: # noqa: F841
|
||||
search_library.search_library(FilterState.from_search_query(invalid_query))
|
||||