Add more docstrings

This commit is contained in:
KnugiHK
2025-03-02 00:28:47 +08:00
parent 272454c2ce
commit 86cb44ced9
2 changed files with 154 additions and 23 deletions

View File

@@ -2,11 +2,11 @@
import os
from datetime import datetime, tzinfo, timedelta
from typing import Union
from typing import Union, Optional
class Timing():
def __init__(self, timezone_offset: Union[int, None]):
def __init__(self, timezone_offset: Optional[int]):
self.timezone_offset = timezone_offset
def format_timestamp(self, timestamp, format):
@@ -80,9 +80,9 @@ class Message():
def __init__(
self,
*,
from_me: Union[bool,int],
from_me: Union[bool, int],
timestamp: int,
time: Union[int,float,str],
time: Union[int, float, str],
key_id: int,
received_timestamp: int,
read_timestamp: int,

View File

@@ -1,3 +1,4 @@
import sqlite3
import jinja2
import json
import os
@@ -9,6 +10,7 @@ from markupsafe import Markup
from datetime import datetime, timedelta
from enum import IntEnum
from Whatsapp_Chat_Exporter.data_model import ChatStore
from typing import Dict, List, Optional, Tuple, Union
try:
from enum import StrEnum, IntEnum
except ImportError:
@@ -26,7 +28,15 @@ ROW_SIZE = 0x3D0
CURRENT_TZ_OFFSET = datetime.now().astimezone().utcoffset().seconds / 3600
def convert_time_unit(time_second: int):
def convert_time_unit(time_second: int) -> str:
"""Converts a time duration in seconds to a human-readable string.
Args:
time_second: The time duration in seconds.
Returns:
str: A human-readable string representing the time duration.
"""
time = str(timedelta(seconds=time_second))
if "day" not in time:
if time_second < 1:
@@ -46,11 +56,19 @@ def convert_time_unit(time_second: int):
return time
def bytes_to_readable(size_bytes: int):
"""From https://stackoverflow.com/a/14822210/9478891
def bytes_to_readable(size_bytes: int) -> str:
"""Converts a file size in bytes to a human-readable string with units.
From https://stackoverflow.com/a/14822210/9478891
Authors: james-sapam & other contributors
Licensed under CC BY-SA 3.0
See git commit logs for changes, if any.
Args:
size_bytes: The file size in bytes.
Returns:
A human-readable string representing the file size.
"""
if size_bytes == 0:
return "0B"
@@ -61,7 +79,18 @@ def bytes_to_readable(size_bytes: int):
return "%s %s" % (s, size_name[i])
def readable_to_bytes(size_str: str):
def readable_to_bytes(size_str: str) -> int:
"""Converts a human-readable file size string to bytes.
Args:
size_str: The human-readable file size string (e.g., "1024KB", "1MB", "2GB").
Returns:
The file size in bytes.
Raises:
ValueError: If the input string is invalid.
"""
SIZE_UNITS = {
'B': 1,
'KB': 1024,
@@ -80,11 +109,28 @@ def readable_to_bytes(size_str: str):
return int(number) * SIZE_UNITS[unit]
def sanitize_except(html):
def sanitize_except(html: str) -> Markup:
"""Sanitizes HTML, only allowing <br> tag.
Args:
html: The HTML string to sanitize.
Returns:
A Markup object containing the sanitized HTML.
"""
return Markup(sanitize(html, tags=["br"]))
def determine_day(last, current):
def determine_day(last: int, current: int) -> Optional[datetime.date]:
"""Determines if the day has changed between two timestamps. Exposed to Jinja's environment.
Args:
last: The timestamp of the previous message.
current: The timestamp of the current message.
Returns:
The date of the current message if it's a different day than the last message, otherwise None.
"""
last = datetime.fromtimestamp(last).date()
current = datetime.fromtimestamp(current).date()
if last == current:
@@ -169,7 +215,13 @@ class Device(StrEnum):
EXPORTED = "exported"
def import_from_json(json_file, data):
def import_from_json(json_file: str, data: Dict[str, ChatStore]):
"""Imports chat data from a JSON file into the data dictionary.
Args:
json_file: The path to the JSON file.
data: The dictionary to store the imported chat data.
"""
from Whatsapp_Chat_Exporter.data_model import ChatStore, Message
with open(json_file, "r") as f:
temp_data = json.loads(f.read())
@@ -204,11 +256,31 @@ def import_from_json(json_file, data):
print(f"Importing chats from JSON...({index + 1}/{total_row_number})", end="\r")
def sanitize_filename(file_name: str):
def sanitize_filename(file_name: str) -> str:
"""Sanitizes a filename by removing invalid and unsafe characters.
Args:
file_name: The filename to sanitize.
Returns:
The sanitized filename.
"""
return "".join(x for x in file_name if x.isalnum() or x in "- ")
def get_file_name(contact: str, chat: ChatStore):
def get_file_name(contact: str, chat: ChatStore) -> Tuple[str, str]:
"""Generates a sanitized filename and contact name for a chat.
Args:
contact: The contact identifier (e.g., a phone number or group ID).
chat: The ChatStore object for the chat.
Returns:
A tuple containing the sanitized filename and the contact name.
Raises:
ValueError: If the contact format is unexpected.
"""
if "@" not in contact and contact not in ("000000000000000", "000000000000001", "ExportedChat"):
raise ValueError("Unexpected contact format: " + contact)
phone_number = contact.split('@')[0]
@@ -228,11 +300,36 @@ def get_file_name(contact: str, chat: ChatStore):
return sanitize_filename(file_name), name
def get_cond_for_empty(enable, jid_field: str, broadcast_field: str):
def get_cond_for_empty(enable: bool, jid_field: str, broadcast_field: str) -> str:
"""Generates a SQL condition for filtering empty chats.
Args:
enable: True to include non-empty chats, False to include empty chats.
jid_field: The name of the JID field in the SQL query.
broadcast_field: The column name of the broadcast field in the SQL query.
Returns:
A SQL condition string.
"""
return f"AND (chat.hidden=0 OR {jid_field}='status@broadcast' OR {broadcast_field}>0)" if enable else ""
def get_chat_condition(filter, include, columns, jid=None, platform=None):
def get_chat_condition(filter: Optional[List[str]], include: bool, columns: List[str], jid: Optional[str] = None, platform: Optional[str] = None) -> str:
"""Generates a SQL condition for filtering chats based on inclusion or exclusion criteria.
Args:
filter: A list of phone numbers to include or exclude.
include: True to include chats that match the filter, False to exclude them.
columns: A list of column names to check against the filter.
jid: The JID column name (used for group identification).
platform: The platform ("android" or "ios") for platform-specific JID queries.
Returns:
A SQL condition string.
Raises:
ValueError: If the column count is invalid or an unsupported platform is provided.
"""
if filter is not None:
conditions = []
if len(columns) < 2 and jid is not None:
@@ -280,7 +377,16 @@ class DbType(StrEnum):
CONTACT = "contact"
def determine_metadata(content, init_msg):
def determine_metadata(content: sqlite3.Row, init_msg: Optional[str]) -> Optional[str]:
"""Determines the metadata of a message.
Args:
content (sqlite3.Row): A row from the messages table.
init_msg (Optional[str]): The initial message, if any.
Returns:
The metadata as a string or None if the type is unsupported.
"""
msg = init_msg if init_msg else ""
if content["is_me_joined"] == 1: # Override
return f"You were added into the group by {msg}"
@@ -361,7 +467,17 @@ def determine_metadata(content, init_msg):
return msg
def get_status_location(output_folder, offline_static):
def get_status_location(output_folder: str, offline_static: str) -> str:
"""
Gets the location of the W3.CSS file, either from web or local storage.
Args:
output_folder (str): The folder where offline static files will be stored.
offline_static (str): The subfolder name for static files. If falsy, returns web URL.
Returns:
str: The path or URL to the W3.CSS file.
"""
w3css = "https://www.w3schools.com/w3css/4/w3.css"
if not offline_static:
return w3css
@@ -376,7 +492,18 @@ def get_status_location(output_folder, offline_static):
w3css = os.path.join(offline_static, "w3.css")
def setup_template(template, no_avatar, experimental=False):
def setup_template(template: Optional[str], no_avatar: bool, experimental: bool = False) -> jinja2.Template:
"""
Sets up the Jinja2 template environment and loads the template.
Args:
template (Optional[str]): Path to custom template file. If None, uses default template.
no_avatar (bool): Whether to disable avatar display in the template.
experimental (bool, optional): Whether to use experimental template features. Defaults to False.
Returns:
jinja2.Template: The configured Jinja2 template object.
"""
if template is None or experimental:
template_dir = os.path.dirname(__file__)
template_file = "whatsapp.html" if not experimental else template
@@ -396,13 +523,17 @@ def setup_template(template, no_avatar, experimental=False):
APPLE_TIME = 978307200
def slugify(value, allow_unicode=False):
def slugify(value: str, allow_unicode: bool = False) -> str:
"""
Convert text to ASCII-only slugs for URL-safe strings.
Taken from https://github.com/django/django/blob/master/django/utils/text.py
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Args:
value (str): The string to convert to a slug.
allow_unicode (bool, optional): Whether to allow Unicode characters. Defaults to False.
Returns:
str: The slugified string with only alphanumerics, underscores, or hyphens.
"""
value = str(value)
if allow_unicode: