From 25f421bca4a7afcd2796525aeef0ba0767b0d89b Mon Sep 17 00:00:00 2001 From: Travis Abendshien <46939827+CyanVoxel@users.noreply.github.com> Date: Wed, 5 Mar 2025 22:29:01 -0800 Subject: [PATCH] refactor: move macro processing to macro_parser.py --- src/tagstudio/core/macro_parser.py | 112 ++++++++++++++++++++++++----- src/tagstudio/qt/ts_qt.py | 82 ++------------------- 2 files changed, 98 insertions(+), 96 deletions(-) diff --git a/src/tagstudio/core/macro_parser.py b/src/tagstudio/core/macro_parser.py index ddd69450..15752412 100644 --- a/src/tagstudio/core/macro_parser.py +++ b/src/tagstudio/core/macro_parser.py @@ -5,7 +5,7 @@ import json from enum import StrEnum from pathlib import Path -from typing import Any, override +from typing import TYPE_CHECKING, Any, override import structlog import toml @@ -13,6 +13,11 @@ from wcmatch import glob from tagstudio.core.library.alchemy.fields import FieldID +if TYPE_CHECKING: + from tagstudio.core.library.alchemy.library import Library + from tagstudio.core.library.alchemy.models import Tag + + logger = structlog.get_logger(__name__) SCHEMA_VERSION = "schema_version" @@ -62,12 +67,12 @@ class OnMissing(StrEnum): SKIP = "skip" -class DataResult: +class Instruction: def __init__(self) -> None: pass -class FieldResult(DataResult): +class AddFieldInstruction(Instruction): def __init__(self, content, name: FieldID, field_type: str) -> None: super().__init__() self.content = content @@ -79,7 +84,7 @@ class FieldResult(DataResult): return str(self.content) -class TagResult(DataResult): +class AddTagInstruction(Instruction): def __init__( self, tag_strings: list[str], @@ -103,14 +108,14 @@ class TagResult(DataResult): def parse_macro_file( macro_path: Path, filepath: Path, -) -> list[DataResult]: +) -> list[Instruction]: """Parse a macro file and return a list of actions for TagStudio to perform. Args: macro_path (Path): The full path of the macro file. filepath (Path): The filepath associated with Entry being operated upon. """ - results: list[DataResult] = [] + results: list[Instruction] = [] logger.info("[MacroParser] Parsing Macro", macro_path=macro_path, filepath=filepath) if not macro_path.exists(): @@ -191,20 +196,20 @@ def parse_macro_file( logger.info(f'[MacroParser] [{table_key}] "{ACTION}": {action}') if action == Actions.IMPORT_DATA: - results.extend(import_data(table, table_key, filepath)) + results.extend(_import_data(table, table_key, filepath)) elif action == Actions.ADD_DATA: - results.extend(add_data(table)) + results.extend(_add_data(table)) logger.info(results) return results -def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[DataResult]: +def _import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[Instruction]: """Process an import_data instruction and return a list of DataResults. Importing data refers to importing data from a source external to TagStudio or any macro. """ - results: list[DataResult] = [] + results: list[Instruction] = [] source_format: str = str(table.get(SOURCE_FORMAT, "")) if not source_format: @@ -287,7 +292,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D logger.warning(f"[MacroParser] [{table_key}] Empty template, skipping") continue for k in json_dump: - template = fill_template(template, json_dump, k) + template = _fill_template(template, json_dump, k) logger.info(f"[MacroParser] [{table_key}] Template filled!", template=template) content_value = template @@ -345,7 +350,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D logger.info("[MacroParser] Found tags", tag_strings=tag_strings) results.append( - TagResult( + AddTagInstruction( tag_strings=tag_strings, use_context=use_context, strict=strict, @@ -356,7 +361,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D elif ts_type in (TEXT_LINE, TEXT_BOX, DATETIME): results.append( - FieldResult(content=content_value, name=name, field_type=ts_type) + AddFieldInstruction(content=content_value, name=name, field_type=ts_type) ) else: logger.error('[MacroParser] [{table_key}] Unknown "{TS_TYPE}"', type=ts_type) @@ -364,12 +369,12 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D return results -def add_data(table: dict[str, Any]) -> list[DataResult]: +def _add_data(table: dict[str, Any]) -> list[Instruction]: """Process an add_data instruction and return a list of DataResults. Adding data refers to adding data defined inside a TagStudio macro, not from an external source. """ - results: list[DataResult] = [] + results: list[Instruction] = [] logger.error(table) for table_value in table.values(): objects: list[dict[str, Any] | str] = [] @@ -385,7 +390,7 @@ def add_data(table: dict[str, Any]) -> list[DataResult]: tag_strings: list[str] = obj.get(VALUE, []) logger.error(tag_strings) results.append( - TagResult( + AddTagInstruction( tag_strings=tag_strings, use_context=False, ) @@ -400,12 +405,14 @@ def add_data(table: dict[str, Any]) -> list[DataResult]: continue content_value: str = str(obj.get(VALUE, "")) - results.append(FieldResult(content=content_value, name=name, field_type=ts_type)) + results.append( + AddFieldInstruction(content=content_value, name=name, field_type=ts_type) + ) return results -def fill_template( +def _fill_template( template: str, table: dict[str, Any], table_key: str, template_key: str = "" ) -> str: """Replaces placeholder keys in a string with the value from that table. @@ -427,7 +434,74 @@ def fill_template( for v in value: normalized_key: str = f"{key}[{str(v)}]" template.replace(f"{{{normalized_key}}}", f"{{{str(v)}}}") - template = fill_template(template, value, str(v), normalized_key) + template = _fill_template(template, value, str(v), normalized_key) value = str(value) return template.replace(f"{{{key}}}", f"{value}") + + +def exec_instructions(library: "Library", entry_id: int, results: list[Instruction]) -> None: + for result in results: + if isinstance(result, AddTagInstruction): + _exec_add_tag(library, entry_id, result) + elif isinstance(result, AddFieldInstruction): + _exec_add_field(library, entry_id, result) + + +def _exec_add_tag(library: "Library", entry_id: int, result: AddTagInstruction): + tag_ids: set[int] = set() + for string in result.tag_strings: + if not string.strip(): + continue + string = string.replace("_", " ") + base_and_parent = string.split("(") + parent = "" + base = base_and_parent[0].strip(" ") + parent_results: list[int] = [] + if len(base_and_parent) > 1: + parent = base_and_parent[1].split(")")[0] + r: list[set[Tag]] = library.search_tags(name=parent, limit=-1) + if len(r) > 0: + parent_results = [t.id for t in r[0]] + # NOTE: The following code overlaps with update_tags() in tag_search.py + # Sort and prioritize the results + tag_results: list[set[Tag]] = library.search_tags(name=base, limit=-1) + results_0 = list(tag_results[0]) + results_0.sort(key=lambda tag: tag.name.lower()) + results_1 = list(tag_results[1]) + results_1.sort(key=lambda tag: tag.name.lower()) + raw_results = list(results_0 + results_1) + priority_results: set[Tag] = set() + + for tag in raw_results: + if tag.name.lower().startswith(base.strip().lower()): + priority_results.add(tag) + all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [ + r for r in raw_results if r not in priority_results + ] + + if parent and parent_results: + filtered_parents: list[Tag] = [] + for tag in all_results: + for p_id in tag.parent_ids: + if p_id in parent_results: + filtered_parents.append(tag) + break + all_results = [t for t in all_results if t in filtered_parents] + + final_tag: Tag | None = None + if len(all_results) > 0: + final_tag = all_results[0] + if final_tag: + tag_ids.add(final_tag.id) + + if not tag_ids: + return + + library.add_tags_to_entries(entry_id, tag_ids) + + +def _exec_add_field(library: "Library", entry_id: int, result: AddFieldInstruction): + library.add_field_to_entry( + entry_id, field_id=result.name, value=result.content, skip_on_exists=True + ) diff --git a/src/tagstudio/qt/ts_qt.py b/src/tagstudio/qt/ts_qt.py index 80ece066..20e46af7 100644 --- a/src/tagstudio/qt/ts_qt.py +++ b/src/tagstudio/qt/ts_qt.py @@ -62,13 +62,12 @@ from tagstudio.core.library.alchemy.enums import ( SortingModeEnum, ) from tagstudio.core.library.alchemy.library import Library, LibraryStatus -from tagstudio.core.library.alchemy.models import Entry, Tag +from tagstudio.core.library.alchemy.models import Entry from tagstudio.core.library.ignore import Ignore from tagstudio.core.library.refresh import RefreshTracker from tagstudio.core.macro_parser import ( - DataResult, - FieldResult, - TagResult, + Instruction, + exec_instructions, parse_macro_file, ) from tagstudio.core.media_types import MediaCategories @@ -1139,79 +1138,8 @@ class QtDriver(DriverMixin, QObject): entry_id=entry.id, ) - results: list[DataResult] = parse_macro_file(macro_path, full_path) - for result in results: - if isinstance(result, TagResult): - tag_ids: set[int] = set() - for string in result.tag_strings: - if not string.strip(): - continue - string = string.replace("_", " ") - base_and_parent = string.split("(") - parent = "" - base = base_and_parent[0].strip(" ") - parent_results: list[int] = [] - if len(base_and_parent) > 1: - parent = base_and_parent[1].split(")")[0] - r: list[set[Tag]] = self.lib.search_tags(name=parent, limit=-1) - if len(r) > 0: - parent_results = [t.id for t in r[0]] - logger.warning("split", string=string, base=base, parent=parent) - # NOTE: The following code overlaps with update_tags() in tag_search.py - # Sort and prioritize the results - tag_results: list[set[Tag]] = self.lib.search_tags(name=base, limit=-1) - results_0 = list(tag_results[0]) - results_0.sort(key=lambda tag: tag.name.lower()) - results_1 = list(tag_results[1]) - results_1.sort(key=lambda tag: tag.name.lower()) - raw_results = list(results_0 + results_1) - priority_results: set[Tag] = set() - - for tag in raw_results: - if tag.name.lower().startswith(base.strip().lower()): - priority_results.add(tag) - all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [ - r for r in raw_results if r not in priority_results - ] - - logger.warning("parents", parent=parent, parent_results=parent_results) - if parent and parent_results: - filtered_parents: list[Tag] = [] - for tag in all_results: - logger.warning( - "parent_ids", - tag_id=tag.id, - p_ids=tag.parent_ids, - parent_results=parent_results, - ) - for p_id in tag.parent_ids: - if p_id in parent_results: - filtered_parents.append(tag) - break - all_results = [t for t in all_results if t in filtered_parents] - logger.warning( - "removed", to_remove=filtered_parents, all_results=all_results - ) - - final_tag: Tag | None = None - if len(all_results) > 0: - final_tag = all_results[0] - # tag = self.lib.get_tag_by_name(string) - if final_tag: - tag_ids.add(final_tag.id) - - if not tag_ids: - continue - - self.lib.add_tags_to_entries(entry_id, tag_ids) - - elif isinstance(result, FieldResult): - self.lib.add_field_to_entry( - entry_id, - field_id=result.name, - value=result.content, - skip_on_exists=True, - ) + results: list[Instruction] = parse_macro_file(macro_path, full_path) + exec_instructions(self.lib, entry_id, results) @property def sorting_direction(self) -> bool: