diff --git a/tagstudio/src/qt/widgets/preview/field_containers.py b/tagstudio/src/qt/widgets/preview/field_containers.py index 5b8e5d9c..851559fb 100644 --- a/tagstudio/src/qt/widgets/preview/field_containers.py +++ b/tagstudio/src/qt/widgets/preview/field_containers.py @@ -151,30 +151,71 @@ class FieldContainers(QWidget): cats: dict[Tag | None, set[Tag]] = {} cats[None] = set() - # Initialize all categories from parents + base_tag_ids: set[int] = {x.id for x in tags} + exhausted: set[int] = set() + cluster_map: dict[int, set[int]] = {} + + def add_to_cluster(tag_id: int, p_ids: list[int] | None = None): + """Maps a Tag's subtag's ID's back to it's parent Tag's ID. + + Example: + Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: + "Cartoon Network" -> Johnny Bravo, + "Character" -> "Johnny Bravo", + "TV" -> Johnny Bravo" + """ + tag_obj = self.lib.get_tag(tag_id) # Get full object + if p_ids is None: + p_ids = tag_obj.subtag_ids + + for p_id in p_ids: + if cluster_map.get(p_id) is None: + cluster_map[p_id] = set() + # If the p_tag has p_tags of its own, recursively link those to the original Tag. + if tag.id not in cluster_map[p_id]: + cluster_map[p_id].add(tag.id) + p_tag = self.lib.get_tag(p_id) # Get full object + if p_tag.subtag_ids: + add_to_cluster( + tag_id, + [sub_id for sub_id in p_tag.subtag_ids if sub_id != tag_id], + ) + exhausted.add(p_id) + exhausted.add(tag_id) + for tag in tags: - for p_tag in list(tag.subtags) + [tag]: - logger.info(f"[{tag.name}] is {p_tag.name} a category? ({p_tag.is_category})") - if p_tag.is_category: - cats[p_tag] = set() + add_to_cluster(tag.id) + + logger.info("Entry Cluster", entry_cluster=exhausted) + logger.info("Cluster Map", cluster_map=cluster_map) + + # Initialize all categories from parents. + tags_ = {self.lib.get_tag(x) for x in exhausted} + for tag in tags_: + if tag.is_category: + cats[tag] = set() logger.info("Blank Tag Categories", cats=cats) - # Add tags to any applicable categories - for tag in tags: - is_general = True - for p_tag in list(cats.keys()): - logger.info(f"[{tag.name}] Checking category tag key {p_tag}") - if not p_tag: - pass - elif p_tag in tag.subtags: - cats[p_tag].add(tag) - is_general = False - elif tag == p_tag: - cats[p_tag].add(tag) - is_general = False - pass - if is_general: - cats[None].add(tag) + # Add tags to any applicable categories. + added_ids: set[int] = set() + for key in cats: + logger.info("Checking category tag key", key=key) + + if key: + logger.info("Key cluster:", key=key, cluster=cluster_map.get(key.id)) + + if final_tags := cluster_map.get(key.id): + cats[key] = {self.lib.get_tag(x) for x in final_tags if x in base_tag_ids} + added_ids = added_ids.union({x for x in final_tags if x in base_tag_ids}) + + # Add remaining tags to None key (general case). + cats[None] = {self.lib.get_tag(x) for x in base_tag_ids if x not in added_ids} + logger.info( + f"[{key}] Key cluster: None, general case!", + general_tags=cats[key], + added=added_ids, + base_tag_ids=base_tag_ids, + ) # Remove unused categories empty: list[Tag] = []