6 Commits

Author SHA1 Message Date
KnugiHK
1694ae7dd9 Update utility.py 2026-01-24 01:47:45 +08:00
KnugiHK
f05e0d3451 Refactor incremental_merge 2026-01-24 01:33:18 +08:00
KnugiHK
0c5f2b7f13 Add a comment on SQLi in get_chat_condition 2026-01-24 01:19:55 +08:00
KnugiHK
db01d05263 Refactor get_chat_condition to increase maintainability 2026-01-24 00:50:06 +08:00
KnugiHK
2e7953f4ca Add unit test for get_chat_condition 2026-01-24 00:03:21 +08:00
KnugiHK
95a52231be Fix the returning string for empty filter list 2026-01-24 00:03:08 +08:00
2 changed files with 372 additions and 94 deletions

View File

@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
from enum import IntEnum from enum import IntEnum
from tqdm import tqdm from tqdm import tqdm
from Whatsapp_Chat_Exporter.data_model import ChatCollection, ChatStore, Timing from Whatsapp_Chat_Exporter.data_model import ChatCollection, ChatStore, Timing
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union, Any
try: try:
from enum import StrEnum, IntEnum from enum import StrEnum, IntEnum
except ImportError: except ImportError:
@@ -257,85 +257,230 @@ def import_from_json(json_file: str, data: ChatCollection):
logger.info(f"Imported {total_row_number} chats from JSON in {convert_time_unit(total_time)}{CLEAR_LINE}") logger.info(f"Imported {total_row_number} chats from JSON in {convert_time_unit(total_time)}{CLEAR_LINE}")
def incremental_merge(source_dir: str, target_dir: str, media_dir: str, pretty_print_json: int, avoid_encoding_json: bool): class IncrementalMerger:
"""Merges JSON files from the source directory into the target directory. """Handles incremental merging of WhatsApp chat exports."""
def __init__(self, pretty_print_json: int, avoid_encoding_json: bool):
"""Initialize the merger with JSON formatting options.
Args: Args:
source_dir (str): The path to the source directory containing JSON files. pretty_print_json: JSON indentation level.
target_dir (str): The path to the target directory to merge into. avoid_encoding_json: Whether to avoid ASCII encoding.
media_dir (str): The path to the media directory. """
self.pretty_print_json = pretty_print_json
self.avoid_encoding_json = avoid_encoding_json
def _get_json_files(self, source_dir: str) -> List[str]:
"""Get list of JSON files from source directory.
Args:
source_dir: Path to the source directory.
Returns:
List of JSON filenames.
Raises:
SystemExit: If no JSON files are found.
""" """
json_files = [f for f in os.listdir(source_dir) if f.endswith('.json')] json_files = [f for f in os.listdir(source_dir) if f.endswith('.json')]
if not json_files: if not json_files:
logger.error("No JSON files found in the source directory.") logger.error("No JSON files found in the source directory.")
return raise SystemExit(1)
logger.info("JSON files found:", json_files) logger.info("JSON files found:", json_files)
return json_files
def _copy_new_file(self, source_path: str, target_path: str, target_dir: str, json_file: str) -> None:
"""Copy a new JSON file to target directory.
Args:
source_path: Path to source file.
target_path: Path to target file.
target_dir: Target directory path.
json_file: Name of the JSON file.
"""
logger.info(f"Copying '{json_file}' to target directory...")
os.makedirs(target_dir, exist_ok=True)
shutil.copy2(source_path, target_path)
def _load_chat_data(self, file_path: str) -> Dict[str, Any]:
"""Load JSON data from file.
Args:
file_path: Path to JSON file.
Returns:
Loaded JSON data.
"""
with open(file_path, 'r') as file:
return json.load(file)
def _parse_chats_from_json(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Parse JSON data into ChatStore objects.
Args:
data: Raw JSON data.
Returns:
Dictionary of JID to ChatStore objects.
"""
return {jid: ChatStore.from_json(chat) for jid, chat in data.items()}
def _merge_chat_stores(self, source_chats: Dict[str, Any], target_chats: Dict[str, Any]) -> Dict[str, Any]:
"""Merge source chats into target chats.
Args:
source_chats: Source ChatStore objects.
target_chats: Target ChatStore objects.
Returns:
Merged ChatStore objects.
"""
for jid, chat in source_chats.items():
if jid in target_chats:
target_chats[jid].merge_with(chat)
else:
target_chats[jid] = chat
return target_chats
def _serialize_chats(self, chats: Dict[str, Any]) -> Dict[str, Any]:
"""Serialize ChatStore objects to JSON format.
Args:
chats: Dictionary of ChatStore objects.
Returns:
Serialized JSON data.
"""
return {jid: chat.to_json() for jid, chat in chats.items()}
def _has_changes(self, merged_data: Dict[str, Any], original_data: Dict[str, Any]) -> bool:
"""Check if merged data differs from original data.
Args:
merged_data: Merged JSON data.
original_data: Original JSON data.
Returns:
True if changes detected, False otherwise.
"""
return json.dumps(merged_data, sort_keys=True) != json.dumps(original_data, sort_keys=True)
def _save_merged_data(self, target_path: str, merged_data: Dict[str, Any]) -> None:
"""Save merged data to target file.
Args:
target_path: Path to target file.
merged_data: Merged JSON data.
"""
with open(target_path, 'w') as merged_file:
json.dump(
merged_data,
merged_file,
indent=self.pretty_print_json,
ensure_ascii=not self.avoid_encoding_json,
)
def _merge_json_file(self, source_path: str, target_path: str, json_file: str) -> None:
"""Merge a single JSON file.
Args:
source_path: Path to source file.
target_path: Path to target file.
json_file: Name of the JSON file.
"""
logger.info(f"Merging '{json_file}' with existing file in target directory...")
source_data = self._load_chat_data(source_path)
target_data = self._load_chat_data(target_path)
source_chats = self._parse_chats_from_json(source_data)
target_chats = self._parse_chats_from_json(target_data)
merged_chats = self._merge_chat_stores(source_chats, target_chats)
merged_data = self._serialize_chats(merged_chats)
if self._has_changes(merged_data, target_data):
logger.info(f"Changes detected in '{json_file}', updating target file...")
self._save_merged_data(target_path, merged_data)
else:
logger.info(f"No changes detected in '{json_file}', skipping update.")
def _should_copy_media_file(self, source_file: str, target_file: str) -> bool:
"""Check if media file should be copied.
Args:
source_file: Path to source media file.
target_file: Path to target media file.
Returns:
True if file should be copied, False otherwise.
"""
return not os.path.exists(target_file) or os.path.getmtime(source_file) > os.path.getmtime(target_file)
def _merge_media_directories(self, source_dir: str, target_dir: str, media_dir: str) -> None:
"""Merge media directories from source to target.
Args:
source_dir: Source directory path.
target_dir: Target directory path.
media_dir: Media directory name.
"""
source_media_path = os.path.join(source_dir, media_dir)
target_media_path = os.path.join(target_dir, media_dir)
logger.info(f"Merging media directories. Source: {source_media_path}, target: {target_media_path}")
if not os.path.exists(source_media_path):
return
for root, _, files in os.walk(source_media_path):
relative_path = os.path.relpath(root, source_media_path)
target_root = os.path.join(target_media_path, relative_path)
os.makedirs(target_root, exist_ok=True)
for file in files:
source_file = os.path.join(root, file)
target_file = os.path.join(target_root, file)
if self._should_copy_media_file(source_file, target_file):
logger.info(f"Copying '{source_file}' to '{target_file}'...")
shutil.copy2(source_file, target_file)
def merge(self, source_dir: str, target_dir: str, media_dir: str) -> None:
"""Merge JSON files and media from source to target directory.
Args:
source_dir: The path to the source directory containing JSON files.
target_dir: The path to the target directory to merge into.
media_dir: The path to the media directory.
"""
json_files = self._get_json_files(source_dir)
for json_file in json_files: for json_file in json_files:
source_path = os.path.join(source_dir, json_file) source_path = os.path.join(source_dir, json_file)
target_path = os.path.join(target_dir, json_file) target_path = os.path.join(target_dir, json_file)
if not os.path.exists(target_path): if not os.path.exists(target_path):
logger.info(f"Copying '{json_file}' to target directory...") self._copy_new_file(source_path, target_path, target_dir, json_file)
os.makedirs(target_dir, exist_ok=True)
shutil.copy2(source_path, target_path)
else: else:
logger.info( self._merge_json_file(source_path, target_path, json_file)
f"Merging '{json_file}' with existing file in target directory...")
with open(source_path, 'r') as src_file, open(target_path, 'r') as tgt_file:
source_data = json.load(src_file)
target_data = json.load(tgt_file)
# Parse JSON into ChatStore objects using from_json() self._merge_media_directories(source_dir, target_dir, media_dir)
source_chats = {jid: ChatStore.from_json(
chat) for jid, chat in source_data.items()}
target_chats = {jid: ChatStore.from_json(
chat) for jid, chat in target_data.items()}
# Merge chats using merge_with()
for jid, chat in source_chats.items():
if jid in target_chats:
target_chats[jid].merge_with(chat)
else:
target_chats[jid] = chat
# Serialize merged data def incremental_merge(source_dir: str, target_dir: str, media_dir: str, pretty_print_json: int, avoid_encoding_json: bool) -> None:
merged_data = {jid: chat.to_json() """Wrapper for merging JSON files from the source directory into the target directory.
for jid, chat in target_chats.items()}
# Check if the merged data differs from the original target data Args:
if json.dumps(merged_data, sort_keys=True) != json.dumps(target_data, sort_keys=True): source_dir: The path to the source directory containing JSON files.
logger.info( target_dir: The path to the target directory to merge into.
f"Changes detected in '{json_file}', updating target file...") media_dir: The path to the media directory.
with open(target_path, 'w') as merged_file: pretty_print_json: JSON indentation level.
json.dump( avoid_encoding_json: Whether to avoid ASCII encoding.
merged_data, """
merged_file, merger = IncrementalMerger(pretty_print_json, avoid_encoding_json)
indent=pretty_print_json, merger.merge(source_dir, target_dir, media_dir)
ensure_ascii=not avoid_encoding_json,
)
else:
logger.info(
f"No changes detected in '{json_file}', skipping update.")
# Merge media directories
source_media_path = os.path.join(source_dir, media_dir)
target_media_path = os.path.join(target_dir, media_dir)
logger.info(
f"Merging media directories. Source: {source_media_path}, target: {target_media_path}")
if os.path.exists(source_media_path):
for root, _, files in os.walk(source_media_path):
relative_path = os.path.relpath(root, source_media_path)
target_root = os.path.join(target_media_path, relative_path)
os.makedirs(target_root, exist_ok=True)
for file in files:
source_file = os.path.join(root, file)
target_file = os.path.join(target_root, file)
# we only copy if the file doesn't exist in the target or if the source is newer
if not os.path.exists(target_file) or os.path.getmtime(source_file) > os.path.getmtime(target_file):
logger.info(f"Copying '{source_file}' to '{target_file}'...")
shutil.copy2(source_file, target_file)
def get_file_name(contact: str, chat: ChatStore) -> Tuple[str, str]: def get_file_name(contact: str, chat: ChatStore) -> Tuple[str, str]:
@@ -384,9 +529,41 @@ def get_cond_for_empty(enable: bool, jid_field: str, broadcast_field: str) -> st
return f"AND (chat.hidden=0 OR {jid_field}='status@broadcast' OR {broadcast_field}>0)" if enable else "" return f"AND (chat.hidden=0 OR {jid_field}='status@broadcast' OR {broadcast_field}>0)" if enable else ""
def get_chat_condition(filter: Optional[List[str]], include: bool, columns: List[str], jid: Optional[str] = None, platform: Optional[str] = None) -> str: def _get_group_condition(jid: str, platform: str) -> str:
"""Generate platform-specific group identification condition.
Args:
jid: The JID column name.
platform: The platform ("android" or "ios").
Returns:
SQL condition string for group identification.
Raises:
ValueError: If platform is not supported.
"""
if platform == "android":
return f"{jid}.type == 1"
elif platform == "ios":
return f"{jid} IS NOT NULL"
else:
raise ValueError(
"Only android and ios are supported for argument platform if jid is not None")
def get_chat_condition(
filter: Optional[List[str]],
include: bool,
columns: List[str],
jid: Optional[str] = None,
platform: Optional[str] = None
) -> str:
"""Generates a SQL condition for filtering chats based on inclusion or exclusion criteria. """Generates a SQL condition for filtering chats based on inclusion or exclusion criteria.
SQL injection risks from chat filters were evaluated during development and deemed negligible
due to the tool's offline, trusted-input model (user running this tool on WhatsApp
backups/databases on their own device).
Args: Args:
filter: A list of phone numbers to include or exclude. filter: A list of phone numbers to include or exclude.
include: True to include chats that match the filter, False to exclude them. include: True to include chats that match the filter, False to exclude them.
@@ -400,35 +577,39 @@ def get_chat_condition(filter: Optional[List[str]], include: bool, columns: List
Raises: Raises:
ValueError: If the column count is invalid or an unsupported platform is provided. ValueError: If the column count is invalid or an unsupported platform is provided.
""" """
if filter is not None: if not filter:
conditions = [] return ""
if len(columns) < 2 and jid is not None:
if jid is not None and len(columns) < 2:
raise ValueError( raise ValueError(
"There must be at least two elements in argument columns if jid is not None") "There must be at least two elements in argument columns if jid is not None")
# Get group condition if needed
is_group_condition = None
if jid is not None: if jid is not None:
if platform == "android": is_group_condition = _get_group_condition(jid, platform)
is_group = f"{jid}.type == 1"
elif platform == "ios": # Build conditions for each chat filter
is_group = f"{jid} IS NOT NULL" conditions = []
else:
raise ValueError(
"Only android and ios are supported for argument platform if jid is not None")
for index, chat in enumerate(filter): for index, chat in enumerate(filter):
# Add connector for subsequent conditions (with double space)
connector = " OR" if include else " AND"
prefix = connector if index > 0 else ""
# Primary column condition
operator = "LIKE" if include else "NOT LIKE"
conditions.append(f"{prefix} {columns[0]} {operator} '%{chat}%'")
# Secondary column condition for groups
if len(columns) > 1 and is_group_condition:
if include: if include:
conditions.append( group_condition = f" OR ({columns[1]} {operator} '%{chat}%' AND {is_group_condition})"
f"{' OR' if index > 0 else ''} {columns[0]} LIKE '%{chat}%'")
if len(columns) > 1:
conditions.append(
f" OR ({columns[1]} LIKE '%{chat}%' AND {is_group})")
else: else:
conditions.append( group_condition = f" AND ({columns[1]} {operator} '%{chat}%' AND {is_group_condition})"
f"{' AND' if index > 0 else ''} {columns[0]} NOT LIKE '%{chat}%'") conditions.append(group_condition)
if len(columns) > 1:
conditions.append( combined_conditions = "".join(conditions)
f" AND ({columns[1]} NOT LIKE '%{chat}%' AND {is_group})") return f"AND ({combined_conditions})"
return f"AND ({' '.join(conditions)})"
else:
return ""
# Android Specific # Android Specific
@@ -584,7 +765,7 @@ def check_jid_map(db: sqlite3.Connection) -> bool:
""" """
cursor = db.cursor() cursor = db.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='jid_map'") cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='jid_map'")
return cursor.fetchone()is not None return cursor.fetchone() is not None
def get_jid_map_join(jid_map_exists: bool) -> str: def get_jid_map_join(jid_map_exists: bool) -> str:
@@ -634,6 +815,7 @@ def get_transcription_selection(db: sqlite3.Connection) -> str:
else: else:
return "NULL AS transcription_text" return "NULL AS transcription_text"
def setup_template(template: Optional[str], no_avatar: bool, experimental: bool = False) -> jinja2.Template: def setup_template(template: Optional[str], no_avatar: bool, experimental: bool = False) -> jinja2.Template:
""" """
Sets up the Jinja2 template environment and loads the template. Sets up the Jinja2 template environment and loads the template.

View File

@@ -254,3 +254,99 @@ class TestSafeName:
def test_safe_name(self, input_text, expected_output): def test_safe_name(self, input_text, expected_output):
result = safe_name(input_text) result = safe_name(input_text)
assert result == expected_output assert result == expected_output
class TestGetChatCondition:
def test_no_filter(self):
"""Test when filter is None"""
result = get_chat_condition(None, True, ["column1", "column2"])
assert result == ""
result = get_chat_condition(None, False, ["column1"])
assert result == ""
def test_include_single_chat_single_column(self):
"""Test including a single chat with single column"""
result = get_chat_condition(["1234567890"], True, ["phone"])
assert result == "AND ( phone LIKE '%1234567890%')"
def test_include_multiple_chats_single_column(self):
"""Test including multiple chats with single column"""
result = get_chat_condition(["1234567890", "0987654321"], True, ["phone"])
assert result == "AND ( phone LIKE '%1234567890%' OR phone LIKE '%0987654321%')"
def test_exclude_single_chat_single_column(self):
"""Test excluding a single chat with single column"""
result = get_chat_condition(["1234567890"], False, ["phone"])
assert result == "AND ( phone NOT LIKE '%1234567890%')"
def test_exclude_multiple_chats_single_column(self):
"""Test excluding multiple chats with single column"""
result = get_chat_condition(["1234567890", "0987654321"], False, ["phone"])
assert result == "AND ( phone NOT LIKE '%1234567890%' AND phone NOT LIKE '%0987654321%')"
def test_include_with_jid_android(self):
"""Test including chats with JID for Android platform"""
result = get_chat_condition(["1234567890"], True, ["phone", "name"], "jid", "android")
assert result == "AND ( phone LIKE '%1234567890%' OR (name LIKE '%1234567890%' AND jid.type == 1))"
def test_include_with_jid_ios(self):
"""Test including chats with JID for iOS platform"""
result = get_chat_condition(["1234567890"], True, ["phone", "name"], "jid", "ios")
assert result == "AND ( phone LIKE '%1234567890%' OR (name LIKE '%1234567890%' AND jid IS NOT NULL))"
def test_exclude_with_jid_android(self):
"""Test excluding chats with JID for Android platform"""
result = get_chat_condition(["1234567890"], False, ["phone", "name"], "jid", "android")
assert result == "AND ( phone NOT LIKE '%1234567890%' AND (name NOT LIKE '%1234567890%' AND jid.type == 1))"
def test_exclude_with_jid_ios(self):
"""Test excluding chats with JID for iOS platform"""
result = get_chat_condition(["1234567890"], False, ["phone", "name"], "jid", "ios")
assert result == "AND ( phone NOT LIKE '%1234567890%' AND (name NOT LIKE '%1234567890%' AND jid IS NOT NULL))"
def test_multiple_chats_with_jid_android(self):
"""Test multiple chats with JID for Android platform"""
result = get_chat_condition(["1234567890", "0987654321"], True, ["phone", "name"], "jid", "android")
expected = "AND ( phone LIKE '%1234567890%' OR (name LIKE '%1234567890%' AND jid.type == 1) OR phone LIKE '%0987654321%' OR (name LIKE '%0987654321%' AND jid.type == 1))"
assert result == expected
def test_multiple_chats_exclude_with_jid_android(self):
"""Test excluding multiple chats with JID for Android platform"""
result = get_chat_condition(["1234567890", "0987654321"], False, ["phone", "name"], "jid", "android")
expected = "AND ( phone NOT LIKE '%1234567890%' AND (name NOT LIKE '%1234567890%' AND jid.type == 1) AND phone NOT LIKE '%0987654321%' AND (name NOT LIKE '%0987654321%' AND jid.type == 1))"
assert result == expected
def test_invalid_column_count_with_jid(self):
"""Test error when column count is less than 2 but jid is provided"""
with pytest.raises(ValueError, match="There must be at least two elements in argument columns if jid is not None"):
get_chat_condition(["1234567890"], True, ["phone"], "jid", "android")
def test_unsupported_platform(self):
"""Test error when unsupported platform is provided"""
with pytest.raises(ValueError, match="Only android and ios are supported for argument platform if jid is not None"):
get_chat_condition(["1234567890"], True, ["phone", "name"], "jid", "windows")
def test_empty_filter_list(self):
"""Test with empty filter list"""
result = get_chat_condition([], True, ["phone"])
assert result == ""
result = get_chat_condition([], False, ["phone"])
assert result == ""
def test_filter_with_empty_strings(self):
"""Test with filter containing empty strings"""
result = get_chat_condition(["", "1234567890"], True, ["phone"])
assert result == "AND ( phone LIKE '%%' OR phone LIKE '%1234567890%')"
result = get_chat_condition([""], True, ["phone"])
assert result == "AND ( phone LIKE '%%')"
def test_special_characters_in_filter(self):
"""Test with special characters in filter values"""
result = get_chat_condition(["test@example.com"], True, ["email"])
assert result == "AND ( email LIKE '%test@example.com%')"
result = get_chat_condition(["user-name"], True, ["username"])
assert result == "AND ( username LIKE '%user-name%')"