refactor: move macro processing to macro_parser.py

This commit is contained in:
Travis Abendshien
2025-03-05 22:29:01 -08:00
parent 3221aafdfc
commit 25f421bca4
2 changed files with 98 additions and 96 deletions

View File

@@ -5,7 +5,7 @@
import json
from enum import StrEnum
from pathlib import Path
from typing import Any, override
from typing import TYPE_CHECKING, Any, override
import structlog
import toml
@@ -13,6 +13,11 @@ from wcmatch import glob
from tagstudio.core.library.alchemy.fields import FieldID
if TYPE_CHECKING:
from tagstudio.core.library.alchemy.library import Library
from tagstudio.core.library.alchemy.models import Tag
logger = structlog.get_logger(__name__)
SCHEMA_VERSION = "schema_version"
@@ -62,12 +67,12 @@ class OnMissing(StrEnum):
SKIP = "skip"
class DataResult:
class Instruction:
def __init__(self) -> None:
pass
class FieldResult(DataResult):
class AddFieldInstruction(Instruction):
def __init__(self, content, name: FieldID, field_type: str) -> None:
super().__init__()
self.content = content
@@ -79,7 +84,7 @@ class FieldResult(DataResult):
return str(self.content)
class TagResult(DataResult):
class AddTagInstruction(Instruction):
def __init__(
self,
tag_strings: list[str],
@@ -103,14 +108,14 @@ class TagResult(DataResult):
def parse_macro_file(
macro_path: Path,
filepath: Path,
) -> list[DataResult]:
) -> list[Instruction]:
"""Parse a macro file and return a list of actions for TagStudio to perform.
Args:
macro_path (Path): The full path of the macro file.
filepath (Path): The filepath associated with Entry being operated upon.
"""
results: list[DataResult] = []
results: list[Instruction] = []
logger.info("[MacroParser] Parsing Macro", macro_path=macro_path, filepath=filepath)
if not macro_path.exists():
@@ -191,20 +196,20 @@ def parse_macro_file(
logger.info(f'[MacroParser] [{table_key}] "{ACTION}": {action}')
if action == Actions.IMPORT_DATA:
results.extend(import_data(table, table_key, filepath))
results.extend(_import_data(table, table_key, filepath))
elif action == Actions.ADD_DATA:
results.extend(add_data(table))
results.extend(_add_data(table))
logger.info(results)
return results
def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[DataResult]:
def _import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[Instruction]:
"""Process an import_data instruction and return a list of DataResults.
Importing data refers to importing data from a source external to TagStudio or any macro.
"""
results: list[DataResult] = []
results: list[Instruction] = []
source_format: str = str(table.get(SOURCE_FORMAT, ""))
if not source_format:
@@ -287,7 +292,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
logger.warning(f"[MacroParser] [{table_key}] Empty template, skipping")
continue
for k in json_dump:
template = fill_template(template, json_dump, k)
template = _fill_template(template, json_dump, k)
logger.info(f"[MacroParser] [{table_key}] Template filled!", template=template)
content_value = template
@@ -345,7 +350,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
logger.info("[MacroParser] Found tags", tag_strings=tag_strings)
results.append(
TagResult(
AddTagInstruction(
tag_strings=tag_strings,
use_context=use_context,
strict=strict,
@@ -356,7 +361,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
elif ts_type in (TEXT_LINE, TEXT_BOX, DATETIME):
results.append(
FieldResult(content=content_value, name=name, field_type=ts_type)
AddFieldInstruction(content=content_value, name=name, field_type=ts_type)
)
else:
logger.error('[MacroParser] [{table_key}] Unknown "{TS_TYPE}"', type=ts_type)
@@ -364,12 +369,12 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
return results
def add_data(table: dict[str, Any]) -> list[DataResult]:
def _add_data(table: dict[str, Any]) -> list[Instruction]:
"""Process an add_data instruction and return a list of DataResults.
Adding data refers to adding data defined inside a TagStudio macro, not from an external source.
"""
results: list[DataResult] = []
results: list[Instruction] = []
logger.error(table)
for table_value in table.values():
objects: list[dict[str, Any] | str] = []
@@ -385,7 +390,7 @@ def add_data(table: dict[str, Any]) -> list[DataResult]:
tag_strings: list[str] = obj.get(VALUE, [])
logger.error(tag_strings)
results.append(
TagResult(
AddTagInstruction(
tag_strings=tag_strings,
use_context=False,
)
@@ -400,12 +405,14 @@ def add_data(table: dict[str, Any]) -> list[DataResult]:
continue
content_value: str = str(obj.get(VALUE, ""))
results.append(FieldResult(content=content_value, name=name, field_type=ts_type))
results.append(
AddFieldInstruction(content=content_value, name=name, field_type=ts_type)
)
return results
def fill_template(
def _fill_template(
template: str, table: dict[str, Any], table_key: str, template_key: str = ""
) -> str:
"""Replaces placeholder keys in a string with the value from that table.
@@ -427,7 +434,74 @@ def fill_template(
for v in value:
normalized_key: str = f"{key}[{str(v)}]"
template.replace(f"{{{normalized_key}}}", f"{{{str(v)}}}")
template = fill_template(template, value, str(v), normalized_key)
template = _fill_template(template, value, str(v), normalized_key)
value = str(value)
return template.replace(f"{{{key}}}", f"{value}")
def exec_instructions(library: "Library", entry_id: int, results: list[Instruction]) -> None:
for result in results:
if isinstance(result, AddTagInstruction):
_exec_add_tag(library, entry_id, result)
elif isinstance(result, AddFieldInstruction):
_exec_add_field(library, entry_id, result)
def _exec_add_tag(library: "Library", entry_id: int, result: AddTagInstruction):
tag_ids: set[int] = set()
for string in result.tag_strings:
if not string.strip():
continue
string = string.replace("_", " ")
base_and_parent = string.split("(")
parent = ""
base = base_and_parent[0].strip(" ")
parent_results: list[int] = []
if len(base_and_parent) > 1:
parent = base_and_parent[1].split(")")[0]
r: list[set[Tag]] = library.search_tags(name=parent, limit=-1)
if len(r) > 0:
parent_results = [t.id for t in r[0]]
# NOTE: The following code overlaps with update_tags() in tag_search.py
# Sort and prioritize the results
tag_results: list[set[Tag]] = library.search_tags(name=base, limit=-1)
results_0 = list(tag_results[0])
results_0.sort(key=lambda tag: tag.name.lower())
results_1 = list(tag_results[1])
results_1.sort(key=lambda tag: tag.name.lower())
raw_results = list(results_0 + results_1)
priority_results: set[Tag] = set()
for tag in raw_results:
if tag.name.lower().startswith(base.strip().lower()):
priority_results.add(tag)
all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
r for r in raw_results if r not in priority_results
]
if parent and parent_results:
filtered_parents: list[Tag] = []
for tag in all_results:
for p_id in tag.parent_ids:
if p_id in parent_results:
filtered_parents.append(tag)
break
all_results = [t for t in all_results if t in filtered_parents]
final_tag: Tag | None = None
if len(all_results) > 0:
final_tag = all_results[0]
if final_tag:
tag_ids.add(final_tag.id)
if not tag_ids:
return
library.add_tags_to_entries(entry_id, tag_ids)
def _exec_add_field(library: "Library", entry_id: int, result: AddFieldInstruction):
library.add_field_to_entry(
entry_id, field_id=result.name, value=result.content, skip_on_exists=True
)

View File

@@ -62,13 +62,12 @@ from tagstudio.core.library.alchemy.enums import (
SortingModeEnum,
)
from tagstudio.core.library.alchemy.library import Library, LibraryStatus
from tagstudio.core.library.alchemy.models import Entry, Tag
from tagstudio.core.library.alchemy.models import Entry
from tagstudio.core.library.ignore import Ignore
from tagstudio.core.library.refresh import RefreshTracker
from tagstudio.core.macro_parser import (
DataResult,
FieldResult,
TagResult,
Instruction,
exec_instructions,
parse_macro_file,
)
from tagstudio.core.media_types import MediaCategories
@@ -1139,79 +1138,8 @@ class QtDriver(DriverMixin, QObject):
entry_id=entry.id,
)
results: list[DataResult] = parse_macro_file(macro_path, full_path)
for result in results:
if isinstance(result, TagResult):
tag_ids: set[int] = set()
for string in result.tag_strings:
if not string.strip():
continue
string = string.replace("_", " ")
base_and_parent = string.split("(")
parent = ""
base = base_and_parent[0].strip(" ")
parent_results: list[int] = []
if len(base_and_parent) > 1:
parent = base_and_parent[1].split(")")[0]
r: list[set[Tag]] = self.lib.search_tags(name=parent, limit=-1)
if len(r) > 0:
parent_results = [t.id for t in r[0]]
logger.warning("split", string=string, base=base, parent=parent)
# NOTE: The following code overlaps with update_tags() in tag_search.py
# Sort and prioritize the results
tag_results: list[set[Tag]] = self.lib.search_tags(name=base, limit=-1)
results_0 = list(tag_results[0])
results_0.sort(key=lambda tag: tag.name.lower())
results_1 = list(tag_results[1])
results_1.sort(key=lambda tag: tag.name.lower())
raw_results = list(results_0 + results_1)
priority_results: set[Tag] = set()
for tag in raw_results:
if tag.name.lower().startswith(base.strip().lower()):
priority_results.add(tag)
all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
r for r in raw_results if r not in priority_results
]
logger.warning("parents", parent=parent, parent_results=parent_results)
if parent and parent_results:
filtered_parents: list[Tag] = []
for tag in all_results:
logger.warning(
"parent_ids",
tag_id=tag.id,
p_ids=tag.parent_ids,
parent_results=parent_results,
)
for p_id in tag.parent_ids:
if p_id in parent_results:
filtered_parents.append(tag)
break
all_results = [t for t in all_results if t in filtered_parents]
logger.warning(
"removed", to_remove=filtered_parents, all_results=all_results
)
final_tag: Tag | None = None
if len(all_results) > 0:
final_tag = all_results[0]
# tag = self.lib.get_tag_by_name(string)
if final_tag:
tag_ids.add(final_tag.id)
if not tag_ids:
continue
self.lib.add_tags_to_entries(entry_id, tag_ids)
elif isinstance(result, FieldResult):
self.lib.add_field_to_entry(
entry_id,
field_id=result.name,
value=result.content,
skip_on_exists=True,
)
results: list[Instruction] = parse_macro_file(macro_path, full_path)
exec_instructions(self.lib, entry_id, results)
@property
def sorting_direction(self) -> bool: