refactoring

This commit is contained in:
Jann Stute
2024-11-28 13:34:56 +01:00
parent 0d4afd47c8
commit eda7f5200f
10 changed files with 73 additions and 94 deletions

View File

@@ -1,9 +1,9 @@
import enum
from dataclasses import dataclass
from dataclasses import dataclass, replace
from pathlib import Path
from src.core.query_lang import AST as Query # noqa: N811
from src.core.query_lang import Parser
from src.core.query_lang import Constraint, ConstraintType, Parser
class TagColor(enum.IntEnum):
@@ -77,27 +77,17 @@ class FilterState:
# these should remain
page_index: int | None = None
page_size: int | None = None
search_mode: SearchMode = SearchMode.AND # TODO - actually implement this
search_mode: SearchMode = SearchMode.AND # TODO this can be removed?
# these should be erased on update
# tag name
tag: str | None = None
# tag ID
tag_id: int | None = None
# entry id
id: int | None = None
# whole path
path: Path | str | None = None
# file name
name: str | None = None
# file type
filetype: str | None = None
mediatype: str | None = None
# a generic query to be parsed
query: str | None = None
# Abstract Syntax Tree Of the current Search Query
ast: Query = None
def __post_init__(self):
@@ -105,34 +95,24 @@ class FilterState:
query = None
if self.query is not None:
query = self.query
elif self.tag is not None:
query = self.tag.strip()
self.tag = None
elif self.tag_id is not None:
query = f"tag_id:{self.tag_id}"
self.tag_id = None
elif self.path is not None:
if self.path is not None:
query = f"path:'{str(self.path).strip()}'"
self.query = query
if query:
if query is not None:
self.ast = Parser(query).parse()
else:
self.name = self.name and self.name.strip()
self.id = int(self.id) if str(self.id).isnumeric() else self.id
if self.page_index is None:
if self.page_index is None: # TODO QTLANG can this just be a default value?
self.page_index = 0
if self.page_size is None:
if self.page_size is None: # TODO QTLANG can this just be a default value?
self.page_size = 500
@property
def summary(self):
"""Show query summary."""
return self.query or self.tag or self.name or self.tag_id or self.path or self.id
return self.name or self.path or self.id
@property
def limit(self):
@@ -142,6 +122,37 @@ class FilterState:
def offset(self):
return self.page_size * self.page_index
@classmethod
def show_all(cls) -> "FilterState":
return FilterState()
@classmethod
def from_search_query(cls, search_query: str) -> "FilterState":
return cls(ast=Parser(search_query).parse())
@classmethod
def from_tag_id(cls, tag_id: int | str) -> "FilterState":
return cls(ast=Constraint(ConstraintType.TagID, str(tag_id), []))
@classmethod
def from_path(cls, path: Path | str) -> "FilterState":
return cls(ast=Constraint(ConstraintType.Path, str(path).strip(), []))
@classmethod
def from_mediatype(cls, mediatype: str) -> "FilterState":
return cls(ast=Constraint(ConstraintType.MediaType, mediatype, []))
@classmethod
def from_filetype(cls, filetype: str) -> "FilterState":
return cls(ast=Constraint(ConstraintType.FileType, filetype, []))
@classmethod
def from_tag_name(cls, tag_name: str) -> "FilterState":
return cls(ast=Constraint(ConstraintType.Tag, tag_name, []))
def with_page_size(self, page_size: int) -> "FilterState":
return replace(self, page_size=page_size)
class FieldTypeEnum(enum.Enum):
TEXT_LINE = "Text Line"

View File

@@ -425,28 +425,6 @@ class Library:
.outerjoin(TagAlias)
.where(SQLBoolExpressionBuilder().visit(search.ast))
)
elif search.tag:
SubtagAlias = aliased(Tag) # noqa: N806
statement = (
statement.join(Entry.tag_box_fields)
.join(TagBoxField.tags)
.outerjoin(Tag.aliases)
.outerjoin(SubtagAlias, Tag.subtags)
.where(
or_(
Tag.name.ilike(search.tag),
Tag.shorthand.ilike(search.tag),
TagAlias.name.ilike(search.tag),
SubtagAlias.name.ilike(search.tag),
)
)
)
elif search.tag_id:
statement = (
statement.join(Entry.tag_box_fields)
.join(TagBoxField.tags)
.where(Tag.id == search.tag_id)
)
elif search.id:
statement = statement.where(Entry.id == search.id)
elif search.name:
@@ -460,18 +438,6 @@ class Library:
elif search.path:
search_str = str(search.path).replace("*", "%")
statement = statement.where(Entry.path.ilike(search_str))
elif search.filetype:
statement = statement.where(Entry.suffix.ilike(f"{search.filetype}"))
elif search.mediatype:
extensions: set[str] = set[str]()
for media_cat in MediaCategories.ALL_CATEGORIES:
if search.mediatype == media_cat.name:
extensions = extensions | media_cat.extensions
break
# just need to map it to search db - suffixes do not have '.'
statement = statement.where(
Entry.suffix.in_(map(lambda x: x.replace(".", ""), extensions))
)
extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST)
@@ -490,7 +456,9 @@ class Library:
.options(selectinload(Tag.aliases), selectinload(Tag.subtags)),
)
query_count = select(func.count()).select_from(statement.alias("entries"))
query_count = select(func.count()).select_from(
statement.alias("entries")
) # TODO this should count the number of *unique* results
count_all: int = session.execute(query_count).scalar()
statement = statement.limit(search.limit).offset(search.offset)

View File

@@ -50,7 +50,7 @@ class DupeRegistry:
continue
results = self.library.search_library(
FilterState(path=path_relative),
FilterState.from_path(path_relative),
)
if not results:

View File

@@ -141,7 +141,7 @@ class QtDriver(DriverMixin, QObject):
self.rm: ResourceManager = ResourceManager()
self.args = args
self.frame_content = []
self.filter = FilterState()
self.filter = FilterState.show_all()
self.pages_count = 0
self.scrollbar_pos = 0
@@ -469,7 +469,7 @@ class QtDriver(DriverMixin, QObject):
]
self.item_thumbs: list[ItemThumb] = []
self.thumb_renderers: list[ThumbRenderer] = []
self.filter = FilterState()
self.filter = FilterState.show_all()
self.init_library_window()
path_result = self.evaluate_path(self.args.open)
@@ -510,13 +510,17 @@ class QtDriver(DriverMixin, QObject):
# Search Button
search_button: QPushButton = self.main_window.searchButton
search_button.clicked.connect(
lambda: self.filter_items(FilterState(query=self.main_window.searchField.text()))
lambda: self.filter_items(
FilterState.from_search_query(self.main_window.searchField.text())
)
)
# Search Field
search_field: QLineEdit = self.main_window.searchField
search_field.returnPressed.connect(
# TODO - parse search field for filters
lambda: self.filter_items(FilterState(query=self.main_window.searchField.text()))
lambda: self.filter_items(
FilterState.from_search_query(self.main_window.searchField.text())
)
)
# Search Type Selector
search_type_selector: QComboBox = self.main_window.comboBox_2
@@ -1142,7 +1146,7 @@ class QtDriver(DriverMixin, QObject):
def set_search_type(self, mode: SearchMode = SearchMode.AND):
self.filter_items(
FilterState(
FilterState( # TODO TSQLANG deal with this
search_mode=mode,
path=self.main_window.searchField.text(),
)

View File

@@ -455,7 +455,7 @@ class ItemThumb(FlowWidget):
# update the entry
self.driver.frame_content[idx] = self.lib.search_library(
FilterState(id=entry.id) # TODO TSQLANG don't search, get entry directly by id
).items[0]
).items[0] # self.lib.get_entry(entry.id)
self.driver.update_badges(update_items)

View File

@@ -99,7 +99,7 @@ class TagBoxWidget(FieldWidget):
tag_widget.on_click.connect(
lambda tag_id=tag.id: (
self.driver.main_window.searchField.setText(f"tag_id:{tag_id}"),
self.driver.filter_items(FilterState(query=f"tag_id:{tag_id}")),
self.driver.filter_items(FilterState.from_tag_id(tag_id)),
)
)

View File

@@ -26,5 +26,5 @@ def test_refresh_missing_files(library: Library):
assert list(registry.fix_missing_files()) == [1, 2]
# `bar.md` should be relinked to new correct path
results = library.search_library(FilterState(path="bar.md"))
results = library.search_library(FilterState.from_path("bar.md"))
assert results[0].path == pathlib.Path("bar.md")

View File

@@ -79,7 +79,7 @@ def test_library_state_update(qt_driver):
assert len(qt_driver.frame_content) == 2
# filter by tag
state = FilterState(tag="foo", page_size=10)
state = FilterState.from_tag_name("foo").with_page_size(10)
qt_driver.filter_items(state)
assert qt_driver.filter.page_size == 10
assert len(qt_driver.frame_content) == 1
@@ -94,7 +94,7 @@ def test_library_state_update(qt_driver):
assert list(entry.tags)[0].name == "foo"
# When state property is changed, previous one is overwritten
state = FilterState(path="*bar.md")
state = FilterState.from_path("*bar.md")
qt_driver.filter_items(state)
assert len(qt_driver.frame_content) == 1
entry = qt_driver.frame_content[0]

View File

@@ -2,13 +2,13 @@ import pytest
from src.core.library.alchemy.enums import FilterState
def test_filter_state_query():
def test_filter_state_query(): # TODO TSQLANG can this test be removed?
# Given
query = "tag:foo"
state = FilterState(query=query)
state = FilterState.from_search_query(query)
# When
assert state.tag == "foo"
# assert state.tag == "foo"
@pytest.mark.parametrize(
@@ -21,7 +21,7 @@ def test_filter_state_query():
("id", int),
],
)
def test_filter_state_attrs_compare(attribute, comparator):
def test_filter_state_attrs_compare(attribute, comparator): # TODO TSQLANG rework this test
# When
state = FilterState(**{attribute: "2"})

View File

@@ -115,7 +115,7 @@ def test_library_search(library, generate_tag, entry_full):
tag = list(entry_full.tags)[0]
results = library.search_library(
FilterState(tag=tag.name),
FilterState.from_tag_name(tag.name),
)
assert results.total_count == 1
@@ -159,11 +159,7 @@ def test_entries_count(library):
new_ids = library.add_entries(entries)
assert len(new_ids) == 10
results = library.search_library(
FilterState(
page_size=5,
)
)
results = library.search_library(FilterState.show_all().with_page_size(5))
assert results.total_count == 12
assert len(results) == 5
@@ -234,7 +230,7 @@ def test_search_filter_extensions(library, is_exclude):
# When
results = library.search_library(
FilterState(),
FilterState.show_all(),
)
# Then
@@ -255,7 +251,7 @@ def test_search_library_case_insensitive(library):
# When
results = library.search_library(
FilterState(tag=tag.name.upper()),
FilterState.from_tag_name(tag.name.upper()),
)
# Then
@@ -285,7 +281,7 @@ def test_save_windows_path(library, generate_tag):
# library.add_tag(tag)
library.add_field_tag(entry, tag, create_field=True)
results = library.search_library(FilterState(tag=tag_name))
results = library.search_library(FilterState.from_tag_name(tag_name))
assert results
# path should be saved in posix format
@@ -474,36 +470,36 @@ def test_library_prefs_multiple_identical_vals():
def test_path_search_glob_after(library: Library):
results = library.search_library(FilterState(path="foo*"))
results = library.search_library(FilterState.from_path("foo*"))
assert results.total_count == 1
assert len(results.items) == 1
def test_path_search_glob_in_front(library: Library):
results = library.search_library(FilterState(path="*bar.md"))
results = library.search_library(FilterState.from_path("*bar.md"))
assert results.total_count == 1
assert len(results.items) == 1
def test_path_search_glob_both_sides(library: Library):
results = library.search_library(FilterState(path="*one/two*"))
results = library.search_library(FilterState.from_path("*one/two*"))
assert results.total_count == 1
assert len(results.items) == 1
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("md", 1), ("txt", 1), ("png", 0)])
def test_filetype_search(library, filetype, num_of_filetype):
results = library.search_library(FilterState(filetype=filetype))
results = library.search_library(FilterState.from_filetype(filetype))
assert len(results.items) == num_of_filetype
@pytest.mark.parametrize(["filetype", "num_of_filetype"], [("png", 2), ("apng", 1), ("ng", 0)])
def test_filetype_return_one_filetype(file_mediatypes_library, filetype, num_of_filetype):
results = file_mediatypes_library.search_library(FilterState(filetype=filetype))
results = file_mediatypes_library.search_library(FilterState.from_filetype(filetype))
assert len(results.items) == num_of_filetype
@pytest.mark.parametrize(["mediatype", "num_of_mediatype"], [("plaintext", 2), ("image", 0)])
def test_mediatype_search(library, mediatype, num_of_mediatype):
results = library.search_library(FilterState(mediatype=mediatype))
results = library.search_library(FilterState.from_mediatype(mediatype))
assert len(results.items) == num_of_mediatype