mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-02-02 08:09:13 +00:00
refactoring
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.query_lang import AST as Query # noqa: N811
|
||||
from src.core.query_lang import Parser
|
||||
from src.core.query_lang import Constraint, ConstraintType, Parser
|
||||
|
||||
|
||||
class TagColor(enum.IntEnum):
|
||||
@@ -77,27 +77,17 @@ class FilterState:
|
||||
# these should remain
|
||||
page_index: int | None = None
|
||||
page_size: int | None = None
|
||||
search_mode: SearchMode = SearchMode.AND # TODO - actually implement this
|
||||
search_mode: SearchMode = SearchMode.AND # TODO this can be removed?
|
||||
|
||||
# these should be erased on update
|
||||
# tag name
|
||||
tag: str | None = None
|
||||
# tag ID
|
||||
tag_id: int | None = None
|
||||
|
||||
# entry id
|
||||
id: int | None = None
|
||||
# whole path
|
||||
path: Path | str | None = None
|
||||
# file name
|
||||
name: str | None = None
|
||||
# file type
|
||||
filetype: str | None = None
|
||||
mediatype: str | None = None
|
||||
|
||||
# a generic query to be parsed
|
||||
query: str | None = None
|
||||
|
||||
# Abstract Syntax Tree Of the current Search Query
|
||||
ast: Query = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -105,34 +95,24 @@ class FilterState:
|
||||
|
||||
query = None
|
||||
|
||||
if self.query is not None:
|
||||
query = self.query
|
||||
elif self.tag is not None:
|
||||
query = self.tag.strip()
|
||||
self.tag = None
|
||||
elif self.tag_id is not None:
|
||||
query = f"tag_id:{self.tag_id}"
|
||||
self.tag_id = None
|
||||
elif self.path is not None:
|
||||
if self.path is not None:
|
||||
query = f"path:'{str(self.path).strip()}'"
|
||||
|
||||
self.query = query
|
||||
|
||||
if query:
|
||||
if query is not None:
|
||||
self.ast = Parser(query).parse()
|
||||
else:
|
||||
self.name = self.name and self.name.strip()
|
||||
self.id = int(self.id) if str(self.id).isnumeric() else self.id
|
||||
|
||||
if self.page_index is None:
|
||||
if self.page_index is None: # TODO QTLANG can this just be a default value?
|
||||
self.page_index = 0
|
||||
if self.page_size is None:
|
||||
if self.page_size is None: # TODO QTLANG can this just be a default value?
|
||||
self.page_size = 500
|
||||
|
||||
@property
|
||||
def summary(self):
|
||||
"""Show query summary."""
|
||||
return self.query or self.tag or self.name or self.tag_id or self.path or self.id
|
||||
return self.name or self.path or self.id
|
||||
|
||||
@property
|
||||
def limit(self):
|
||||
@@ -142,6 +122,37 @@ class FilterState:
|
||||
def offset(self):
|
||||
return self.page_size * self.page_index
|
||||
|
||||
@classmethod
|
||||
def show_all(cls) -> "FilterState":
|
||||
return FilterState()
|
||||
|
||||
@classmethod
|
||||
def from_search_query(cls, search_query: str) -> "FilterState":
|
||||
return cls(ast=Parser(search_query).parse())
|
||||
|
||||
@classmethod
|
||||
def from_tag_id(cls, tag_id: int | str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.TagID, str(tag_id), []))
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path: Path | str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.Path, str(path).strip(), []))
|
||||
|
||||
@classmethod
|
||||
def from_mediatype(cls, mediatype: str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.MediaType, mediatype, []))
|
||||
|
||||
@classmethod
|
||||
def from_filetype(cls, filetype: str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.FileType, filetype, []))
|
||||
|
||||
@classmethod
|
||||
def from_tag_name(cls, tag_name: str) -> "FilterState":
|
||||
return cls(ast=Constraint(ConstraintType.Tag, tag_name, []))
|
||||
|
||||
def with_page_size(self, page_size: int) -> "FilterState":
|
||||
return replace(self, page_size=page_size)
|
||||
|
||||
|
||||
class FieldTypeEnum(enum.Enum):
|
||||
TEXT_LINE = "Text Line"
|
||||
|
||||
@@ -425,28 +425,6 @@ class Library:
|
||||
.outerjoin(TagAlias)
|
||||
.where(SQLBoolExpressionBuilder().visit(search.ast))
|
||||
)
|
||||
elif search.tag:
|
||||
SubtagAlias = aliased(Tag) # noqa: N806
|
||||
statement = (
|
||||
statement.join(Entry.tag_box_fields)
|
||||
.join(TagBoxField.tags)
|
||||
.outerjoin(Tag.aliases)
|
||||
.outerjoin(SubtagAlias, Tag.subtags)
|
||||
.where(
|
||||
or_(
|
||||
Tag.name.ilike(search.tag),
|
||||
Tag.shorthand.ilike(search.tag),
|
||||
TagAlias.name.ilike(search.tag),
|
||||
SubtagAlias.name.ilike(search.tag),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif search.tag_id:
|
||||
statement = (
|
||||
statement.join(Entry.tag_box_fields)
|
||||
.join(TagBoxField.tags)
|
||||
.where(Tag.id == search.tag_id)
|
||||
)
|
||||
elif search.id:
|
||||
statement = statement.where(Entry.id == search.id)
|
||||
elif search.name:
|
||||
@@ -460,18 +438,6 @@ class Library:
|
||||
elif search.path:
|
||||
search_str = str(search.path).replace("*", "%")
|
||||
statement = statement.where(Entry.path.ilike(search_str))
|
||||
elif search.filetype:
|
||||
statement = statement.where(Entry.suffix.ilike(f"{search.filetype}"))
|
||||
elif search.mediatype:
|
||||
extensions: set[str] = set[str]()
|
||||
for media_cat in MediaCategories.ALL_CATEGORIES:
|
||||
if search.mediatype == media_cat.name:
|
||||
extensions = extensions | media_cat.extensions
|
||||
break
|
||||
# just need to map it to search db - suffixes do not have '.'
|
||||
statement = statement.where(
|
||||
Entry.suffix.in_(map(lambda x: x.replace(".", ""), extensions))
|
||||
)
|
||||
|
||||
extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
|
||||
is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST)
|
||||
@@ -490,7 +456,9 @@ class Library:
|
||||
.options(selectinload(Tag.aliases), selectinload(Tag.subtags)),
|
||||
)
|
||||
|
||||
query_count = select(func.count()).select_from(statement.alias("entries"))
|
||||
query_count = select(func.count()).select_from(
|
||||
statement.alias("entries")
|
||||
) # TODO this should count the number of *unique* results
|
||||
count_all: int = session.execute(query_count).scalar()
|
||||
|
||||
statement = statement.limit(search.limit).offset(search.offset)
|
||||
|
||||
@@ -50,7 +50,7 @@ class DupeRegistry:
|
||||
continue
|
||||
|
||||
results = self.library.search_library(
|
||||
FilterState(path=path_relative),
|
||||
FilterState.from_path(path_relative),
|
||||
)
|
||||
|
||||
if not results:
|
||||
|
||||
@@ -141,7 +141,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
self.rm: ResourceManager = ResourceManager()
|
||||
self.args = args
|
||||
self.frame_content = []
|
||||
self.filter = FilterState()
|
||||
self.filter = FilterState.show_all()
|
||||
self.pages_count = 0
|
||||
|
||||
self.scrollbar_pos = 0
|
||||
@@ -469,7 +469,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
]
|
||||
self.item_thumbs: list[ItemThumb] = []
|
||||
self.thumb_renderers: list[ThumbRenderer] = []
|
||||
self.filter = FilterState()
|
||||
self.filter = FilterState.show_all()
|
||||
self.init_library_window()
|
||||
|
||||
path_result = self.evaluate_path(self.args.open)
|
||||
@@ -510,13 +510,17 @@ class QtDriver(DriverMixin, QObject):
|
||||
# Search Button
|
||||
search_button: QPushButton = self.main_window.searchButton
|
||||
search_button.clicked.connect(
|
||||
lambda: self.filter_items(FilterState(query=self.main_window.searchField.text()))
|
||||
lambda: self.filter_items(
|
||||
FilterState.from_search_query(self.main_window.searchField.text())
|
||||
)
|
||||
)
|
||||
# Search Field
|
||||
search_field: QLineEdit = self.main_window.searchField
|
||||
search_field.returnPressed.connect(
|
||||
# TODO - parse search field for filters
|
||||
lambda: self.filter_items(FilterState(query=self.main_window.searchField.text()))
|
||||
lambda: self.filter_items(
|
||||
FilterState.from_search_query(self.main_window.searchField.text())
|
||||
)
|
||||
)
|
||||
# Search Type Selector
|
||||
search_type_selector: QComboBox = self.main_window.comboBox_2
|
||||
@@ -1142,7 +1146,7 @@ class QtDriver(DriverMixin, QObject):
|
||||
|
||||
def set_search_type(self, mode: SearchMode = SearchMode.AND):
|
||||
self.filter_items(
|
||||
FilterState(
|
||||
FilterState( # TODO TSQLANG deal with this
|
||||
search_mode=mode,
|
||||
path=self.main_window.searchField.text(),
|
||||
)
|
||||
|
||||
@@ -455,7 +455,7 @@ class ItemThumb(FlowWidget):
|
||||
# update the entry
|
||||
self.driver.frame_content[idx] = self.lib.search_library(
|
||||
FilterState(id=entry.id) # TODO TSQLANG don't search, get entry directly by id
|
||||
).items[0]
|
||||
).items[0] # self.lib.get_entry(entry.id)
|
||||
|
||||
self.driver.update_badges(update_items)
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ class TagBoxWidget(FieldWidget):
|
||||
tag_widget.on_click.connect(
|
||||
lambda tag_id=tag.id: (
|
||||
self.driver.main_window.searchField.setText(f"tag_id:{tag_id}"),
|
||||
self.driver.filter_items(FilterState(query=f"tag_id:{tag_id}")),
|
||||
self.driver.filter_items(FilterState.from_tag_id(tag_id)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -26,5 +26,5 @@ def test_refresh_missing_files(library: Library):
|
||||
assert list(registry.fix_missing_files()) == [1, 2]
|
||||
|
||||
# `bar.md` should be relinked to new correct path
|
||||
results = library.search_library(FilterState(path="bar.md"))
|
||||
results = library.search_library(FilterState.from_path("bar.md"))
|
||||
assert results[0].path == pathlib.Path("bar.md")
|
||||
|
||||
@@ -79,7 +79,7 @@ def test_library_state_update(qt_driver):
|
||||
assert len(qt_driver.frame_content) == 2
|
||||
|
||||
# filter by tag
|
||||
state = FilterState(tag="foo", page_size=10)
|
||||
state = FilterState.from_tag_name("foo").with_page_size(10)
|
||||
qt_driver.filter_items(state)
|
||||
assert qt_driver.filter.page_size == 10
|
||||
assert len(qt_driver.frame_content) == 1
|
||||
@@ -94,7 +94,7 @@ def test_library_state_update(qt_driver):
|
||||
assert list(entry.tags)[0].name == "foo"
|
||||
|
||||
# When state property is changed, previous one is overwritten
|
||||
state = FilterState(path="*bar.md")
|
||||
state = FilterState.from_path("*bar.md")
|
||||
qt_driver.filter_items(state)
|
||||
assert len(qt_driver.frame_content) == 1
|
||||
entry = qt_driver.frame_content[0]
|
||||
|
||||
@@ -2,13 +2,13 @@ import pytest
|
||||
from src.core.library.alchemy.enums import FilterState
|
||||
|
||||
|
||||
def test_filter_state_query():
|
||||
def test_filter_state_query(): # TODO TSQLANG can this test be removed?
|
||||
# Given
|
||||
query = "tag:foo"
|
||||
state = FilterState(query=query)
|
||||
state = FilterState.from_search_query(query)
|
||||
|
||||
# When
|
||||
assert state.tag == "foo"
|
||||
# assert state.tag == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -21,7 +21,7 @@ def test_filter_state_query():
|
||||
("id", int),
|
||||
],
|
||||
)
|
||||
def test_filter_state_attrs_compare(attribute, comparator):
|
||||
def test_filter_state_attrs_compare(attribute, comparator): # TODO TSQLANG rework this test
|
||||
# When
|
||||
state = FilterState(**{attribute: "2"})
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ def test_library_search(library, generate_tag, entry_full):
|
||||
tag = list(entry_full.tags)[0]
|
||||
|
||||
results = library.search_library(
|
||||
FilterState(tag=tag.name),
|
||||
FilterState.from_tag_name(tag.name),
|
||||
)
|
||||
|
||||
assert results.total_count == 1
|
||||
@@ -159,11 +159,7 @@ def test_entries_count(library):
|
||||
new_ids = library.add_entries(entries)
|
||||
assert len(new_ids) == 10
|
||||
|
||||
results = library.search_library(
|
||||
FilterState(
|
||||
page_size=5,
|
||||
)
|
||||
)
|
||||
results = library.search_library(FilterState.show_all().with_page_size(5))
|
||||
|
||||
assert results.total_count == 12
|
||||
assert len(results) == 5
|
||||
@@ -234,7 +230,7 @@ def test_search_filter_extensions(library, is_exclude):
|
||||
|
||||
# When
|
||||
results = library.search_library(
|
||||
FilterState(),
|
||||
FilterState.show_all(),
|
||||
)
|
||||
|
||||
# Then
|
||||
@@ -255,7 +251,7 @@ def test_search_library_case_insensitive(library):
|
||||
|
||||
# When
|
||||
results = library.search_library(
|
||||
FilterState(tag=tag.name.upper()),
|
||||
FilterState.from_tag_name(tag.name.upper()),
|
||||
)
|
||||
|
||||
# Then
|
||||
@@ -285,7 +281,7 @@ def test_save_windows_path(library, generate_tag):
|
||||
# library.add_tag(tag)
|
||||
library.add_field_tag(entry, tag, create_field=True)
|
||||
|
||||
results = library.search_library(FilterState(tag=tag_name))
|
||||
results = library.search_library(FilterState.from_tag_name(tag_name))
|
||||
assert results
|
||||
|
||||
# path should be saved in posix format
|
||||
@@ -474,36 +470,36 @@ def test_library_prefs_multiple_identical_vals():
|
||||
|
||||
|
||||
def test_path_search_glob_after(library: Library):
|
||||
results = library.search_library(FilterState(path="foo*"))
|
||||
results = library.search_library(FilterState.from_path("foo*"))
|
||||
assert results.total_count == 1
|
||||
assert len(results.items) == 1
|
||||
|
||||
|
||||
def test_path_search_glob_in_front(library: Library):
|
||||
results = library.search_library(FilterState(path="*bar.md"))
|
||||
results = library.search_library(FilterState.from_path("*bar.md"))
|
||||
assert results.total_count == 1
|
||||
assert len(results.items) == 1
|
||||
|
||||
|
||||
def test_path_search_glob_both_sides(library: Library):
|
||||
results = library.search_library(FilterState(path="*one/two*"))
|
||||
results = library.search_library(FilterState.from_path("*one/two*"))
|
||||
assert results.total_count == 1
|
||||
assert len(results.items) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("md", 1), ("txt", 1), ("png", 0)])
|
||||
def test_filetype_search(library, filetype, num_of_filetype):
|
||||
results = library.search_library(FilterState(filetype=filetype))
|
||||
results = library.search_library(FilterState.from_filetype(filetype))
|
||||
assert len(results.items) == num_of_filetype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("png", 2), ("apng", 1), ("ng", 0)])
|
||||
def test_filetype_return_one_filetype(file_mediatypes_library, filetype, num_of_filetype):
|
||||
results = file_mediatypes_library.search_library(FilterState(filetype=filetype))
|
||||
results = file_mediatypes_library.search_library(FilterState.from_filetype(filetype))
|
||||
assert len(results.items) == num_of_filetype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["mediatype", "num_of_mediatype"], [("plaintext", 2), ("image", 0)])
|
||||
def test_mediatype_search(library, mediatype, num_of_mediatype):
|
||||
results = library.search_library(FilterState(mediatype=mediatype))
|
||||
results = library.search_library(FilterState.from_mediatype(mediatype))
|
||||
assert len(results.items) == num_of_mediatype
|
||||
|
||||
Reference in New Issue
Block a user