mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-01-31 23:29:10 +00:00
feat: make search results more ergonomic (#498)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, UTC
|
||||
import shutil
|
||||
from os import makedirs
|
||||
@@ -87,6 +88,33 @@ def get_default_tags() -> tuple[Tag, ...]:
|
||||
return archive_tag, favorite_tag
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SearchResult:
|
||||
"""Wrapper for search results.
|
||||
|
||||
:param total_count: total number of items for given query, might be different than len(items)
|
||||
:param items: items for current page (size matches filter.page_size)
|
||||
"""
|
||||
|
||||
total_count: int
|
||||
items: list[Entry]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Boolean evaluation for the wrapper.
|
||||
|
||||
:return: True if there are items in the result.
|
||||
"""
|
||||
return self.total_count > 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of items in the result."""
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, index: int) -> Entry:
|
||||
"""Allow to access items via index directly on the wrapper."""
|
||||
return self.items[index]
|
||||
|
||||
|
||||
class Library:
|
||||
"""Class for the Library object, and all CRUD operations made upon it."""
|
||||
|
||||
@@ -325,7 +353,7 @@ class Library:
|
||||
def search_library(
|
||||
self,
|
||||
search: FilterState,
|
||||
) -> tuple[int, list[Entry]]:
|
||||
) -> SearchResult:
|
||||
"""Filter library by search query.
|
||||
|
||||
:return: number of entries matching the query and one page of results.
|
||||
@@ -401,11 +429,14 @@ class Library:
|
||||
),
|
||||
)
|
||||
|
||||
entries_ = list(session.scalars(statement).unique())
|
||||
res = SearchResult(
|
||||
total_count=count_all,
|
||||
items=list(session.scalars(statement).unique()),
|
||||
)
|
||||
|
||||
session.expunge_all()
|
||||
|
||||
return count_all, entries_
|
||||
return res
|
||||
|
||||
def search_tags(
|
||||
self,
|
||||
|
||||
@@ -50,15 +50,15 @@ class DupeRegistry:
|
||||
# The file is not in the library directory
|
||||
continue
|
||||
|
||||
_, entries = self.library.search_library(
|
||||
results = self.library.search_library(
|
||||
FilterState(path=path_relative),
|
||||
)
|
||||
|
||||
if not entries:
|
||||
if not results:
|
||||
# file not in library
|
||||
continue
|
||||
|
||||
files.append(entries[0])
|
||||
files.append(results[0])
|
||||
|
||||
if not len(files) > 1:
|
||||
# only one file in the group, nothing to do
|
||||
|
||||
@@ -1009,26 +1009,26 @@ class QtDriver(QObject):
|
||||
self.main_window.statusbar.repaint()
|
||||
start_time = time.time()
|
||||
|
||||
query_count, page_items = self.lib.search_library(self.filter)
|
||||
results = self.lib.search_library(self.filter)
|
||||
|
||||
logger.info("items to render", count=len(page_items))
|
||||
logger.info("items to render", count=len(results))
|
||||
|
||||
end_time = time.time()
|
||||
if self.filter.summary:
|
||||
self.main_window.statusbar.showMessage(
|
||||
f'{query_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
|
||||
f'{results.total_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
|
||||
)
|
||||
else:
|
||||
self.main_window.statusbar.showMessage(
|
||||
f"{query_count} Results ({format_timespan(end_time - start_time)})"
|
||||
f"{results.total_count} Results ({format_timespan(end_time - start_time)})"
|
||||
)
|
||||
|
||||
# update page content
|
||||
self.frame_content = list(page_items)
|
||||
self.frame_content = results.items
|
||||
self.update_thumbs()
|
||||
|
||||
# update pagination
|
||||
self.pages_count = math.ceil(query_count / self.filter.page_size)
|
||||
self.pages_count = math.ceil(results.total_count / self.filter.page_size)
|
||||
self.main_window.pagination.update_buttons(
|
||||
self.pages_count, self.filter.page_index, emit=False
|
||||
)
|
||||
|
||||
@@ -487,7 +487,7 @@ class ItemThumb(FlowWidget):
|
||||
# update the entry
|
||||
self.driver.frame_content[idx] = self.lib.search_library(
|
||||
FilterState(id=entry.id)
|
||||
)[1][0]
|
||||
).items[0]
|
||||
|
||||
self.driver.update_badges(update_items)
|
||||
|
||||
|
||||
@@ -63,12 +63,12 @@ def update_selected_entry(driver: "QtDriver"):
|
||||
for grid_idx in driver.selected:
|
||||
entry = driver.frame_content[grid_idx]
|
||||
# reload entry
|
||||
_, entries = driver.lib.search_library(FilterState(id=entry.id))
|
||||
results = driver.lib.search_library(FilterState(id=entry.id))
|
||||
logger.info(
|
||||
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
|
||||
"found item", entries=len(results), grid_idx=grid_idx, lookup_id=entry.id
|
||||
)
|
||||
assert entries, f"Entry not found: {entry.id}"
|
||||
driver.frame_content[grid_idx] = entries[0]
|
||||
assert results, f"Entry not found: {entry.id}"
|
||||
driver.frame_content[grid_idx] = next(results)
|
||||
|
||||
|
||||
class PreviewPanel(QWidget):
|
||||
@@ -499,11 +499,14 @@ class PreviewPanel(QWidget):
|
||||
# TODO - Entry reload is maybe not necessary
|
||||
for grid_idx in self.driver.selected:
|
||||
entry = self.driver.frame_content[grid_idx]
|
||||
_, entries = self.lib.search_library(FilterState(id=entry.id))
|
||||
results = self.lib.search_library(FilterState(id=entry.id))
|
||||
logger.info(
|
||||
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
|
||||
"found item",
|
||||
entries=len(results.items),
|
||||
grid_idx=grid_idx,
|
||||
lookup_id=entry.id,
|
||||
)
|
||||
self.driver.frame_content[grid_idx] = entries[0]
|
||||
self.driver.frame_content[grid_idx] = results[0]
|
||||
|
||||
if len(self.driver.selected) == 1:
|
||||
# 1 Selected Entry
|
||||
|
||||
@@ -27,5 +27,5 @@ def test_refresh_missing_files(library: Library):
|
||||
assert list(registry.fix_missing_files()) == [1, 2]
|
||||
|
||||
# `bar.md` should be relinked to new correct path
|
||||
_, entries = library.search_library(FilterState(path="bar.md"))
|
||||
assert entries[0].path == pathlib.Path("bar.md")
|
||||
results = library.search_library(FilterState(path="bar.md"))
|
||||
assert results[0].path == pathlib.Path("bar.md")
|
||||
|
||||
@@ -57,16 +57,16 @@ def test_library_search(library, generate_tag, entry_full):
|
||||
assert library.entries_count == 2
|
||||
tag = list(entry_full.tags)[0]
|
||||
|
||||
query_count, items = library.search_library(
|
||||
results = library.search_library(
|
||||
FilterState(
|
||||
tag=tag.name,
|
||||
),
|
||||
)
|
||||
|
||||
assert query_count == 1
|
||||
assert len(items) == 1
|
||||
assert results.total_count == 1
|
||||
assert len(results) == 1
|
||||
|
||||
entry = items[0]
|
||||
entry = results[0]
|
||||
assert {x.name for x in entry.tags} == {
|
||||
"foo",
|
||||
}
|
||||
@@ -94,9 +94,9 @@ def test_tag_search(library):
|
||||
|
||||
def test_get_entry(library, entry_min):
|
||||
assert entry_min.id
|
||||
cnt, entries = library.search_library(FilterState(id=entry_min.id))
|
||||
assert len(entries) == cnt == 1
|
||||
assert entries[0].tags
|
||||
results = library.search_library(FilterState(id=entry_min.id))
|
||||
assert len(results) == results.total_count == 1
|
||||
assert results[0].tags
|
||||
|
||||
|
||||
def test_entries_count(library):
|
||||
@@ -105,14 +105,14 @@ def test_entries_count(library):
|
||||
for x in range(10)
|
||||
]
|
||||
library.add_entries(entries)
|
||||
matches, page = library.search_library(
|
||||
results = library.search_library(
|
||||
FilterState(
|
||||
page_size=5,
|
||||
)
|
||||
)
|
||||
|
||||
assert matches == 12
|
||||
assert len(page) == 5
|
||||
assert results.total_count == 12
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
def test_add_field_to_entry(library):
|
||||
@@ -146,8 +146,8 @@ def test_add_field_tag(library, entry_full, generate_tag):
|
||||
library.add_field_tag(entry_full, tag, tag_field.type_key)
|
||||
|
||||
# Then
|
||||
_, entries = library.search_library(FilterState(id=entry_full.id))
|
||||
tag_field = entries[0].tag_box_fields[0]
|
||||
results = library.search_library(FilterState(id=entry_full.id))
|
||||
tag_field = results[0].tag_box_fields[0]
|
||||
assert [x.name for x in tag_field.tags if x.name == tag_name]
|
||||
|
||||
|
||||
@@ -179,15 +179,15 @@ def test_search_filter_extensions(library, is_exclude):
|
||||
library.set_prefs(LibraryPrefs.EXTENSION_LIST, ["md"])
|
||||
|
||||
# When
|
||||
query_count, items = library.search_library(
|
||||
results = library.search_library(
|
||||
FilterState(),
|
||||
)
|
||||
|
||||
# Then
|
||||
assert query_count == 1
|
||||
assert len(items) == 1
|
||||
assert results.total_count == 1
|
||||
assert len(results) == 1
|
||||
|
||||
entry = items[0]
|
||||
entry = results[0]
|
||||
assert (entry.path.suffix == ".txt") == is_exclude
|
||||
|
||||
|
||||
@@ -200,15 +200,15 @@ def test_search_library_case_insensitive(library):
|
||||
tag = list(entry.tags)[0]
|
||||
|
||||
# When
|
||||
query_count, items = library.search_library(
|
||||
results = library.search_library(
|
||||
FilterState(tag=tag.name.upper()),
|
||||
)
|
||||
|
||||
# Then
|
||||
assert query_count == 1
|
||||
assert len(items) == 1
|
||||
assert results.total_count == 1
|
||||
assert len(results) == 1
|
||||
|
||||
assert items[0].id == entry.id
|
||||
assert results[0].id == entry.id
|
||||
|
||||
|
||||
def test_preferences(library):
|
||||
@@ -231,11 +231,11 @@ def test_save_windows_path(library, generate_tag):
|
||||
# library.add_tag(tag)
|
||||
library.add_field_tag(entry, tag, create_field=True)
|
||||
|
||||
_, found = library.search_library(FilterState(tag=tag_name))
|
||||
assert found
|
||||
results = library.search_library(FilterState(tag=tag_name))
|
||||
assert results
|
||||
|
||||
# path should be saved in posix format
|
||||
assert str(found[0].path) == "foo/bar.txt"
|
||||
assert str(results[0].path) == "foo/bar.txt"
|
||||
|
||||
|
||||
def test_remove_entry_field(library, entry_full):
|
||||
@@ -312,13 +312,13 @@ def test_mirror_entry_fields(library, entry_full):
|
||||
|
||||
entry_id = library.add_entries([target_entry])[0]
|
||||
|
||||
_, entries = library.search_library(FilterState(id=entry_id))
|
||||
new_entry = entries[0]
|
||||
results = library.search_library(FilterState(id=entry_id))
|
||||
new_entry = results[0]
|
||||
|
||||
library.mirror_entry_fields(new_entry, entry_full)
|
||||
|
||||
_, entries = library.search_library(FilterState(id=entry_id))
|
||||
entry = entries[0]
|
||||
results = library.search_library(FilterState(id=entry_id))
|
||||
entry = results[0]
|
||||
|
||||
assert len(entry.fields) == 4
|
||||
assert {x.type_key for x in entry.fields} == {
|
||||
@@ -350,13 +350,11 @@ def test_remove_tag_from_field(library, entry_full):
|
||||
],
|
||||
)
|
||||
def test_search_file_name(library, query_name, has_result):
|
||||
res_count, items = library.search_library(
|
||||
results = library.search_library(
|
||||
FilterState(name=query_name),
|
||||
)
|
||||
|
||||
assert (
|
||||
res_count == has_result
|
||||
), f"mismatch with query: {query_name}, result: {res_count}"
|
||||
assert results.total_count == has_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -369,13 +367,11 @@ def test_search_file_name(library, query_name, has_result):
|
||||
],
|
||||
)
|
||||
def test_search_entry_id(library, query_name, has_result):
|
||||
res_count, items = library.search_library(
|
||||
results = library.search_library(
|
||||
FilterState(id=query_name),
|
||||
)
|
||||
|
||||
assert (
|
||||
res_count == has_result
|
||||
), f"mismatch with query: {query_name}, result: {res_count}"
|
||||
assert results.total_count == has_result
|
||||
|
||||
|
||||
def test_update_field_order(library, entry_full):
|
||||
|
||||
Reference in New Issue
Block a user