feat: make search results more ergonomic (#498)

This commit is contained in:
yed
2024-09-13 07:34:27 +07:00
committed by GitHub
parent a8fdae8ebc
commit c15963868e
7 changed files with 87 additions and 57 deletions

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime, UTC
import shutil
from os import makedirs
@@ -87,6 +88,33 @@ def get_default_tags() -> tuple[Tag, ...]:
return archive_tag, favorite_tag
@dataclass(frozen=True)
class SearchResult:
"""Wrapper for search results.
:param total_count: total number of items for given query, might be different than len(items)
:param items: items for current page (size matches filter.page_size)
"""
total_count: int
items: list[Entry]
def __bool__(self) -> bool:
"""Boolean evaluation for the wrapper.
:return: True if there are items in the result.
"""
return self.total_count > 0
def __len__(self) -> int:
"""Return the total number of items in the result."""
return len(self.items)
def __getitem__(self, index: int) -> Entry:
"""Allow to access items via index directly on the wrapper."""
return self.items[index]
class Library:
"""Class for the Library object, and all CRUD operations made upon it."""
@@ -325,7 +353,7 @@ class Library:
def search_library(
self,
search: FilterState,
) -> tuple[int, list[Entry]]:
) -> SearchResult:
"""Filter library by search query.
:return: number of entries matching the query and one page of results.
@@ -401,11 +429,14 @@ class Library:
),
)
entries_ = list(session.scalars(statement).unique())
res = SearchResult(
total_count=count_all,
items=list(session.scalars(statement).unique()),
)
session.expunge_all()
return count_all, entries_
return res
def search_tags(
self,

View File

@@ -50,15 +50,15 @@ class DupeRegistry:
# The file is not in the library directory
continue
_, entries = self.library.search_library(
results = self.library.search_library(
FilterState(path=path_relative),
)
if not entries:
if not results:
# file not in library
continue
files.append(entries[0])
files.append(results[0])
if not len(files) > 1:
# only one file in the group, nothing to do

View File

@@ -1009,26 +1009,26 @@ class QtDriver(QObject):
self.main_window.statusbar.repaint()
start_time = time.time()
query_count, page_items = self.lib.search_library(self.filter)
results = self.lib.search_library(self.filter)
logger.info("items to render", count=len(page_items))
logger.info("items to render", count=len(results))
end_time = time.time()
if self.filter.summary:
self.main_window.statusbar.showMessage(
f'{query_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
f'{results.total_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
)
else:
self.main_window.statusbar.showMessage(
f"{query_count} Results ({format_timespan(end_time - start_time)})"
f"{results.total_count} Results ({format_timespan(end_time - start_time)})"
)
# update page content
self.frame_content = list(page_items)
self.frame_content = results.items
self.update_thumbs()
# update pagination
self.pages_count = math.ceil(query_count / self.filter.page_size)
self.pages_count = math.ceil(results.total_count / self.filter.page_size)
self.main_window.pagination.update_buttons(
self.pages_count, self.filter.page_index, emit=False
)

View File

@@ -487,7 +487,7 @@ class ItemThumb(FlowWidget):
# update the entry
self.driver.frame_content[idx] = self.lib.search_library(
FilterState(id=entry.id)
)[1][0]
).items[0]
self.driver.update_badges(update_items)

View File

@@ -63,12 +63,12 @@ def update_selected_entry(driver: "QtDriver"):
for grid_idx in driver.selected:
entry = driver.frame_content[grid_idx]
# reload entry
_, entries = driver.lib.search_library(FilterState(id=entry.id))
results = driver.lib.search_library(FilterState(id=entry.id))
logger.info(
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
"found item", entries=len(results), grid_idx=grid_idx, lookup_id=entry.id
)
assert entries, f"Entry not found: {entry.id}"
driver.frame_content[grid_idx] = entries[0]
assert results, f"Entry not found: {entry.id}"
driver.frame_content[grid_idx] = next(results)
class PreviewPanel(QWidget):
@@ -499,11 +499,14 @@ class PreviewPanel(QWidget):
# TODO - Entry reload is maybe not necessary
for grid_idx in self.driver.selected:
entry = self.driver.frame_content[grid_idx]
_, entries = self.lib.search_library(FilterState(id=entry.id))
results = self.lib.search_library(FilterState(id=entry.id))
logger.info(
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
"found item",
entries=len(results.items),
grid_idx=grid_idx,
lookup_id=entry.id,
)
self.driver.frame_content[grid_idx] = entries[0]
self.driver.frame_content[grid_idx] = results[0]
if len(self.driver.selected) == 1:
# 1 Selected Entry

View File

@@ -27,5 +27,5 @@ def test_refresh_missing_files(library: Library):
assert list(registry.fix_missing_files()) == [1, 2]
# `bar.md` should be relinked to new correct path
_, entries = library.search_library(FilterState(path="bar.md"))
assert entries[0].path == pathlib.Path("bar.md")
results = library.search_library(FilterState(path="bar.md"))
assert results[0].path == pathlib.Path("bar.md")

View File

@@ -57,16 +57,16 @@ def test_library_search(library, generate_tag, entry_full):
assert library.entries_count == 2
tag = list(entry_full.tags)[0]
query_count, items = library.search_library(
results = library.search_library(
FilterState(
tag=tag.name,
),
)
assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1
entry = items[0]
entry = results[0]
assert {x.name for x in entry.tags} == {
"foo",
}
@@ -94,9 +94,9 @@ def test_tag_search(library):
def test_get_entry(library, entry_min):
assert entry_min.id
cnt, entries = library.search_library(FilterState(id=entry_min.id))
assert len(entries) == cnt == 1
assert entries[0].tags
results = library.search_library(FilterState(id=entry_min.id))
assert len(results) == results.total_count == 1
assert results[0].tags
def test_entries_count(library):
@@ -105,14 +105,14 @@ def test_entries_count(library):
for x in range(10)
]
library.add_entries(entries)
matches, page = library.search_library(
results = library.search_library(
FilterState(
page_size=5,
)
)
assert matches == 12
assert len(page) == 5
assert results.total_count == 12
assert len(results) == 5
def test_add_field_to_entry(library):
@@ -146,8 +146,8 @@ def test_add_field_tag(library, entry_full, generate_tag):
library.add_field_tag(entry_full, tag, tag_field.type_key)
# Then
_, entries = library.search_library(FilterState(id=entry_full.id))
tag_field = entries[0].tag_box_fields[0]
results = library.search_library(FilterState(id=entry_full.id))
tag_field = results[0].tag_box_fields[0]
assert [x.name for x in tag_field.tags if x.name == tag_name]
@@ -179,15 +179,15 @@ def test_search_filter_extensions(library, is_exclude):
library.set_prefs(LibraryPrefs.EXTENSION_LIST, ["md"])
# When
query_count, items = library.search_library(
results = library.search_library(
FilterState(),
)
# Then
assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1
entry = items[0]
entry = results[0]
assert (entry.path.suffix == ".txt") == is_exclude
@@ -200,15 +200,15 @@ def test_search_library_case_insensitive(library):
tag = list(entry.tags)[0]
# When
query_count, items = library.search_library(
results = library.search_library(
FilterState(tag=tag.name.upper()),
)
# Then
assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1
assert items[0].id == entry.id
assert results[0].id == entry.id
def test_preferences(library):
@@ -231,11 +231,11 @@ def test_save_windows_path(library, generate_tag):
# library.add_tag(tag)
library.add_field_tag(entry, tag, create_field=True)
_, found = library.search_library(FilterState(tag=tag_name))
assert found
results = library.search_library(FilterState(tag=tag_name))
assert results
# path should be saved in posix format
assert str(found[0].path) == "foo/bar.txt"
assert str(results[0].path) == "foo/bar.txt"
def test_remove_entry_field(library, entry_full):
@@ -312,13 +312,13 @@ def test_mirror_entry_fields(library, entry_full):
entry_id = library.add_entries([target_entry])[0]
_, entries = library.search_library(FilterState(id=entry_id))
new_entry = entries[0]
results = library.search_library(FilterState(id=entry_id))
new_entry = results[0]
library.mirror_entry_fields(new_entry, entry_full)
_, entries = library.search_library(FilterState(id=entry_id))
entry = entries[0]
results = library.search_library(FilterState(id=entry_id))
entry = results[0]
assert len(entry.fields) == 4
assert {x.type_key for x in entry.fields} == {
@@ -350,13 +350,11 @@ def test_remove_tag_from_field(library, entry_full):
],
)
def test_search_file_name(library, query_name, has_result):
res_count, items = library.search_library(
results = library.search_library(
FilterState(name=query_name),
)
assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
assert results.total_count == has_result
@pytest.mark.parametrize(
@@ -369,13 +367,11 @@ def test_search_file_name(library, query_name, has_result):
],
)
def test_search_entry_id(library, query_name, has_result):
res_count, items = library.search_library(
results = library.search_library(
FilterState(id=query_name),
)
assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
assert results.total_count == has_result
def test_update_field_order(library, entry_full):