mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-01-31 15:19:10 +00:00
refactor: move macro processing to macro_parser.py
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, override
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
import structlog
|
||||
import toml
|
||||
@@ -13,6 +13,11 @@ from wcmatch import glob
|
||||
|
||||
from tagstudio.core.library.alchemy.fields import FieldID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tagstudio.core.library.alchemy.library import Library
|
||||
from tagstudio.core.library.alchemy.models import Tag
|
||||
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
SCHEMA_VERSION = "schema_version"
|
||||
@@ -62,12 +67,12 @@ class OnMissing(StrEnum):
|
||||
SKIP = "skip"
|
||||
|
||||
|
||||
class DataResult:
|
||||
class Instruction:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FieldResult(DataResult):
|
||||
class AddFieldInstruction(Instruction):
|
||||
def __init__(self, content, name: FieldID, field_type: str) -> None:
|
||||
super().__init__()
|
||||
self.content = content
|
||||
@@ -79,7 +84,7 @@ class FieldResult(DataResult):
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class TagResult(DataResult):
|
||||
class AddTagInstruction(Instruction):
|
||||
def __init__(
|
||||
self,
|
||||
tag_strings: list[str],
|
||||
@@ -103,14 +108,14 @@ class TagResult(DataResult):
|
||||
def parse_macro_file(
|
||||
macro_path: Path,
|
||||
filepath: Path,
|
||||
) -> list[DataResult]:
|
||||
) -> list[Instruction]:
|
||||
"""Parse a macro file and return a list of actions for TagStudio to perform.
|
||||
|
||||
Args:
|
||||
macro_path (Path): The full path of the macro file.
|
||||
filepath (Path): The filepath associated with Entry being operated upon.
|
||||
"""
|
||||
results: list[DataResult] = []
|
||||
results: list[Instruction] = []
|
||||
logger.info("[MacroParser] Parsing Macro", macro_path=macro_path, filepath=filepath)
|
||||
|
||||
if not macro_path.exists():
|
||||
@@ -191,20 +196,20 @@ def parse_macro_file(
|
||||
logger.info(f'[MacroParser] [{table_key}] "{ACTION}": {action}')
|
||||
|
||||
if action == Actions.IMPORT_DATA:
|
||||
results.extend(import_data(table, table_key, filepath))
|
||||
results.extend(_import_data(table, table_key, filepath))
|
||||
elif action == Actions.ADD_DATA:
|
||||
results.extend(add_data(table))
|
||||
results.extend(_add_data(table))
|
||||
|
||||
logger.info(results)
|
||||
return results
|
||||
|
||||
|
||||
def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[DataResult]:
|
||||
def _import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[Instruction]:
|
||||
"""Process an import_data instruction and return a list of DataResults.
|
||||
|
||||
Importing data refers to importing data from a source external to TagStudio or any macro.
|
||||
"""
|
||||
results: list[DataResult] = []
|
||||
results: list[Instruction] = []
|
||||
|
||||
source_format: str = str(table.get(SOURCE_FORMAT, ""))
|
||||
if not source_format:
|
||||
@@ -287,7 +292,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
|
||||
logger.warning(f"[MacroParser] [{table_key}] Empty template, skipping")
|
||||
continue
|
||||
for k in json_dump:
|
||||
template = fill_template(template, json_dump, k)
|
||||
template = _fill_template(template, json_dump, k)
|
||||
logger.info(f"[MacroParser] [{table_key}] Template filled!", template=template)
|
||||
content_value = template
|
||||
|
||||
@@ -345,7 +350,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
|
||||
|
||||
logger.info("[MacroParser] Found tags", tag_strings=tag_strings)
|
||||
results.append(
|
||||
TagResult(
|
||||
AddTagInstruction(
|
||||
tag_strings=tag_strings,
|
||||
use_context=use_context,
|
||||
strict=strict,
|
||||
@@ -356,7 +361,7 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
|
||||
|
||||
elif ts_type in (TEXT_LINE, TEXT_BOX, DATETIME):
|
||||
results.append(
|
||||
FieldResult(content=content_value, name=name, field_type=ts_type)
|
||||
AddFieldInstruction(content=content_value, name=name, field_type=ts_type)
|
||||
)
|
||||
else:
|
||||
logger.error('[MacroParser] [{table_key}] Unknown "{TS_TYPE}"', type=ts_type)
|
||||
@@ -364,12 +369,12 @@ def import_data(table: dict[str, Any], table_key: str, filepath: Path) -> list[D
|
||||
return results
|
||||
|
||||
|
||||
def add_data(table: dict[str, Any]) -> list[DataResult]:
|
||||
def _add_data(table: dict[str, Any]) -> list[Instruction]:
|
||||
"""Process an add_data instruction and return a list of DataResults.
|
||||
|
||||
Adding data refers to adding data defined inside a TagStudio macro, not from an external source.
|
||||
"""
|
||||
results: list[DataResult] = []
|
||||
results: list[Instruction] = []
|
||||
logger.error(table)
|
||||
for table_value in table.values():
|
||||
objects: list[dict[str, Any] | str] = []
|
||||
@@ -385,7 +390,7 @@ def add_data(table: dict[str, Any]) -> list[DataResult]:
|
||||
tag_strings: list[str] = obj.get(VALUE, [])
|
||||
logger.error(tag_strings)
|
||||
results.append(
|
||||
TagResult(
|
||||
AddTagInstruction(
|
||||
tag_strings=tag_strings,
|
||||
use_context=False,
|
||||
)
|
||||
@@ -400,12 +405,14 @@ def add_data(table: dict[str, Any]) -> list[DataResult]:
|
||||
continue
|
||||
|
||||
content_value: str = str(obj.get(VALUE, ""))
|
||||
results.append(FieldResult(content=content_value, name=name, field_type=ts_type))
|
||||
results.append(
|
||||
AddFieldInstruction(content=content_value, name=name, field_type=ts_type)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def fill_template(
|
||||
def _fill_template(
|
||||
template: str, table: dict[str, Any], table_key: str, template_key: str = ""
|
||||
) -> str:
|
||||
"""Replaces placeholder keys in a string with the value from that table.
|
||||
@@ -427,7 +434,74 @@ def fill_template(
|
||||
for v in value:
|
||||
normalized_key: str = f"{key}[{str(v)}]"
|
||||
template.replace(f"{{{normalized_key}}}", f"{{{str(v)}}}")
|
||||
template = fill_template(template, value, str(v), normalized_key)
|
||||
template = _fill_template(template, value, str(v), normalized_key)
|
||||
|
||||
value = str(value)
|
||||
return template.replace(f"{{{key}}}", f"{value}")
|
||||
|
||||
|
||||
def exec_instructions(library: "Library", entry_id: int, results: list[Instruction]) -> None:
|
||||
for result in results:
|
||||
if isinstance(result, AddTagInstruction):
|
||||
_exec_add_tag(library, entry_id, result)
|
||||
elif isinstance(result, AddFieldInstruction):
|
||||
_exec_add_field(library, entry_id, result)
|
||||
|
||||
|
||||
def _exec_add_tag(library: "Library", entry_id: int, result: AddTagInstruction):
|
||||
tag_ids: set[int] = set()
|
||||
for string in result.tag_strings:
|
||||
if not string.strip():
|
||||
continue
|
||||
string = string.replace("_", " ")
|
||||
base_and_parent = string.split("(")
|
||||
parent = ""
|
||||
base = base_and_parent[0].strip(" ")
|
||||
parent_results: list[int] = []
|
||||
if len(base_and_parent) > 1:
|
||||
parent = base_and_parent[1].split(")")[0]
|
||||
r: list[set[Tag]] = library.search_tags(name=parent, limit=-1)
|
||||
if len(r) > 0:
|
||||
parent_results = [t.id for t in r[0]]
|
||||
# NOTE: The following code overlaps with update_tags() in tag_search.py
|
||||
# Sort and prioritize the results
|
||||
tag_results: list[set[Tag]] = library.search_tags(name=base, limit=-1)
|
||||
results_0 = list(tag_results[0])
|
||||
results_0.sort(key=lambda tag: tag.name.lower())
|
||||
results_1 = list(tag_results[1])
|
||||
results_1.sort(key=lambda tag: tag.name.lower())
|
||||
raw_results = list(results_0 + results_1)
|
||||
priority_results: set[Tag] = set()
|
||||
|
||||
for tag in raw_results:
|
||||
if tag.name.lower().startswith(base.strip().lower()):
|
||||
priority_results.add(tag)
|
||||
all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
|
||||
r for r in raw_results if r not in priority_results
|
||||
]
|
||||
|
||||
if parent and parent_results:
|
||||
filtered_parents: list[Tag] = []
|
||||
for tag in all_results:
|
||||
for p_id in tag.parent_ids:
|
||||
if p_id in parent_results:
|
||||
filtered_parents.append(tag)
|
||||
break
|
||||
all_results = [t for t in all_results if t in filtered_parents]
|
||||
|
||||
final_tag: Tag | None = None
|
||||
if len(all_results) > 0:
|
||||
final_tag = all_results[0]
|
||||
if final_tag:
|
||||
tag_ids.add(final_tag.id)
|
||||
|
||||
if not tag_ids:
|
||||
return
|
||||
|
||||
library.add_tags_to_entries(entry_id, tag_ids)
|
||||
|
||||
|
||||
def _exec_add_field(library: "Library", entry_id: int, result: AddFieldInstruction):
|
||||
library.add_field_to_entry(
|
||||
entry_id, field_id=result.name, value=result.content, skip_on_exists=True
|
||||
)
|
||||
|
||||
@@ -62,13 +62,12 @@ from tagstudio.core.library.alchemy.enums import (
|
||||
SortingModeEnum,
|
||||
)
|
||||
from tagstudio.core.library.alchemy.library import Library, LibraryStatus
|
||||
from tagstudio.core.library.alchemy.models import Entry, Tag
|
||||
from tagstudio.core.library.alchemy.models import Entry
|
||||
from tagstudio.core.library.ignore import Ignore
|
||||
from tagstudio.core.library.refresh import RefreshTracker
|
||||
from tagstudio.core.macro_parser import (
|
||||
DataResult,
|
||||
FieldResult,
|
||||
TagResult,
|
||||
Instruction,
|
||||
exec_instructions,
|
||||
parse_macro_file,
|
||||
)
|
||||
from tagstudio.core.media_types import MediaCategories
|
||||
@@ -1139,79 +1138,8 @@ class QtDriver(DriverMixin, QObject):
|
||||
entry_id=entry.id,
|
||||
)
|
||||
|
||||
results: list[DataResult] = parse_macro_file(macro_path, full_path)
|
||||
for result in results:
|
||||
if isinstance(result, TagResult):
|
||||
tag_ids: set[int] = set()
|
||||
for string in result.tag_strings:
|
||||
if not string.strip():
|
||||
continue
|
||||
string = string.replace("_", " ")
|
||||
base_and_parent = string.split("(")
|
||||
parent = ""
|
||||
base = base_and_parent[0].strip(" ")
|
||||
parent_results: list[int] = []
|
||||
if len(base_and_parent) > 1:
|
||||
parent = base_and_parent[1].split(")")[0]
|
||||
r: list[set[Tag]] = self.lib.search_tags(name=parent, limit=-1)
|
||||
if len(r) > 0:
|
||||
parent_results = [t.id for t in r[0]]
|
||||
logger.warning("split", string=string, base=base, parent=parent)
|
||||
# NOTE: The following code overlaps with update_tags() in tag_search.py
|
||||
# Sort and prioritize the results
|
||||
tag_results: list[set[Tag]] = self.lib.search_tags(name=base, limit=-1)
|
||||
results_0 = list(tag_results[0])
|
||||
results_0.sort(key=lambda tag: tag.name.lower())
|
||||
results_1 = list(tag_results[1])
|
||||
results_1.sort(key=lambda tag: tag.name.lower())
|
||||
raw_results = list(results_0 + results_1)
|
||||
priority_results: set[Tag] = set()
|
||||
|
||||
for tag in raw_results:
|
||||
if tag.name.lower().startswith(base.strip().lower()):
|
||||
priority_results.add(tag)
|
||||
all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
|
||||
r for r in raw_results if r not in priority_results
|
||||
]
|
||||
|
||||
logger.warning("parents", parent=parent, parent_results=parent_results)
|
||||
if parent and parent_results:
|
||||
filtered_parents: list[Tag] = []
|
||||
for tag in all_results:
|
||||
logger.warning(
|
||||
"parent_ids",
|
||||
tag_id=tag.id,
|
||||
p_ids=tag.parent_ids,
|
||||
parent_results=parent_results,
|
||||
)
|
||||
for p_id in tag.parent_ids:
|
||||
if p_id in parent_results:
|
||||
filtered_parents.append(tag)
|
||||
break
|
||||
all_results = [t for t in all_results if t in filtered_parents]
|
||||
logger.warning(
|
||||
"removed", to_remove=filtered_parents, all_results=all_results
|
||||
)
|
||||
|
||||
final_tag: Tag | None = None
|
||||
if len(all_results) > 0:
|
||||
final_tag = all_results[0]
|
||||
# tag = self.lib.get_tag_by_name(string)
|
||||
if final_tag:
|
||||
tag_ids.add(final_tag.id)
|
||||
|
||||
if not tag_ids:
|
||||
continue
|
||||
|
||||
self.lib.add_tags_to_entries(entry_id, tag_ids)
|
||||
|
||||
elif isinstance(result, FieldResult):
|
||||
self.lib.add_field_to_entry(
|
||||
entry_id,
|
||||
field_id=result.name,
|
||||
value=result.content,
|
||||
skip_on_exists=True,
|
||||
)
|
||||
results: list[Instruction] = parse_macro_file(macro_path, full_path)
|
||||
exec_instructions(self.lib, entry_id, results)
|
||||
|
||||
@property
|
||||
def sorting_direction(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user