mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-02-01 23:59:10 +00:00
rudimentary search field integration
This commit is contained in:
@@ -2,6 +2,9 @@ import enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.query_lang import AST as Query # noqa: N811
|
||||
from src.core.query_lang import Parser
|
||||
|
||||
|
||||
class TagColor(enum.IntEnum):
|
||||
DEFAULT = 1
|
||||
@@ -84,36 +87,27 @@ class FilterState:
|
||||
# a generic query to be parsed
|
||||
query: str | None = None
|
||||
|
||||
ast: Query = None
|
||||
|
||||
def __post_init__(self):
|
||||
# strip values automatically
|
||||
if query := (self.query and self.query.strip()):
|
||||
# parse the value
|
||||
if ":" in query:
|
||||
kind, _, value = query.partition(":")
|
||||
value = value.replace('"', "")
|
||||
else:
|
||||
# default to tag search
|
||||
kind, value = "tag", query
|
||||
|
||||
if kind == "tag_id":
|
||||
self.tag_id = int(value)
|
||||
elif kind == "tag":
|
||||
self.tag = value
|
||||
elif kind == "path":
|
||||
self.path = value
|
||||
elif kind == "name":
|
||||
self.name = value
|
||||
elif kind == "id":
|
||||
self.id = int(self.id) if str(self.id).isnumeric() else self.id
|
||||
elif kind == "filetype":
|
||||
self.filetype = value
|
||||
elif kind == "mediatype":
|
||||
self.mediatype = value
|
||||
query = None
|
||||
|
||||
if self.query:
|
||||
query = self.query
|
||||
elif self.tag:
|
||||
query = self.tag.strip()
|
||||
self.tag = None
|
||||
elif self.tag_id:
|
||||
query = f"tag_id:{self.tag_id}"
|
||||
self.tag_id = None
|
||||
elif self.path:
|
||||
query = f"path:'{str(self.path).strip()}'"
|
||||
|
||||
if query:
|
||||
self.ast = Parser(query).parse()
|
||||
else:
|
||||
self.tag = self.tag and self.tag.strip()
|
||||
self.tag_id = int(self.tag_id) if str(self.tag_id).isnumeric() else self.tag_id
|
||||
self.path = self.path and str(self.path).strip()
|
||||
self.name = self.name and self.name.strip()
|
||||
self.id = int(self.id) if str(self.id).isnumeric() else self.id
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ from .fields import (
|
||||
)
|
||||
from .joins import TagField, TagSubtag
|
||||
from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType
|
||||
from .visitors import SQLBoolExpressionBuilder
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@@ -417,7 +418,13 @@ class Library:
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
statement = select(Entry)
|
||||
|
||||
if search.tag:
|
||||
if search.ast:
|
||||
statement = (
|
||||
statement.join(Entry.tag_box_fields)
|
||||
.join(TagBoxField.tags)
|
||||
.where(SQLBoolExpressionBuilder().visit(search.ast))
|
||||
)
|
||||
elif search.tag:
|
||||
SubtagAlias = aliased(Tag) # noqa: N806
|
||||
statement = (
|
||||
statement.join(Entry.tag_box_fields)
|
||||
@@ -439,7 +446,6 @@ class Library:
|
||||
.join(TagBoxField.tags)
|
||||
.where(Tag.id == search.tag_id)
|
||||
)
|
||||
|
||||
elif search.id:
|
||||
statement = statement.where(Entry.id == search.id)
|
||||
elif search.name:
|
||||
|
||||
50
tagstudio/src/core/library/alchemy/visitors.py
Normal file
50
tagstudio/src/core/library/alchemy/visitors.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.sql.expression import ColumnExpressionArgument
|
||||
from src.core.media_types import MediaCategories
|
||||
from src.core.query_lang import BaseVisitor
|
||||
from src.core.query_lang.ast import ANDList, Constraint, ConstraintType, ORList, Property
|
||||
|
||||
from .models import Entry, Tag, TagAlias
|
||||
|
||||
|
||||
class SQLBoolExpressionBuilder(BaseVisitor):
|
||||
def visit_or_list(self, node: ORList) -> ColumnExpressionArgument:
|
||||
return or_(*[self.visit(element) for element in node.elements])
|
||||
|
||||
def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
|
||||
return and_(*[self.visit(term) for term in node.terms])
|
||||
|
||||
def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument:
|
||||
if len(node.properties) != 0:
|
||||
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG
|
||||
|
||||
if node.type == ConstraintType.Tag:
|
||||
return or_(
|
||||
Tag.name.ilike(node.value),
|
||||
Tag.shorthand.ilike(node.value),
|
||||
TagAlias.name.ilike(node.value),
|
||||
aliased(Tag).name.ilike(node.value),
|
||||
)
|
||||
elif node.type == ConstraintType.TagID:
|
||||
return Tag.id == int(node.value)
|
||||
elif node.type == ConstraintType.Path:
|
||||
return Entry.path.ilike(node.value.replace("*", "%")) # TODO TSQLANG this is broken
|
||||
elif node.type == ConstraintType.MediaType:
|
||||
extensions: set[str] = set[str]()
|
||||
for media_cat in MediaCategories.ALL_CATEGORIES:
|
||||
if node.value == media_cat.name:
|
||||
extensions = extensions | media_cat.extensions
|
||||
break
|
||||
return Entry.suffix.in_(
|
||||
map(lambda x: x.replace(".", ""), extensions)
|
||||
) # TODO audio doesn't work on mp3 files (might be my library)
|
||||
elif node.type == ConstraintType.FileType:
|
||||
return Entry.suffix.ilike(
|
||||
node.value
|
||||
) # TODO TSQLANG this is broken for mp3, but works for png (might be my library)
|
||||
|
||||
raise NotImplementedError("This type of constraint is not implemented yet")
|
||||
|
||||
def visit_property(self, node: Property) -> None:
|
||||
return
|
||||
@@ -65,6 +65,7 @@ from src.core.constants import (
|
||||
)
|
||||
from src.core.driver import DriverMixin
|
||||
from src.core.enums import LibraryPrefs, MacroID, SettingItems
|
||||
from src.core.library.alchemy import Library
|
||||
from src.core.library.alchemy.enums import (
|
||||
FieldTypeEnum,
|
||||
FilterState,
|
||||
@@ -130,6 +131,8 @@ class QtDriver(DriverMixin, QObject):
|
||||
|
||||
preview_panel: PreviewPanel
|
||||
|
||||
lib: Library
|
||||
|
||||
def __init__(self, backend, args):
|
||||
super().__init__()
|
||||
# prevent recursive badges update when multiple items selected
|
||||
@@ -1100,8 +1103,12 @@ class QtDriver(DriverMixin, QObject):
|
||||
if filter:
|
||||
self.filter = dataclasses.replace(self.filter, **dataclasses.asdict(filter))
|
||||
|
||||
# inform user about running search
|
||||
self.main_window.statusbar.showMessage(f'Searching Library: "{self.filter.summary}"')
|
||||
self.main_window.statusbar.repaint()
|
||||
|
||||
# search the library
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
results = self.lib.search_library(self.filter)
|
||||
@@ -1109,6 +1116,8 @@ class QtDriver(DriverMixin, QObject):
|
||||
logger.info("items to render", count=len(results))
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# inform user about completed search
|
||||
if self.filter.summary:
|
||||
# fmt: off
|
||||
self.main_window.statusbar.showMessage(
|
||||
|
||||
Reference in New Issue
Block a user