diff --git a/head-tracking/colors.py b/head-tracking/colors.py new file mode 100644 index 0000000..cc1ba21 --- /dev/null +++ b/head-tracking/colors.py @@ -0,0 +1,29 @@ +import logging +from logging import Formatter, LogRecord +from typing import Dict + +class Colors: + RESET: str = "\033[0m" + BOLD: str = "\033[1m" + RED: str = "\033[91m" + GREEN: str = "\033[92m" + YELLOW: str = "\033[93m" + BLUE: str = "\033[94m" + MAGENTA: str = "\033[95m" + CYAN: str = "\033[96m" + WHITE: str = "\033[97m" + BG_BLACK: str = "\033[40m" + +class ColorFormatter(Formatter): + FORMATS: Dict[int, str] = { + logging.DEBUG: f"{Colors.BLUE}[%(levelname)s] %(message)s{Colors.RESET}", + logging.INFO: f"{Colors.GREEN}%(message)s{Colors.RESET}", + logging.WARNING: f"{Colors.YELLOW}%(message)s{Colors.RESET}", + logging.ERROR: f"{Colors.RED}[%(levelname)s] %(message)s{Colors.RESET}", + logging.CRITICAL: f"{Colors.RED}{Colors.BOLD}[%(levelname)s] %(message)s{Colors.RESET}" + } + + def format(self, record: LogRecord) -> str: + log_fmt: str = self.FORMATS.get(record.levelno) + formatter: Formatter = Formatter(log_fmt, datefmt="%H:%M:%S") + return formatter.format(record) diff --git a/head-tracking/connection_manager.py b/head-tracking/connection_manager.py index 1e18b04..ae92dc3 100644 --- a/head-tracking/connection_manager.py +++ b/head-tracking/connection_manager.py @@ -1,23 +1,25 @@ import bluetooth import logging +from bluetooth import BluetoothSocket +from logging import Logger class ConnectionManager: - INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" - START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" - STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" + INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" + START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" + STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" - def __init__(self, bt_addr="28:2D:7F:C2:05:5B", psm=0x1001, logger=None): - self.bt_addr = bt_addr - self.psm = psm - self.logger = logger if logger else logging.getLogger(__name__) - self.sock = None - self.connected = False - self.started = False + def __init__(self, bt_addr: str = "28:2D:7F:C2:05:5B", psm: int = 0x1001, logger: Logger = None) -> None: + self.bt_addr: str = bt_addr + self.psm: int = psm + self.logger: Logger = logger if logger else logging.getLogger(__name__) + self.sock: BluetoothSocket = None + self.connected: bool = False + self.started: bool = False - def connect(self): + def connect(self) -> bool: self.logger.info(f"Connecting to {self.bt_addr} on PSM {self.psm:#04x}...") try: - self.sock = bluetooth.BluetoothSocket(bluetooth.L2CAP) + self.sock = BluetoothSocket(bluetooth.L2CAP) self.sock.connect((self.bt_addr, self.psm)) self.connected = True self.logger.info("Connected to AirPods.") @@ -28,7 +30,7 @@ class ConnectionManager: self.connected = False return self.connected - def send_start(self): + def send_start(self) -> bool: if not self.connected: self.logger.error("Not connected. Cannot send START command.") return False @@ -40,7 +42,7 @@ class ConnectionManager: self.logger.info("START command has already been sent.") return True - def send_stop(self): + def send_stop(self) -> None: if self.connected and self.started: try: self.sock.send(bytes.fromhex(self.STOP_CMD)) @@ -51,7 +53,7 @@ class ConnectionManager: else: self.logger.info("Cannot send STOP; not started or not connected.") - def disconnect(self): + def disconnect(self) -> None: if self.sock: try: self.sock.close() @@ -59,4 +61,4 @@ class ConnectionManager: except Exception as e: self.logger.error(f"Error during disconnect: {e}") self.connected = False - self.started = False \ No newline at end of file + self.started = False diff --git a/head-tracking/gestures.py b/head-tracking/gestures.py index 394b72a..a598409 100644 --- a/head-tracking/gestures.py +++ b/head-tracking/gestures.py @@ -1,88 +1,65 @@ -import bluetooth -import threading -import time import logging import statistics +import time +from bluetooth import BluetoothSocket from collections import deque +from colors import * +from connection_manager import ConnectionManager +from logging import Logger, StreamHandler +from threading import Lock, Thread +from typing import Any, Deque, List, Optional, Tuple -class Colors: - RESET = "\033[0m" - BOLD = "\033[1m" - RED = "\033[91m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - BLUE = "\033[94m" - MAGENTA = "\033[95m" - CYAN = "\033[96m" - WHITE = "\033[97m" - BG_BLACK = "\033[40m" - -class ColorFormatter(logging.Formatter): - FORMATS = { - logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET, - logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET, - logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET, - logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET, - logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S") - return formatter.format(record) - -handler = logging.StreamHandler() +handler: StreamHandler = StreamHandler() handler.setFormatter(ColorFormatter()) -log = logging.getLogger(__name__) +log: Logger = logging.getLogger(__name__) log.setLevel(logging.INFO) log.addHandler(handler) log.propagate = False class GestureDetector: - INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" - START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" - STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" + INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" + START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" + STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" - def __init__(self, conn=None): - self.sock = None - self.bt_addr = "28:2D:7F:C2:05:5B" - self.psm = 0x1001 - self.running = False - self.data_lock = threading.Lock() + def __init__(self, conn: ConnectionManager = None) -> None: + self.sock: BluetoothSocket = None + self.bt_addr: str = "28:2D:7F:C2:05:5B" + self.psm: int = 0x1001 + self.running: bool = False + self.data_lock: Lock = Lock() - self.horiz_buffer = deque(maxlen=100) - self.vert_buffer = deque(maxlen=100) + self.horiz_buffer: Deque[int] = deque(maxlen=100) + self.vert_buffer: Deque[int] = deque(maxlen=100) - self.horiz_avg_buffer = deque(maxlen=5) - self.vert_avg_buffer = deque(maxlen=5) + self.horiz_avg_buffer: Deque[float] = deque(maxlen=5) + self.vert_avg_buffer: Deque[float] = deque(maxlen=5) - self.horiz_peaks = [] - self.horiz_troughs = [] - self.vert_peaks = [] - self.vert_troughs = [] + self.horiz_peaks: List[int] = [] + self.horiz_troughs: List[int] = [] + self.vert_peaks: List[int] = [] + self.vert_troughs: List[int] = [] - self.last_peak_time = 0 - self.peak_intervals = deque(maxlen=5) + self.last_peak_time: float = 0 + self.peak_intervals: Deque[float] = deque(maxlen=5) - self.peak_threshold = 400 - self.direction_change_threshold = 175 - self.rhythm_consistency_threshold = 0.5 + self.peak_threshold: int = 400 + self.direction_change_threshold: int = 175 + self.rhythm_consistency_threshold: float = 0.5 - self.horiz_increasing = None - self.vert_increasing = None + self.horiz_increasing: Optional[bool] = None + self.vert_increasing: Optional[bool] = None self.required_extremes = 3 - self.detection_timeout = 15 + self.detection_timeout: int = 15 - self.min_confidence_threshold = 0.7 + self.min_confidence_threshold: float = 0.7 - self.conn = conn + self.conn: ConnectionManager = conn - def connect(self): + def connect(self) -> bool: try: log.info(f"Connecting to AirPods at {self.bt_addr}...") if self.conn is None: - from connection_manager import ConnectionManager self.conn = ConnectionManager(self.bt_addr, self.psm, logger=log) if not self.conn.connect(): return False @@ -97,13 +74,13 @@ class GestureDetector: log.error(f"{Colors.RED}Connection failed: {e}{Colors.RESET}") return False - def process_data(self): + def process_data(self) -> None: """Process incoming head tracking data.""" self.conn.send_start() log.info(f"{Colors.GREEN}✓ Head tracking activated{Colors.RESET}") self.running = True - start_time = time.time() + start_time: float = time.time() log.info(f"{Colors.GREEN}Ready! Make a YES or NO gesture{Colors.RESET}") log.info(f"{Colors.YELLOW}Tip: Use natural, moderate speed head movements{Colors.RESET}") @@ -118,10 +95,10 @@ class GestureDetector: if not self.sock: log.error("Socket not available.") break - data = self.sock.recv(1024) - formatted = self.format_hex(data) + data: bytes = self.sock.recv(1024) + formatted: str = self.format_hex(data) if self.is_valid_tracking_packet(formatted): - raw_bytes = bytes.fromhex(formatted.replace(" ", "")) + raw_bytes: bytes = bytes.fromhex(formatted.replace(" ", "")) horizontal, vertical = self.extract_orientation_values(raw_bytes) if horizontal is not None and vertical is not None: @@ -132,7 +109,7 @@ class GestureDetector: self.vert_buffer.append(smooth_v) self.detect_peaks_and_troughs() - gesture = self.detect_gestures() + gesture: Optional[str] = self.detect_gestures() if gesture: self.running = False @@ -143,19 +120,19 @@ class GestureDetector: log.error(f"Data processing error: {e}") break - def disconnect(self): + def disconnect(self) -> None: """Disconnect from socket.""" self.conn.disconnect() - def format_hex(self, data): + def format_hex(self, data: bytes) -> str: """Format binary data to readable hex string.""" - hex_str = data.hex() + hex_str: str = data.hex() return ' '.join(hex_str[i:i+2] for i in range(0, len(hex_str), 2)) - def is_valid_tracking_packet(self, hex_string): + def is_valid_tracking_packet(self, hex_string: str) -> bool: """Verify packet is a valid head tracking packet.""" - standard_header = "04 00 04 00 17 00 00 00 10 00 45 00" - alternate_header = "04 00 04 00 17 00 00 00 10 00 44 00" + standard_header: str = "04 00 04 00 17 00 00 00 10 00 45 00" + alternate_header: str = "04 00 04 00 17 00 00 00 10 00 44 00" if not hex_string.startswith(standard_header) and not hex_string.startswith(alternate_header): return False @@ -164,55 +141,55 @@ class GestureDetector: return True - def extract_orientation_values(self, raw_bytes): + def extract_orientation_values(self, raw_bytes: bytes) -> Tuple[Optional[int], Optional[int]]: """Extract head orientation data from packet.""" try: - horizontal = int.from_bytes(raw_bytes[51:53], byteorder='little', signed=True) - vertical = int.from_bytes(raw_bytes[53:55], byteorder='little', signed=True) + horizontal: int = int.from_bytes(raw_bytes[51:53], byteorder='little', signed=True) + vertical: int = int.from_bytes(raw_bytes[53:55], byteorder='little', signed=True) return horizontal, vertical except Exception as e: log.debug(f"Failed to extract orientation: {e}") return None, None - def apply_smoothing(self, horizontal, vertical): + def apply_smoothing(self, horizontal: int, vertical: int) -> Tuple[float, float]: """Apply moving average smoothing (Apple-like filtering).""" self.horiz_avg_buffer.append(horizontal) self.vert_avg_buffer.append(vertical) - smooth_horiz = sum(self.horiz_avg_buffer) / len(self.horiz_avg_buffer) - smooth_vert = sum(self.vert_avg_buffer) / len(self.vert_avg_buffer) + smooth_horiz: float = sum(self.horiz_avg_buffer) / len(self.horiz_avg_buffer) + smooth_vert: float = sum(self.vert_avg_buffer) / len(self.vert_avg_buffer) return smooth_horiz, smooth_vert - def detect_peaks_and_troughs(self): + def detect_peaks_and_troughs(self) -> None: """Detect motion direction changes with Apple-like refinements.""" if len(self.horiz_buffer) < 4 or len(self.vert_buffer) < 4: return - h_values = list(self.horiz_buffer)[-4:] - v_values = list(self.vert_buffer)[-4:] + h_values: List[int] = list(self.horiz_buffer)[-4:] + v_values: List[int] = list(self.vert_buffer)[-4:] - h_variance = statistics.variance(h_values) if len(h_values) > 1 else 0 - v_variance = statistics.variance(v_values) if len(v_values) > 1 else 0 + h_variance: float = statistics.variance(h_values) if len(h_values) > 1 else 0 + v_variance: float = statistics.variance(v_values) if len(v_values) > 1 else 0 - current = self.horiz_buffer[-1] - prev = self.horiz_buffer[-2] + current: int = self.horiz_buffer[-1] + prev: int = self.horiz_buffer[-2] if self.horiz_increasing is None: self.horiz_increasing = current > prev - dynamic_h_threshold = max(100, min(self.direction_change_threshold, h_variance / 3)) + dynamic_h_threshold: float = max(100, min(self.direction_change_threshold, h_variance / 3)) if self.horiz_increasing and current < prev - dynamic_h_threshold: if abs(prev) > self.peak_threshold: self.horiz_peaks.append((len(self.horiz_buffer)-1, prev, time.time())) - direction = "➡️ " if prev > 0 else "⬅️ " + direction: str = "➡️ " if prev > 0 else "⬅️ " log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}") - now = time.time() + now: float = time.time() if self.last_peak_time > 0: - interval = now - self.last_peak_time + interval: float = now - self.last_peak_time self.peak_intervals.append(interval) self.last_peak_time = now @@ -221,34 +198,34 @@ class GestureDetector: elif not self.horiz_increasing and current > prev + dynamic_h_threshold: if abs(prev) > self.peak_threshold: self.horiz_troughs.append((len(self.horiz_buffer)-1, prev, time.time())) - direction = "➡️ " if prev > 0 else "⬅️ " + direction: str = "➡️ " if prev > 0 else "⬅️ " log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}") - now = time.time() + now: float = time.time() if self.last_peak_time > 0: - interval = now - self.last_peak_time + interval: float = now - self.last_peak_time self.peak_intervals.append(interval) self.last_peak_time = now self.horiz_increasing = True - current = self.vert_buffer[-1] - prev = self.vert_buffer[-2] + current: int = self.vert_buffer[-1] + prev: int = self.vert_buffer[-2] if self.vert_increasing is None: self.vert_increasing = current > prev - dynamic_v_threshold = max(100, min(self.direction_change_threshold, v_variance / 3)) + dynamic_v_threshold: float = max(100, min(self.direction_change_threshold, v_variance / 3)) if self.vert_increasing and current < prev - dynamic_v_threshold: if abs(prev) > self.peak_threshold: self.vert_peaks.append((len(self.vert_buffer)-1, prev, time.time())) - direction = "⬆️ " if prev > 0 else "⬇️ " + direction: str = "⬆️ " if prev > 0 else "⬇️ " log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}") - now = time.time() + now: float = time.time() if self.last_peak_time > 0: - interval = now - self.last_peak_time + interval: float = now - self.last_peak_time self.peak_intervals.append(interval) self.last_peak_time = now @@ -257,60 +234,60 @@ class GestureDetector: elif not self.vert_increasing and current > prev + dynamic_v_threshold: if abs(prev) > self.peak_threshold: self.vert_troughs.append((len(self.vert_buffer)-1, prev, time.time())) - direction = "⬆️ " if prev > 0 else "⬇️ " + direction: str = "⬆️ " if prev > 0 else "⬇️ " log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}") - now = time.time() + now: float = time.time() if self.last_peak_time > 0: - interval = now - self.last_peak_time + interval: float = now - self.last_peak_time self.peak_intervals.append(interval) self.last_peak_time = now self.vert_increasing = True - def calculate_rhythm_consistency(self): + def calculate_rhythm_consistency(self) -> float: """Calculate how consistent the timing between peaks is (Apple-like).""" if len(self.peak_intervals) < 2: return 0 - mean_interval = statistics.mean(self.peak_intervals) + mean_interval: float = statistics.mean(self.peak_intervals) if mean_interval == 0: return 0 - variances = [(i/mean_interval - 1.0) ** 2 for i in self.peak_intervals] - consistency = 1.0 - min(1.0, statistics.mean(variances) / self.rhythm_consistency_threshold) + variances: List[float] = [(i/mean_interval - 1.0) ** 2 for i in self.peak_intervals] + consistency: float = 1.0 - min(1.0, statistics.mean(variances) / self.rhythm_consistency_threshold) return max(0, consistency) - def calculate_confidence_score(self, extremes, is_vertical=True): + def calculate_confidence_score(self, extremes: List[Tuple[int, int, float]], is_vertical: bool = True) -> float: """Calculate confidence score for gesture detection (Apple-like).""" if len(extremes) < self.required_extremes: return 0.0 - sorted_extremes = sorted(extremes, key=lambda x: x[0]) + sorted_extremes: List[Tuple[int, int, float]] = sorted(extremes, key=lambda x: x[0]) - recent = sorted_extremes[-self.required_extremes:] + recent: List[Tuple[int, int, float]] = sorted_extremes[-self.required_extremes:] - avg_amplitude = sum(abs(val) for _, val, _ in recent) / len(recent) - amplitude_factor = min(1.0, avg_amplitude / 600) + avg_amplitude: float = sum(abs(val) for _, val, _ in recent) / len(recent) + amplitude_factor: float = min(1.0, avg_amplitude / 600) - rhythm_factor = self.calculate_rhythm_consistency() + rhythm_factor: float = self.calculate_rhythm_consistency() - signs = [1 if val > 0 else -1 for _, val, _ in recent] - alternating = all(signs[i] != signs[i-1] for i in range(1, len(signs))) - alternation_factor = 1.0 if alternating else 0.5 + signs: List[int] = [1 if val > 0 else -1 for _, val, _ in recent] + alternating: bool = all(signs[i] != signs[i-1] for i in range(1, len(signs))) + alternation_factor: float = 1.0 if alternating else 0.5 if is_vertical: - vert_amp = sum(abs(val) for _, val, _ in recent) / len(recent) - horiz_vals = list(self.horiz_buffer)[-len(recent)*2:] - horiz_amp = sum(abs(val) for val in horiz_vals) / len(horiz_vals) if horiz_vals else 0 - isolation_factor = min(1.0, vert_amp / (horiz_amp + 0.1) * 1.2) + vert_amp: float = sum(abs(val) for _, val, _ in recent) / len(recent) + horiz_vals: List[int] = list(self.horiz_buffer)[-len(recent)*2:] + horiz_amp: float = sum(abs(val) for val in horiz_vals) / len(horiz_vals) if horiz_vals else 0 + isolation_factor: float = min(1.0, vert_amp / (horiz_amp + 0.1) * 1.2) else: - horiz_amp = sum(abs(val) for _, val, _ in recent) - vert_vals = list(self.vert_buffer)[-len(recent)*2:] - vert_amp = sum(abs(val) for val in vert_vals) / len(vert_vals) if vert_vals else 0 - isolation_factor = min(1.0, horiz_amp / (vert_amp + 0.1) * 1.2) + horiz_amp: float = sum(abs(val) for _, val, _ in recent) + vert_vals: List[int] = list(self.vert_buffer)[-len(recent)*2:] + vert_amp: float = sum(abs(val) for val in vert_vals) / len(vert_vals) if vert_vals else 0 + isolation_factor: float = min(1.0, horiz_amp / (vert_amp + 0.1) * 1.2) - confidence = ( + confidence: float = ( amplitude_factor * 0.4 + rhythm_factor * 0.2 + alternation_factor * 0.2 + @@ -319,12 +296,12 @@ class GestureDetector: return confidence - def detect_gestures(self): + def detect_gestures(self) -> Optional[str]: """Recognize head gesture patterns with Apple-like intelligence.""" if len(self.vert_peaks) + len(self.vert_troughs) >= self.required_extremes: - all_extremes = sorted(self.vert_peaks + self.vert_troughs, key=lambda x: x[0]) + all_extremes: List[Tuple[int, int, float]] = sorted(self.vert_peaks + self.vert_troughs, key=lambda x: x[0]) - confidence = self.calculate_confidence_score(all_extremes, is_vertical=True) + confidence: float = self.calculate_confidence_score(all_extremes, is_vertical=True) log.info(f"Vertical motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})") @@ -333,9 +310,9 @@ class GestureDetector: return "YES" if len(self.horiz_peaks) + len(self.horiz_troughs) >= self.required_extremes: - all_extremes = sorted(self.horiz_peaks + self.horiz_troughs, key=lambda x: x[0]) + all_extremes: List[Tuple[int, int, float]] = sorted(self.horiz_peaks + self.horiz_troughs, key=lambda x: x[0]) - confidence = self.calculate_confidence_score(all_extremes, is_vertical=False) + confidence: float = self.calculate_confidence_score(all_extremes, is_vertical=False) log.info(f"Horizontal motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})") @@ -345,7 +322,7 @@ class GestureDetector: return None - def start_detection(self): + def start_detection(self) -> None: """Begin gesture detection process.""" log.info(f"{Colors.BOLD}{Colors.WHITE}Starting gesture detection...{Colors.RESET}") @@ -353,7 +330,7 @@ class GestureDetector: log.error(f"{Colors.RED}Failed to connect to AirPods.{Colors.RESET}") return - data_thread = threading.Thread(target=self.process_data) + data_thread: Thread = Thread(target=self.process_data) data_thread.daemon = True data_thread.start() @@ -377,5 +354,5 @@ if __name__ == "__main__": print(f"{Colors.GREEN}• YES: {Colors.WHITE}nodding head up and down{Colors.RESET}") print(f"{Colors.RED}• NO: {Colors.WHITE}shaking head left and right{Colors.RESET}\n") - detector = GestureDetector() + detector: GestureDetector = GestureDetector() detector.start_detection() \ No newline at end of file diff --git a/head-tracking/head_orientation.py b/head-tracking/head_orientation.py index d27cb85..1f90990 100644 --- a/head-tracking/head_orientation.py +++ b/head-tracking/head_orientation.py @@ -1,63 +1,43 @@ import math -import drawille import numpy as np import logging import os +from colors import * +from drawille import Canvas +from logging import Logger, StreamHandler +from matplotlib.animation import FuncAnimation +from matplotlib.pyplot import Axes, Figure +from numpy.typing import NDArray +from os import terminal_size as TerminalSize +from typing import Any, Dict, List, Optional, Tuple -class Colors: - RESET = "\033[0m" - BOLD = "\033[1m" - RED = "\033[91m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - BLUE = "\033[94m" - MAGENTA = "\033[95m" - CYAN = "\033[96m" - WHITE = "\033[97m" - BG_BLACK = "\033[40m" - -class ColorFormatter(logging.Formatter): - FORMATS = { - logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET, - logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET, - logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET, - logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET, - logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S") - return formatter.format(record) - -handler = logging.StreamHandler() +handler: StreamHandler = StreamHandler() handler.setFormatter(ColorFormatter()) -log = logging.getLogger(__name__) +log: Logger = logging.getLogger(__name__) log.setLevel(logging.INFO) log.addHandler(handler) log.propagate = False - class HeadOrientation: - def __init__(self, use_terminal=False): - self.orientation_offset = 5500 - self.o1_neutral = 19000 - self.o2_neutral = 0 - self.o3_neutral = 0 - self.calibration_samples = [] - self.calibration_complete = False - self.calibration_sample_count = 10 - self.fig = None - self.ax = None - self.arrow = None - self.animation = None - self.use_terminal = use_terminal + def __init__(self, use_terminal: bool = False) -> None: + self.orientation_offset: int = 5500 + self.o1_neutral: int = 19000 + self.o2_neutral: int = 0 + self.o3_neutral: int = 0 + self.calibration_samples: List[List[int]] = [] + self.calibration_complete: bool = False + self.calibration_sample_count: int = 10 + self.fig: Optional[Figure] = None + self.ax: Optional[Axes] = None + self.arrow: Any = None + self.animation: Optional[FuncAnimation] = None + self.use_terminal: bool = use_terminal - def reset_calibration(self): + def reset_calibration(self) -> None: self.calibration_samples = [] self.calibration_complete = False - def add_calibration_sample(self, orientation_values): + def add_calibration_sample(self, orientation_values: List[int]) -> bool: if len(self.calibration_samples) < self.calibration_sample_count: self.calibration_samples.append(orientation_values) return False @@ -66,57 +46,58 @@ class HeadOrientation: return True return True - def _calculate_calibration(self): + def _calculate_calibration(self) -> None: if len(self.calibration_samples) < 3: log.warning("Not enough calibration samples") return - samples = np.array(self.calibration_samples) - self.o1_neutral = np.mean(samples[:, 0]) - avg_o2 = np.mean(samples[:, 1]) - avg_o3 = np.mean(samples[:, 2]) - self.o2_neutral = avg_o2 - self.o3_neutral = avg_o3 + samples: NDArray[[List[int]]] = np.array(self.calibration_samples) + self.o1_neutral: float = np.mean(samples[:, 0]) + avg_o2: float = np.mean(samples[:, 1]) + avg_o3: float = np.mean(samples[:, 2]) + self.o2_neutral: float = avg_o2 + self.o3_neutral: float = avg_o3 log.info("Calibration complete: o1_neutral=%.2f, o2_neutral=%.2f, o3_neutral=%.2f", self.o1_neutral, self.o2_neutral, self.o3_neutral) self.calibration_complete = True - def calculate_orientation(self, o1, o2, o3): + def calculate_orientation(self, o1: float, o2: float, o3: float) -> Dict[str, float]: if not self.calibration_complete: return {'pitch': 0, 'yaw': 0} - o1_norm = o1 - self.o1_neutral - o2_norm = o2 - self.o2_neutral - o3_norm = o3 - self.o3_neutral - pitch = (o2_norm + o3_norm) / 2 / 32000 * 180 - yaw = (o2_norm - o3_norm) / 2 / 32000 * 180 + o1_norm: float = o1 - self.o1_neutral + o2_norm: float = o2 - self.o2_neutral + o3_norm: float = o3 - self.o3_neutral + pitch: float = (o2_norm + o3_norm) / 2 / 32000 * 180 + yaw: float = (o2_norm - o3_norm) / 2 / 32000 * 180 return {'pitch': pitch, 'yaw': yaw} - def create_face_art(self, pitch, yaw): + def create_face_art(self, pitch: float, yaw: float) -> str: if self.use_terminal: try: - ts = os.get_terminal_size() + ts: TerminalSize = os.get_terminal_size() width, height = ts.columns, ts.lines * 2 except Exception: width, height = 80, 40 else: width, height = 80, 40 center_x, center_y = width // 2, height // 2 - radius = (min(width, height) // 2 - 2) // 2 - pitch_rad = math.radians(pitch) - yaw_rad = math.radians(yaw) - canvas = drawille.Canvas() - def rotate_point(x, y, z, pitch_r, yaw_r): + radius: int = (min(width, height) // 2 - 2) // 2 + pitch_rad: float = math.radians(pitch) + yaw_rad: float = math.radians(yaw) + canvas: Canvas = Canvas() + + def rotate_point(x: float, y: float, z: float, pitch_r: float, yaw_r: float) -> Tuple[int, int]: cos_y, sin_y = math.cos(yaw_r), math.sin(yaw_r) cos_p, sin_p = math.cos(pitch_r), math.sin(pitch_r) - x1 = x * cos_y - z * sin_y - z1 = x * sin_y + z * cos_y - y1 = y * cos_p - z1 * sin_p - z2 = y * sin_p + z1 * cos_p - scale = 1 + (z2 / width) + x1: float = x * cos_y - z * sin_y + z1: float = x * sin_y + z * cos_y + y1: float = y * cos_p - z1 * sin_p + z2: float = y * sin_p + z1 * cos_p + scale: float = 1 + (z2 / width) return int(center_x + x1 * scale), int(center_y + y1 * scale) for angle in range(0, 360, 2): - rad = math.radians(angle) - x = radius * math.cos(rad) - y = radius * math.sin(rad) + rad: float = math.radians(angle) + x: float = radius * math.cos(rad) + y: float = radius * math.sin(rad) x1, y1 = rotate_point(x, y, 0, pitch_rad, yaw_rad) canvas.set(x1, y1) for eye in [(-radius//2, -radius//3, 2), (radius//2, -radius//3, 2)]: @@ -129,14 +110,14 @@ class HeadOrientation: for dx in [-1, 0, 1]: for dy in [-1, 0, 1]: canvas.set(nx + dx, ny + dy) - smile_depth = radius // 8 - mouth_local_y = radius // 4 - mouth_length = radius + smile_depth: int = radius // 8 + mouth_local_y: int = radius // 4 + mouth_length: int = radius for x_offset in range(-mouth_length // 2, mouth_length // 2 + 1): - norm = abs(x_offset) / (mouth_length / 2) - y_offset = int((1 - norm ** 2) * smile_depth) - local_x = x_offset - local_y = mouth_local_y + y_offset + norm: float = abs(x_offset) / (mouth_length / 2) + y_offset: int = int((1 - norm ** 2) * smile_depth) + local_x: int = x_offset + local_y: int = mouth_local_y + y_offset mx, my = rotate_point(local_x, local_y, 0, pitch_rad, yaw_rad) canvas.set(mx, my) return canvas.frame() diff --git a/head-tracking/plot.py b/head-tracking/plot.py index 38ccea1..b1ad79b 100644 --- a/head-tracking/plot.py +++ b/head-tracking/plot.py @@ -1,61 +1,41 @@ -import struct -import bluetooth -import threading -import time -from datetime import datetime -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.animation import FuncAnimation -import os import asciichartpy as acp +import logging +import matplotlib.pyplot as plt +import numpy as np +import os +import struct +import time +from bluetooth import BluetoothSocket +from colors import * +from connection_manager import ConnectionManager +from datetime import datetime as DateTime +from drawille import Canvas +from head_orientation import HeadOrientation +from logging import Logger, StreamHandler +from matplotlib.animation import FuncAnimation +from matplotlib.legend import Legend +from matplotlib.pyplot import Axes, Figure +from numpy.typing import NDArray from rich.live import Live from rich.layout import Layout from rich.panel import Panel from rich.console import Console -import drawille -from head_orientation import HeadOrientation -import logging -from connection_manager import ConnectionManager +from threading import Lock, Thread +from typing import Any, Dict, List, Optional, TextIO, Tuple, Union -class Colors: - RESET = "\033[0m" - BOLD = "\033[1m" - RED = "\033[91m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - BLUE = "\033[94m" - MAGENTA = "\033[95m" - CYAN = "\033[96m" - WHITE = "\033[97m" - BG_BLACK = "\033[40m" - -class ColorFormatter(logging.Formatter): - FORMATS = { - logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET, - logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET, - logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET, - logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET, - logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S") - return formatter.format(record) - -handler = logging.StreamHandler() +handler: StreamHandler = StreamHandler() handler.setFormatter(ColorFormatter()) -logger = logging.getLogger("airpods-head-tracking") +logger: Logger = logging.getLogger("airpods-head-tracking") logger.setLevel(logging.INFO) logger.addHandler(handler) logger.propagate = True -INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" -NOTIF_CMD = "04 00 04 00 0F 00 FF FF FE FF" -START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" -STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" +INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00" +NOTIF_CMD: str = "04 00 04 00 0F 00 FF FF FE FF" +START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00" +STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00" -KEY_FIELDS = { +KEY_FIELDS: Dict[str, Tuple[int, int]] = { "orientation 1": (43, 2), "orientation 2": (45, 2), "orientation 3": (47, 2), @@ -68,28 +48,28 @@ KEY_FIELDS = { } class AirPodsTracker: - def __init__(self): - self.sock = None - self.recording = False - self.log_file = None - self.listener_thread = None - self.bt_addr = "28:2D:7F:C2:05:5B" - self.psm = 0x1001 - self.raw_packets = [] - self.parsed_packets = [] - self.live_data = [] - self.live_plotting = False - self.animation = None - self.fig = None - self.axes = None - self.lines = {} - self.selected_fields = [] - self.data_lock = threading.Lock() - self.orientation_offset = 5500 - self.use_terminal = True # '--terminal' in sys.argv - self.orientation_visualizer = HeadOrientation(use_terminal=self.use_terminal) + def __init__(self) -> None: + self.sock: BluetoothSocket = None + self.recording: bool = False + self.log_file: Optional[TextIO] = None + self.listener_thread: Optional[Thread] = None + self.bt_addr: str = "28:2D:7F:C2:05:5B" + self.psm: int = 0x1001 + self.raw_packets: List[bytes] = [] + self.parsed_packets: List[bytes] = [] + self.live_data: List[bytes] = [] + self.live_plotting: bool = False + self.animation: FuncAnimation = None + self.fig: Optional[Figure] = None + self.axes: Optional[Axes] = None + self.lines: Dict[str, Any] = {} + self.selected_fields: List[str] = [] + self.data_lock: Lock = Lock() + self.orientation_offset: int = 5500 + self.use_terminal: bool = True # '--terminal' in sys.argv + self.orientation_visualizer: HeadOrientation = HeadOrientation(use_terminal=self.use_terminal) - self.conn = None + self.conn: Optional[ConnectionManager] = None def connect(self): try: @@ -102,35 +82,35 @@ class AirPodsTracker: self.sock.send(bytes.fromhex(NOTIF_CMD)) logger.info("Sent initialization command.") - self.listener_thread = threading.Thread(target=self.listen, daemon=True) + self.listener_thread = Thread(target=self.listen, daemon=True) self.listener_thread.start() return True except Exception as e: logger.error("Connection error: %s", e) return False - def start_tracking(self, duration=None): + def start_tracking(self, duration: Optional[float] = None) -> None: if not self.recording: self.conn.send_start() - filename = "head_tracking_" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".log" + filename: str = f"head_tracking_{DateTime.now().strftime('%Y%m%d_%H%M%S')}.log" self.log_file = open(filename, "w") self.recording = True logger.info("Recording started. Saving data to %s", filename) if duration is not None and duration > 0: - def auto_stop(): + def auto_stop() -> None: time.sleep(duration) if self.recording: self.stop_tracking() logger.info("Recording automatically stopped after %s seconds.", duration) - timer_thread = threading.Thread(target=auto_stop, daemon=True) + timer_thread = Thread(target=auto_stop, daemon=True) timer_thread.start() logger.info("Will automatically stop recording after %s seconds.", duration) else: logger.info("Already recording.") - def stop_tracking(self): + def stop_tracking(self) -> None: if self.recording: self.conn.send_stop() self.recording = False @@ -141,39 +121,41 @@ class AirPodsTracker: else: logger.info("Not currently recording.") - def format_hex(self, data): - hex_str = data.hex() + def format_hex(self, data: bytes) -> str: + hex_str: str = data.hex() return ' '.join(hex_str[i:i + 2] for i in range(0, len(hex_str), 2)) - def parse_raw_packet(self, hex_string): + def parse_raw_packet(self, hex_string: str) -> bytes: return bytes.fromhex(hex_string.replace(" ", "")) - def interpret_bytes(self, raw_bytes, start, length, data_type="signed_short"): + def interpret_bytes(self, raw_bytes: bytes, start: int, length: int, data_type: str = "signed_short") -> Optional[Union[int, float]]: if start + length > len(raw_bytes): return None - if data_type == "signed_short": - return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=True) - elif data_type == "unsigned_short": - return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=False) - elif data_type == "signed_short_be": - return int.from_bytes(raw_bytes[start:start + 2], byteorder='big', signed=True) - elif data_type == "float_le": - if start + 4 <= len(raw_bytes): - return struct.unpack('f', raw_bytes[start:start + 4])[0] - return None + match data_type: + case "signed_short": + return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=True) + case "unsigned_short": + return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=False) + case "signed_short_be": + return int.from_bytes(raw_bytes[start:start + 2], byteorder='big', signed=True) + case "float_le": + if start + 4 <= len(raw_bytes): + return struct.unpack('f', raw_bytes[start:start + 4])[0] + case _: + return None - def normalize_orientation(self, value, field_name): + def normalize_orientation(self, value: Optional[Union[int, float]], field_name: str) -> Optional[Union[int, float]]: if 'orientation' in field_name.lower(): return value + self.orientation_offset return value - def parse_packet_all_fields(self, raw_bytes): - packet = {} + def parse_packet_all_fields(self, raw_bytes: bytes) -> Dict[str, Union[int, float]]: + packet: Dict[str, Union[int, float]] = {} packet["seq_num"] = int.from_bytes(raw_bytes[12:14], byteorder='little') @@ -186,14 +168,14 @@ class AirPodsTracker: packet[field_name] = self.normalize_orientation(raw_value, field_name) for i in range(30, min(90, len(raw_bytes) - 1), 2): - field_name = f"byte_{i:02d}" - raw_value = self.interpret_bytes(raw_bytes, i, 2, "signed_short") + field_name: str = f"byte_{i:02d}" + raw_value: Optional[Union[int, float]] = self.interpret_bytes(raw_bytes, i, 2, "signed_short") if raw_value is not None: packet[field_name] = self.normalize_orientation(raw_value, field_name) return packet - def apply_dark_theme(self, fig, axes): + def apply_dark_theme(self, fig: Figure, axes: List[Axes]) -> None: fig.patch.set_facecolor('#1e1e1e') for ax in axes: ax.set_facecolor('#2d2d2d') @@ -210,21 +192,21 @@ class AirPodsTracker: for spine in ax.spines.values(): spine.set_color('#555555') - legend = ax.get_legend() + legend: Optional[Legend] = ax.get_legend() if (legend): legend.get_frame().set_facecolor('#2d2d2d') legend.get_frame().set_alpha(0.7) for text in legend.get_texts(): text.set_color('white') - def listen(self): + def listen(self) -> None: while True: try: - data = self.sock.recv(1024) - formatted = self.format_hex(data) - timestamp = datetime.now().isoformat() + data: bytes = self.sock.recv(1024) + formatted: str = self.format_hex(data) + timestamp: str = DateTime.now().isoformat() - is_valid = self.is_valid_tracking_packet(formatted) + is_valid: bool = self.is_valid_tracking_packet(formatted) if not self.live_plotting: if is_valid: @@ -238,8 +220,8 @@ class AirPodsTracker: self.log_file.flush() try: - raw_bytes = self.parse_raw_packet(formatted) - packet = self.parse_packet_all_fields(raw_bytes) + raw_bytes: bytes = self.parse_raw_packet(formatted) + packet: Dict[str, Union[int, float]] = self.parse_packet_all_fields(raw_bytes) with self.data_lock: self.live_data.append(packet) @@ -253,7 +235,7 @@ class AirPodsTracker: logger.error("Error receiving data: %s", e) break - def load_log_file(self, filepath): + def load_log_file(self, filepath: str) -> bool: self.raw_packets = [] self.parsed_packets = [] try: @@ -262,11 +244,11 @@ class AirPodsTracker: line = line.strip() if line: try: - raw_bytes = self.parse_raw_packet(line) + raw_bytes: bytes = self.parse_raw_packet(line) self.raw_packets.append(raw_bytes) - packet = self.parse_packet_all_fields(raw_bytes) + packet: Dict[str, Union[int, float]] = self.parse_packet_all_fields(raw_bytes) - min_seq_num = min( + min_seq_num: int = min( [parsed_packet["seq_num"] for parsed_packet in self.parsed_packets], default=0 ) @@ -282,26 +264,26 @@ class AirPodsTracker: logger.error(f"Error loading log file: {e}") return False - def extract_field_values(self, field_name, data_source='loaded'): + def extract_field_values(self, field_name: str, data_source: str = 'loaded') -> List[Union[int, float]]: if data_source == 'loaded': - data = self.parsed_packets + data: List[Dict[str, Union[int, float]]] = self.parsed_packets else: with self.data_lock: - data = self.live_data.copy() + data: List[Dict[str, Union[int, float]]] = self.live_data.copy() - values = [packet.get(field_name, 0) for packet in data if field_name in packet] + values: List[Union[int, float]] = [packet.get(field_name, 0) for packet in data if field_name in packet] if data_source == 'live' and len(values) > 5: try: - values = np.array(values, dtype=float) + values: NDArray[Any] = np.array(values, dtype=float) values = np.convolve(values, np.ones(5) / 5, mode='valid') except Exception as e: logger.warning(f"Smoothing error (non-critical): {e}") return values - def is_valid_tracking_packet(self, hex_string): - standard_header = "04 00 04 00 17 00 00 00 10 00" + def is_valid_tracking_packet(self, hex_string: str) -> bool: + standard_header: str = "04 00 04 00 17 00 00 00 10 00" if not hex_string.startswith(standard_header): if self.live_plotting: @@ -316,13 +298,13 @@ class AirPodsTracker: return True - def plot_fields(self, field_names=None): + def plot_fields(self, field_names: Optional[List[str]] = None) -> None: if not self.parsed_packets: logger.error("No data to plot. Load a log file first.") return if field_names is None: - field_names = list(KEY_FIELDS.keys()) + field_names: List[str] = list(KEY_FIELDS.keys()) if not self.orientation_visualizer.calibration_complete: if len(self.parsed_packets) < self.orientation_visualizer.calibration_sample_count: @@ -339,16 +321,16 @@ class AirPodsTracker: self._plot_fields_terminal(field_names) else: - acceleration_fields = [f for f in field_names if 'acceleration' in f.lower()] - orientation_fields = [f for f in field_names if 'orientation' in f.lower()] - other_fields = [f for f in field_names if f not in acceleration_fields + orientation_fields] + acceleration_fields: List[str] = [f for f in field_names if 'acceleration' in f.lower()] + orientation_fields: List[str] = [f for f in field_names if 'orientation' in f.lower()] + other_fields: List[str] = [f for f in field_names if f not in acceleration_fields + orientation_fields] fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True) self.apply_dark_theme(fig, axes) - acceleration_colors = ['#FFFF00', '#00FFFF'] - orientation_colors = ['#FF00FF', '#00FF00', '#FFA500'] - other_colors = ['#52b788', '#f4a261', '#e76f51', '#2a9d8f'] + acceleration_colors: List[str] = ['#FFFF00', '#00FFFF'] + orientation_colors: List[str] = ['#FF00FF', '#00FF00', '#FFA500'] + other_colors: List[str] = ['#52b788', '#f4a261', '#e76f51', '#2a9d8f'] if acceleration_fields: for i, field in enumerate(acceleration_fields): @@ -375,17 +357,17 @@ class AirPodsTracker: plt.tight_layout() plt.show() - def _plot_fields_terminal(self, field_names): + def _plot_fields_terminal(self, field_names: List[str]) -> None: """Internal method for terminal-based plotting""" - terminal_width = os.get_terminal_size().columns - plot_width = min(terminal_width - 10, 120) - plot_height = 15 + terminal_width: int = os.get_terminal_size().columns + plot_width: int = min(terminal_width - 10, 120) + plot_height: int = 15 - acceleration_fields = [f for f in field_names if 'acceleration' in f.lower()] - orientation_fields = [f for f in field_names if 'orientation' in f.lower()] - other_fields = [f for f in field_names if f not in acceleration_fields + orientation_fields] + acceleration_fields: List[str] = [f for f in field_names if 'acceleration' in f.lower()] + orientation_fields: List[str] = [f for f in field_names if 'orientation' in f.lower()] + other_fields: List[str] = [f for f in field_names if f not in acceleration_fields + orientation_fields] - def plot_group(fields, title): + def plot_group(fields: List[str], title: str) -> None: if not fields: return @@ -393,40 +375,39 @@ class AirPodsTracker: print("=" * len(title)) for field in fields: - values = self.extract_field_values(field) + values: List[float] = self.extract_field_values(field) if len(values) > plot_width: values = values[-plot_width:] if title == "Acceleration Data": - chart = acp.plot(values, {'height': plot_height}) + chart: str = acp.plot(values, {'height': plot_height}) print(chart) else: - chart = acp.plot(values, {'height': plot_height}) + chart: str = acp.plot(values, {'height': plot_height}) print(chart) - print(f"Min: {min(values):.2f}, Max: {max(values):.2f}, " + - f"Mean: {np.mean(values):.2f}") + print(f"Min: {min(values):.2f}, Max: {max(values):.2f}, " + f"Mean: {np.mean(values):.2f}") print() plot_group(acceleration_fields, "Acceleration Data") plot_group(orientation_fields, "Orientation Data") plot_group(other_fields, "Other Fields") - def create_braille_plot(self, values, width=80, height=20, y_label=True, fixed_y_min=None, fixed_y_max=None): - canvas = drawille.Canvas() + def create_braille_plot(self, values: List[float], width: int = 80, height: int = 20, y_label: bool = True, fixed_y_min: Optional[float] = None, fixed_y_max: Optional[float] = None) -> str: + canvas: Canvas = Canvas() if fixed_y_min is None or fixed_y_max is None: local_min, local_max = min(values), max(values) else: local_min, local_max = fixed_y_min, fixed_y_max - y_range = local_max - local_min or 1 - x_step = max(1, len(values) // width) + y_range: float = local_max - local_min or 1 + x_step: int = max(1, len(values) // width) for i, v in enumerate(values[::x_step]): - y = int(((v - local_min) / y_range) * (height * 2 - 1)) + y: int = int(((v - local_min) / y_range) * (height * 2 - 1)) canvas.set(i, y) - frame = canvas.frame() + frame: str = canvas.frame() if y_label: - lines = frame.split('\n') - labeled_lines = [] + lines: List[str] = frame.split('\n') + labeled_lines: List[str] = [] for idx, line in enumerate(lines): if idx == 0: labeled_lines.append(f"{local_max:6.0f} {line}") @@ -437,17 +418,17 @@ class AirPodsTracker: frame = "\n".join(labeled_lines) return frame - def _start_live_plotting_terminal(self, record_data=False, duration=None): + def _start_live_plotting_terminal(self, record_data: bool = False, duration: Optional[float] = None) -> None: import sys, select, tty, termios old_settings = termios.tcgetattr(sys.stdin) tty.setcbreak(sys.stdin.fileno()) - console = Console() - term_width = console.width - plot_width = round(min(term_width / 2 - 15, 120)) - ori_height = 10 + console: Console = Console() + term_width: int = console.width + plot_width: int = round(min(term_width / 2 - 15, 120)) + ori_height: int = 10 - def make_compact_layout(): - layout = Layout() + def make_compact_layout() -> Layout: + layout: Layout = Layout() layout.split_column( Layout(name="header", size=3), Layout(name="main", ratio=1), @@ -466,7 +447,7 @@ class AirPodsTracker: ) return layout - layout = make_compact_layout() + layout: Layout = make_compact_layout() try: import time @@ -479,76 +460,76 @@ class AirPodsTracker: logger.info("Paused" if self.paused else "Resumed") if self.paused: time.sleep(0.1) - rec_str = " [red][REC][/red]" if record_data else "" - left = "AirPods Head Tracking - v1.0.0" - right = "Ctrl+C - Close | p - Pause" + rec_str - status = "[bold red]Paused[/bold red]" - header = list(" " * term_width) + rec_str: str = " [red][REC][/red]" if record_data else "" + left: str = "AirPods Head Tracking - v1.0.0" + right: str = "Ctrl+C - Close | p - Pause" + rec_str + status: str = "[bold red]Paused[/bold red]" + header: List[str] = list(" " * term_width) header[0:len(left)] = list(left) header[term_width - len(right):] = list(right) - start = (term_width - len(status)) // 2 + start: int = (term_width - len(status)) // 2 header[start:start+len(status)] = list(status) - header_text = "".join(header) + header_text: str = "".join(header) layout["header"].update(Panel(header_text, style="bold white on black")) continue with self.data_lock: if len(self.live_data) < 1: continue - latest = self.live_data[-1] - data = self.live_data[-plot_width:] + latest: Dict[str, float] = self.live_data[-1] + data: List[Dict[str, float]] = self.live_data[-plot_width:] if not self.orientation_visualizer.calibration_complete: - sample = [ + sample: List[float] = [ latest.get('orientation 1', 0), latest.get('orientation 2', 0), latest.get('orientation 3', 0) ] self.orientation_visualizer.add_calibration_sample(sample) time.sleep(0.05) - rec_str = " [red][REC][/red]" if record_data else "" + rec_str: str = " [red][REC][/red]" if record_data else "" - left = "AirPods Head Tracking - v1.0.0" - status = "[bold yellow]Calibrating...[/bold yellow]" - right = "Ctrl+C - Close | p - Pause" - remaining = max(term_width - len(left) - len(right), 0) - header_text = f"{left}{status.center(remaining)}{right}{rec_str}" + left: str = "AirPods Head Tracking - v1.0.0" + status: str = "[bold yellow]Calibrating...[/bold yellow]" + right: str = "Ctrl+C - Close | p - Pause" + remaining: int = max(term_width - len(left) - len(right), 0) + header_text: str = f"{left}{status.center(remaining)}{right}{rec_str}" layout["header"].update(Panel(header_text, style="bold white on black")) live.refresh() continue - o1 = latest.get('orientation 1', 0) - o2 = latest.get('orientation 2', 0) - o3 = latest.get('orientation 3', 0) - orientation = self.orientation_visualizer.calculate_orientation(o1, o2, o3) - pitch = orientation['pitch'] - yaw = orientation['yaw'] + o1: float = latest.get('orientation 1', 0) + o2: float = latest.get('orientation 2', 0) + o3: float = latest.get('orientation 3', 0) + orientation: Dict[str, float] = self.orientation_visualizer.calculate_orientation(o1, o2, o3) + pitch: float = orientation['pitch'] + yaw: float = orientation['yaw'] - h_accel = [p.get('Horizontal Acceleration', 0) for p in data] - v_accel = [p.get('Vertical Acceleration', 0) for p in data] + h_accel: List[float] = [p.get('Horizontal Acceleration', 0) for p in data] + v_accel: List[float] = [p.get('Vertical Acceleration', 0) for p in data] if len(h_accel) > plot_width: h_accel = h_accel[-plot_width:] if len(v_accel) > plot_width: v_accel = v_accel[-plot_width:] - global_min = min(min(v_accel), min(h_accel)) - global_max = max(max(v_accel), max(h_accel)) - config_acc = {'height': 20, 'min': global_min, 'max': global_max} - vert_plot = acp.plot(v_accel, config_acc) - horiz_plot = acp.plot(h_accel, config_acc) + global_min: float = min(min(v_accel), min(h_accel)) + global_max: float = max(max(v_accel), max(h_accel)) + config_acc: Dict[str, float] = {'height': 20, 'min': global_min, 'max': global_max} + vert_plot: str = acp.plot(v_accel, config_acc) + horiz_plot: str = acp.plot(h_accel, config_acc) - rec_str = " [red][REC][/red]" if record_data else "" - left = "AirPods Head Tracking - v1.0.0" - right = "Ctrl+C - Close | p - Pause" + rec_str - status = "[bold green]Live[/bold green]" - header = list(" " * term_width) + rec_str: str = " [red][REC][/red]" if record_data else "" + left: str = "AirPods Head Tracking - v1.0.0" + right: str = "Ctrl+C - Close | p - Pause" + rec_str + status: str = "[bold green]Live[/bold green]" + header: List[str] = list(" " * term_width) header[0:len(left)] = list(left) header[term_width - len(right):] = list(right) - start = (term_width - len(status)) // 2 + start: int = (term_width - len(status)) // 2 header[start:start+len(status)] = list(status) - header_text = "".join(header) + header_text: str = "".join(header) layout["header"].update(Panel(header_text, style="bold white on black")) - face_art = self.orientation_visualizer.create_face_art(pitch, yaw) + face_art: str = self.orientation_visualizer.create_face_art(pitch, yaw) layout["accelerations"]["vertical"].update(Panel( "[bold yellow]Vertical Acceleration[/]\n" + vert_plot + "\n" + @@ -563,15 +544,15 @@ class AirPodsTracker: )) layout["orientations"]["face"].update(Panel(face_art, title="[green]Orientation - Visualization[/]", style="green")) - o2_values = [p.get('orientation 2', 0) for p in data[-plot_width:]] - o3_values = [p.get('orientation 3', 0) for p in data[-plot_width:]] - o2_values = o2_values[:plot_width] - o3_values = o3_values[:plot_width] - common_min = min(min(o2_values), min(o3_values)) - common_max = max(max(o2_values), max(o3_values)) - config_ori = {'height': ori_height, 'min': common_min, 'max': common_max, 'format': "{:6.0f}"} - chart_o2 = acp.plot(o2_values, config_ori) - chart_o3 = acp.plot(o3_values, config_ori) + o2_values: List[float] = [p.get('orientation 2', 0) for p in data[-plot_width:]] + o3_values: List[float] = [p.get('orientation 3', 0) for p in data[-plot_width:]] + o2_values: List[float] = o2_values[:plot_width] + o3_values: List[float] = o3_values[:plot_width] + common_min: float = min(min(o2_values), min(o3_values)) + common_max: float = max(max(o2_values), max(o3_values)) + config_ori: Dict[str, float] = {'height': ori_height, 'min': common_min, 'max': common_max, 'format': "{:6.0f}"} + chart_o2: str = acp.plot(o2_values, config_ori) + chart_o3: str = acp.plot(o3_values, config_ori) layout["orientations"]["raw"].update(Panel( "[bold yellow]Orientation 1:[/]\n" + chart_o2 + "\n" + f"Cur: {o2_values[-1]:6.1f} | Min: {min(o2_values):6.1f} | Max: {max(o2_values):6.1f}\n\n" + @@ -591,10 +572,10 @@ class AirPodsTracker: finally: termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) - def _start_live_plotting(self, record_data=False, duration=None): - terminal_width = os.get_terminal_size().columns - plot_width = min(terminal_width - 10, 80) - plot_height = 10 + def _start_live_plotting(self, record_data: bool = False, duration: Optional[float] = None) -> None: + terminal_width: int = os.get_terminal_size().columns + plot_width: int = min(terminal_width - 10, 80) + plot_height: int = 10 try: while True: @@ -605,13 +586,13 @@ class AirPodsTracker: time.sleep(0.1) continue - data = self.live_data[-plot_width:] + data: List[Dict[str, float]] = self.live_data[-plot_width:] - acceleration_fields = [f for f in KEY_FIELDS.keys() if 'acceleration' in f.lower()] - orientation_fields = [f for f in KEY_FIELDS.keys() if 'orientation' in f.lower()] - other_fields = [f for f in KEY_FIELDS.keys() if f not in acceleration_fields + orientation_fields] + acceleration_fields: List[str] = [f for f in KEY_FIELDS.keys() if 'acceleration' in f.lower()] + orientation_fields: List[str] = [f for f in KEY_FIELDS.keys() if 'orientation' in f.lower()] + other_fields: List[str] = [f for f in KEY_FIELDS.keys() if f not in acceleration_fields + orientation_fields] - def plot_group(fields, title): + def plot_group(fields: List[str], title: str) -> None: if not fields: return @@ -619,9 +600,9 @@ class AirPodsTracker: print("=" * len(title)) for field in fields: - values = [packet.get(field, 0) for packet in data if field in packet] + values: List[float] = [packet.get(field, 0) for packet in data if field in packet] if len(values) > 0: - chart = acp.plot(values, {'height': plot_height}) + chart: str = acp.plot(values, {'height': plot_height}) print(chart) print(f"Current: {values[-1]:.2f}, " + f"Min: {min(values):.2f}, Max: {max(values):.2f}") @@ -641,7 +622,7 @@ class AirPodsTracker: self.stop_tracking() self.live_plotting = False - def start_live_plotting(self, record_data=False, duration=None): + def start_live_plotting(self, record_data: bool = False, duration: Optional[float] = None) -> None: if self.sock is None: if not self.connect(): logger.error("Could not connect to AirPods. Live plotting aborted.") @@ -660,12 +641,12 @@ class AirPodsTracker: self._start_live_plotting_terminal(record_data, duration) else: from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec - fig = plt.figure(figsize=(14, 6)) - gs = GridSpec(1, 2, width_ratios=[1, 1]) - ax_accel = fig.add_subplot(gs[0]) - subgs = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[1], height_ratios=[2, 1]) - ax_head_top = fig.add_subplot(subgs[0], projection='3d') - ax_ori = fig.add_subplot(subgs[1]) + fig: Figure = plt.figure(figsize=(14, 6)) + gs: GridSpec = GridSpec(1, 2, width_ratios=[1, 1]) + ax_accel: Axes = fig.add_subplot(gs[0]) + subgs: GridSpecFromSubplotSpec = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[1], height_ratios=[2, 1]) + ax_head_top: Axes = fig.add_subplot(subgs[0], projection='3d') + ax_ori: Axes = fig.add_subplot(subgs[1]) ax_accel.set_title("Acceleration Data") ax_accel.set_xlabel("Packet Index") @@ -676,16 +657,16 @@ class AirPodsTracker: self.apply_dark_theme(fig, [ax_accel, ax_head_top, ax_ori]) plt.ion() - def update_plot(_): + def update_plot(_: int) -> None: with self.data_lock: - data = self.live_data.copy() + data: List[Dict[str, float]] = self.live_data.copy() if len(data) == 0: return - latest = data[-1] + latest: Dict[str, float] = data[-1] if not self.orientation_visualizer.calibration_complete: - sample = [ + sample: List[float] = [ latest.get('orientation 1', 0), latest.get('orientation 2', 0), latest.get('orientation 3', 0) @@ -696,9 +677,9 @@ class AirPodsTracker: fig.canvas.draw_idle() return - h_accel = [p.get('Horizontal Acceleration', 0) for p in data] - v_accel = [p.get('Vertical Acceleration', 0) for p in data] - x_vals = list(range(len(h_accel))) + h_accel: List[float] = [p.get('Horizontal Acceleration', 0) for p in data] + v_accel: List[float] = [p.get('Vertical Acceleration', 0) for p in data] + x_vals: List[int] = list(range(len(h_accel))) ax_accel.cla() ax_accel.plot(x_vals, v_accel, label='Vertical Acceleration', color='#FFFF00', linewidth=2) ax_accel.plot(x_vals, h_accel, label='Horizontal Acceleration', color='#00FFFF', linewidth=2) @@ -711,13 +692,13 @@ class AirPodsTracker: ax_accel.xaxis.label.set_color('white') ax_accel.yaxis.label.set_color('white') - latest = data[-1] - o1 = latest.get('orientation 1', 0) - o2 = latest.get('orientation 2', 0) - o3 = latest.get('orientation 3', 0) - orientation = self.orientation_visualizer.calculate_orientation(o1, o2, o3) - pitch = orientation['pitch'] - yaw = orientation['yaw'] + latest: Dict[str, float] = data[-1] + o1: float = latest.get('orientation 1', 0) + o2: float = latest.get('orientation 2', 0) + o3: float = latest.get('orientation 3', 0) + orientation: Dict[str, float] = self.orientation_visualizer.calculate_orientation(o1, o2, o3) + pitch: float = orientation['pitch'] + yaw: float = orientation['yaw'] ax_head_top.cla() ax_head_top.set_title("Head Orientation") @@ -727,25 +708,25 @@ class AirPodsTracker: ax_head_top.set_facecolor('#2d2d2d') pitch_rad = np.radians(pitch) yaw_rad = np.radians(yaw) - Rz = np.array([ + Rz: NDArray[Any] = np.array([ [np.cos(yaw_rad), np.sin(yaw_rad), 0], [-np.sin(yaw_rad), np.cos(yaw_rad), 0], [0, 0, 1] ]) - Ry = np.array([ + Ry: NDArray[Any] = np.array([ [np.cos(pitch_rad), 0, np.sin(pitch_rad)], [0, 1, 0], [-np.sin(pitch_rad), 0, np.cos(pitch_rad)] ]) - R = Rz @ Ry - dir_vec = R @ np.array([1, 0, 0]) + R: NDArray[Any] = Rz @ Ry + dir_vec: NDArray[Any] = R @ np.array([1, 0, 0]) ax_head_top.quiver(0, 0, 0, dir_vec[0], dir_vec[1], dir_vec[2], color='r', length=0.8, linewidth=3) ax_ori.cla() - o2_values = [p.get('orientation 2', 0) for p in data] - o3_values = [p.get('orientation 3', 0) for p in data] - x_range = list(range(len(o2_values))) + o2_values: List[float] = [p.get('orientation 2', 0) for p in data] + o3_values: List[float] = [p.get('orientation 3', 0) for p in data] + x_range: List[int] = list(range(len(o2_values))) ax_ori.plot(x_range, o2_values, label='Orientation 1', color='red', linewidth=2) ax_ori.plot(x_range, o3_values, label='Orientation 2', color='green', linewidth=2) ax_ori.set_facecolor('#2d2d2d') @@ -775,9 +756,9 @@ class AirPodsTracker: self.animation = None plt.ioff() - def interactive_mode(self): + def interactive_mode(self) -> None: from prompt_toolkit import PromptSession - session = PromptSession("> ") + session: PromptSession = PromptSession("> ") logger.info("\nAirPods Head Tracking Analyzer") print("------------------------------") logger.info("Commands:") @@ -793,59 +774,61 @@ class AirPodsTracker: while True: try: - cmd_input = session.prompt("> ") - cmd_parts = cmd_input.strip().split() + cmd_input: str = session.prompt("> ") + cmd_parts: List[str] = cmd_input.strip().split() if not cmd_parts: continue cmd = cmd_parts[0].lower() - if cmd == "connect": - self.connect() - elif cmd == "start": - duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None - self.start_tracking(duration) - elif cmd == "stop": - self.stop_tracking() - elif cmd == "load" and len(cmd_parts) > 1: - self.load_log_file(cmd_parts[1]) - elif cmd == "plot": - self.plot_fields() - elif cmd == "live": - duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None - logger.info("Starting live plotting mode (without recording)%s.", - f" for {duration} seconds" if duration else "") - self.start_live_plotting(record_data=False, duration=duration) - elif cmd == "liver": - duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None - logger.info("Starting live plotting mode WITH recording%s.", - f" for {duration} seconds" if duration else "") - self.start_live_plotting(record_data=True, duration=duration) - elif cmd == "gestures": - from gestures import GestureDetector - if self.conn is not None: - detector = GestureDetector(conn=self.conn) - else: - detector = GestureDetector() - detector.start_detection() - elif cmd == "quit": - logger.info("Exiting.") - if self.conn != None: - self.conn.disconnect() - break - elif cmd == "help": - logger.info("\nAirPods Head Tracking Analyzer") - logger.info("------------------------------") - logger.info("Commands:") - logger.info(" connect - connect to your AirPods") - logger.info(" start [seconds] - start recording head tracking data, optionally for specified duration") - logger.info(" stop - stop recording") - logger.info(" load - load and parse a log file") - logger.info(" plot - plot all sensor data fields") - logger.info(" live [seconds] - start live plotting (without recording), optionally stop recording after seconds") - logger.info(" liver [seconds] - start live plotting with recording, optionally stop recording after seconds") - logger.info(" gestures - start gesture detection") - logger.info(" quit - exit the program") - else: - logger.info("Unknown command. Type 'help' to see available commands.") + match cmd: + case "connect": + self.connect() + case "start": + duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None + self.start_tracking(duration) + case "stop": + self.stop_tracking() + case "load": + if len(cmd_parts) > 1: + self.load_log_file(cmd_parts[1]) + case "plot": + self.plot_fields() + case "live": + duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None + logger.info("Starting live plotting mode (without recording)%s.", + f" for {duration} seconds" if duration else "") + self.start_live_plotting(record_data=False, duration=duration) + case "liver": + duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None + logger.info("Starting live plotting mode WITH recording%s.", + f" for {duration} seconds" if duration else "") + self.start_live_plotting(record_data=True, duration=duration) + case "gestures": + from gestures import GestureDetector + if self.conn is not None: + detector: GestureDetector = GestureDetector(conn=self.conn) + else: + detector: GestureDetector = GestureDetector() + detector.start_detection() + case "quit": + logger.info("Exiting.") + if self.conn != None: + self.conn.disconnect() + break + case "help": + logger.info("\nAirPods Head Tracking Analyzer") + logger.info("------------------------------") + logger.info("Commands:") + logger.info(" connect - connect to your AirPods") + logger.info(" start [seconds] - start recording head tracking data, optionally for specified duration") + logger.info(" stop - stop recording") + logger.info(" load - load and parse a log file") + logger.info(" plot - plot all sensor data fields") + logger.info(" live [seconds] - start live plotting (without recording), optionally stop recording after seconds") + logger.info(" liver [seconds] - start live plotting with recording, optionally stop recording after seconds") + logger.info(" gestures - start gesture detection") + logger.info(" quit - exit the program") + case _: + logger.info("Unknown command. Type 'help' to see available commands.") except KeyboardInterrupt: logger.info("Use 'quit' to exit.") except EOFError: @@ -856,5 +839,5 @@ class AirPodsTracker: if __name__ == "__main__": import sys - tracker = AirPodsTracker() - tracker.interactive_mode() \ No newline at end of file + tracker: AirPodsTracker = AirPodsTracker() + tracker.interactive_mode() diff --git a/linux/hearing-aid-adjustments.py b/linux/hearing-aid-adjustments.py index 2312b8e..dfa7c1e 100644 --- a/linux/hearing-aid-adjustments.py +++ b/linux/hearing-aid-adjustments.py @@ -1,10 +1,13 @@ -import sys -import socket -import struct -import threading -from queue import Queue import logging import signal +import socket +import struct +import sys +import threading +from socket import socket as Socket, TimeoutError +from queue import Queue +from threading import Thread +from typing import Any, Dict, List, Optional # Configure logging logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') @@ -12,47 +15,47 @@ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - % from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QSlider, QCheckBox, QPushButton, QLineEdit, QFormLayout, QGridLayout from PyQt5.QtCore import Qt, QTimer, pyqtSignal, QObject -OPCODE_READ_REQUEST = 0x0A -OPCODE_WRITE_REQUEST = 0x12 -OPCODE_HANDLE_VALUE_NTF = 0x1B +OPCODE_READ_REQUEST: int = 0x0A +OPCODE_WRITE_REQUEST: int = 0x12 +OPCODE_HANDLE_VALUE_NTF: int = 0x1B -ATT_HANDLES = { +ATT_HANDLES: Dict[str, int] = { 'TRANSPARENCY': 0x18, 'LOUD_SOUND_REDUCTION': 0x1B, 'HEARING_AID': 0x2A, } -ATT_CCCD_HANDLES = { +ATT_CCCD_HANDLES: Dict[str, int] = { 'TRANSPARENCY': ATT_HANDLES['TRANSPARENCY'] + 1, 'LOUD_SOUND_REDUCTION': ATT_HANDLES['LOUD_SOUND_REDUCTION'] + 1, 'HEARING_AID': ATT_HANDLES['HEARING_AID'] + 1, } -PSM_ATT = 31 +PSM_ATT: int = 31 class ATTManager: - def __init__(self, mac_address): - self.mac_address = mac_address - self.sock = None - self.responses = Queue() - self.listeners = {} - self.notification_thread = None - self.running = False + def __init__(self, mac_address: str) -> None: + self.mac_address: str = mac_address + self.sock: Optional[Socket] = None + self.responses: Queue = Queue() + self.listeners: Dict[int, List[Any]] = {} + self.notification_thread: Optional[Thread] = None + self.running: bool = False # Avoid logging full MAC address to prevent sensitive data exposure - mac_tail = ':'.join(mac_address.split(':')[-2:]) if isinstance(mac_address, str) and ':' in mac_address else '[redacted]' + mac_tail: str = ':'.join(mac_address.split(':')[-2:]) if isinstance(mac_address, str) and ':' in mac_address else '[redacted]' logging.info(f"ATTManager initialized") - def connect(self): + def connect(self) -> None: logging.info("Attempting to connect to ATT socket") - self.sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + self.sock = Socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) self.sock.connect((self.mac_address, PSM_ATT)) self.sock.settimeout(0.1) self.running = True - self.notification_thread = threading.Thread(target=self._listen_notifications) + self.notification_thread = Thread(target=self._listen_notifications) self.notification_thread.start() logging.info("Connected to ATT socket") - def disconnect(self): + def disconnect(self) -> None: logging.info("Disconnecting from ATT socket") self.running = False if self.sock: @@ -63,37 +66,37 @@ class ATTManager: self.notification_thread.join(timeout=1.0) logging.info("Disconnected from ATT socket") - def register_listener(self, handle, listener): + def register_listener(self, handle: int, listener: Any) -> None: if handle not in self.listeners: self.listeners[handle] = [] self.listeners[handle].append(listener) logging.debug(f"Registered listener for handle {handle}") - def unregister_listener(self, handle, listener): + def unregister_listener(self, handle: int, listener: Any) -> None: if handle in self.listeners: self.listeners[handle].remove(listener) logging.debug(f"Unregistered listener for handle {handle}") - def enable_notifications(self, handle): + def enable_notifications(self, handle: Any) -> None: self.write_cccd(handle, b'\x01\x00') logging.info(f"Enabled notifications for handle {handle.name}") - def read(self, handle): - handle_value = ATT_HANDLES[handle.name] - lsb = handle_value & 0xFF - msb = (handle_value >> 8) & 0xFF - pdu = bytes([OPCODE_READ_REQUEST, lsb, msb]) + def read(self, handle: Any) -> bytes: + handle_value: int = ATT_HANDLES[handle.name] + lsb: int = handle_value & 0xFF + msb: int = (handle_value >> 8) & 0xFF + pdu: bytes = bytes([OPCODE_READ_REQUEST, lsb, msb]) logging.debug(f"Sending read request for handle {handle.name}: {pdu.hex()}") self._write_raw(pdu) - response = self._read_response() + response: bytes = self._read_response() logging.debug(f"Read response for handle {handle.name}: {response.hex()}") return response - def write(self, handle, value): - handle_value = ATT_HANDLES[handle.name] - lsb = handle_value & 0xFF - msb = (handle_value >> 8) & 0xFF - pdu = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value + def write(self, handle: Any, value: bytes) -> None: + handle_value: int = ATT_HANDLES[handle.name] + lsb: int = handle_value & 0xFF + msb: int = (handle_value >> 8) & 0xFF + pdu: bytes = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value logging.debug(f"Sending write request for handle {handle.name}: {pdu.hex()}") self._write_raw(pdu) try: @@ -102,11 +105,11 @@ class ATTManager: except: logging.warning(f"No write response received for handle {handle.name}") - def write_cccd(self, handle, value): - handle_value = ATT_CCCD_HANDLES[handle.name] - lsb = handle_value & 0xFF - msb = (handle_value >> 8) & 0xFF - pdu = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value + def write_cccd(self, handle: Any, value: bytes) -> None: + handle_value: int = ATT_CCCD_HANDLES[handle.name] + lsb: int = handle_value & 0xFF + msb: int = (handle_value >> 8) & 0xFF + pdu: bytes = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value logging.debug(f"Sending CCCD write request for handle {handle.name}: {pdu.hex()}") self._write_raw(pdu) try: @@ -115,42 +118,42 @@ class ATTManager: except: logging.warning(f"No CCCD write response received for handle {handle.name}") - def _write_raw(self, pdu): + def _write_raw(self, pdu: bytes) -> None: self.sock.send(pdu) logging.debug(f"Sent PDU: {pdu.hex()}") - def _read_pdu(self): + def _read_pdu(self) -> Optional[bytes]: try: - data = self.sock.recv(512) + data: bytes = self.sock.recv(512) logging.debug(f"Received PDU: {data.hex()}") return data - except socket.timeout: + except TimeoutError: return None except: raise - def _read_response(self, timeout=2.0): + def _read_response(self, timeout: float = 2.0) -> bytes: try: - response = self.responses.get(timeout=timeout)[1:] # Skip opcode + response: bytes = self.responses.get(timeout=timeout)[1:] # Skip opcode logging.debug(f"Response received: {response.hex()}") return response except: logging.error("No response received within timeout") raise Exception("No response received") - def _listen_notifications(self): + def _listen_notifications(self) -> None: logging.info("Starting notification listener thread") while self.running: try: - pdu = self._read_pdu() + pdu: Optional[bytes] = self._read_pdu() except: break if pdu is None: continue if len(pdu) > 0 and pdu[0] == OPCODE_HANDLE_VALUE_NTF: logging.debug(f"Notification PDU received: {pdu.hex()}") - handle = pdu[1] | (pdu[2] << 8) - value = pdu[3:] + handle: int = pdu[1] | (pdu[2] << 8) + value: bytes = pdu[3:] logging.debug(f"Notification for handle {handle}: {value.hex()}") if handle in self.listeners: for listener in self.listeners[handle]: @@ -165,36 +168,36 @@ class ATTManager: logging.error(f"Reconnection failed: {e}") class HearingAidSettings: - def __init__(self, left_eq, right_eq, left_amp, right_amp, left_tone, right_tone, - left_conv, right_conv, left_anr, right_anr, net_amp, balance, own_voice): - self.left_eq = left_eq - self.right_eq = right_eq - self.left_amplification = left_amp - self.right_amplification = right_amp - self.left_tone = left_tone - self.right_tone = right_tone - self.left_conversation_boost = left_conv - self.right_conversation_boost = right_conv - self.left_ambient_noise_reduction = left_anr - self.right_ambient_noise_reduction = right_anr - self.net_amplification = net_amp - self.balance = balance - self.own_voice_amplification = own_voice + def __init__(self, left_eq: List[float], right_eq: List[float], left_amp: float, right_amp: float, left_tone: float, right_tone: float, + left_conv: bool, right_conv: bool, left_anr: float, right_anr: float, net_amp: float, balance: float, own_voice: float) -> None: + self.left_eq: List[float] = left_eq + self.right_eq: List[float] = right_eq + self.left_amplification: float = left_amp + self.right_amplification: float = right_amp + self.left_tone: float = left_tone + self.right_tone: float = right_tone + self.left_conversation_boost: bool = left_conv + self.right_conversation_boost: bool = right_conv + self.left_ambient_noise_reduction: float = left_anr + self.right_ambient_noise_reduction: float = right_anr + self.net_amplification: float = net_amp + self.balance: float = balance + self.own_voice_amplification: float = own_voice logging.debug(f"HearingAidSettings created: amp={net_amp}, balance={balance}, tone={left_tone}, anr={left_anr}, conv={left_conv}") -def parse_hearing_aid_settings(data): +def parse_hearing_aid_settings(data: bytes) -> Optional[HearingAidSettings]: logging.debug(f"Parsing hearing aid settings from data: {data.hex()}") if len(data) < 104: logging.warning("Data too short for parsing") return None - buffer = data - offset = 0 + buffer: bytes = data + offset: int = 0 offset += 4 logging.info(f"Parsing hearing aid settings, starting read at offset 4, value: {buffer[offset]:02x}") - left_eq = [] + left_eq: List[float] = [] for i in range(8): val, = struct.unpack(' None: logging.info("Sending hearing aid settings") - data = att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})()) + data: bytes = att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})()) if len(data) < 104: logging.error("Read data too short for sending settings") return - buffer = bytearray(data) + buffer: bytearray = bytearray(data) # Modify byte at index 2 to 0x64 buffer[2] = 0x64 @@ -272,16 +275,16 @@ def send_hearing_aid_settings(att_manager, settings): logging.info("Hearing aid settings sent") class SignalEmitter(QObject): - update_ui = pyqtSignal(HearingAidSettings) + update_ui: pyqtSignal = pyqtSignal(HearingAidSettings) class HearingAidApp(QWidget): - def __init__(self, mac_address): + def __init__(self, mac_address: str) -> None: super().__init__() - self.mac_address = mac_address - self.att_manager = ATTManager(mac_address) - self.emitter = SignalEmitter() + self.mac_address: str = mac_address + self.att_manager: ATTManager = ATTManager(mac_address) + self.emitter: SignalEmitter = SignalEmitter() self.emitter.update_ui.connect(self.on_update_ui) - self.debounce_timer = QTimer() + self.debounce_timer: QTimer = QTimer() self.debounce_timer.setSingleShot(True) self.debounce_timer.timeout.connect(self.send_settings) logging.info("HearingAidConfig initialized") @@ -289,25 +292,25 @@ class HearingAidApp(QWidget): self.init_ui() self.connect_att() - def init_ui(self): + def init_ui(self) -> None: logging.debug("Initializing UI") self.setWindowTitle("Hearing Aid Adjustments") - layout = QVBoxLayout() + layout: QVBoxLayout = QVBoxLayout() # EQ Inputs - eq_layout = QGridLayout() - self.left_eq_inputs = [] - self.right_eq_inputs = [] + eq_layout: QGridLayout = QGridLayout() + self.left_eq_inputs: List[QLineEdit] = [] + self.right_eq_inputs: List[QLineEdit] = [] - eq_labels = ["250Hz", "500Hz", "1kHz", "2kHz", "3kHz", "4kHz", "6kHz", "8kHz"] + eq_labels: List[str] = ["250Hz", "500Hz", "1kHz", "2kHz", "3kHz", "4kHz", "6kHz", "8kHz"] eq_layout.addWidget(QLabel("Frequency"), 0, 0) eq_layout.addWidget(QLabel("Left"), 0, 1) eq_layout.addWidget(QLabel("Right"), 0, 2) for i, label in enumerate(eq_labels): eq_layout.addWidget(QLabel(label), i + 1, 0) - left_input = QLineEdit() - right_input = QLineEdit() + left_input: QLineEdit = QLineEdit() + right_input: QLineEdit = QLineEdit() left_input.setPlaceholderText("Left") right_input.setPlaceholderText("Right") self.left_eq_inputs.append(left_input) @@ -315,52 +318,52 @@ class HearingAidApp(QWidget): eq_layout.addWidget(left_input, i + 1, 1) eq_layout.addWidget(right_input, i + 1, 2) - eq_group = QWidget() + eq_group: QWidget = QWidget() eq_group.setLayout(eq_layout) layout.addWidget(QLabel("Loss, in dBHL")) layout.addWidget(eq_group) # Amplification - self.amp_slider = QSlider(Qt.Horizontal) + self.amp_slider: QSlider = QSlider(Qt.Horizontal) self.amp_slider.setRange(-100, 100) self.amp_slider.setValue(50) layout.addWidget(QLabel("Amplification")) layout.addWidget(self.amp_slider) # Balance - self.balance_slider = QSlider(Qt.Horizontal) + self.balance_slider: QSlider = QSlider(Qt.Horizontal) self.balance_slider.setRange(-100, 100) self.balance_slider.setValue(50) layout.addWidget(QLabel("Balance")) layout.addWidget(self.balance_slider) # Tone - self.tone_slider = QSlider(Qt.Horizontal) + self.tone_slider: QSlider = QSlider(Qt.Horizontal) self.tone_slider.setRange(-100, 100) self.tone_slider.setValue(50) layout.addWidget(QLabel("Tone")) layout.addWidget(self.tone_slider) # Ambient Noise Reduction - self.anr_slider = QSlider(Qt.Horizontal) + self.anr_slider: QSlider = QSlider(Qt.Horizontal) self.anr_slider.setRange(0, 100) self.anr_slider.setValue(0) layout.addWidget(QLabel("Ambient Noise Reduction")) layout.addWidget(self.anr_slider) # Conversation Boost - self.conv_checkbox = QCheckBox("Conversation Boost") + self.conv_checkbox: QCheckBox = QCheckBox("Conversation Boost") layout.addWidget(self.conv_checkbox) # Own Voice Amplification - self.own_voice_slider = QSlider(Qt.Horizontal) + self.own_voice_slider: QSlider = QSlider(Qt.Horizontal) self.own_voice_slider.setRange(0, 100) self.own_voice_slider.setValue(50) # layout.addWidget(QLabel("Own Voice Amplification")) # layout.addWidget(self.own_voice_slider) # seems to have no effect # Reset button - self.reset_button = QPushButton("Reset") + self.reset_button: QPushButton = QPushButton("Reset") layout.addWidget(self.reset_button) # Connect signals @@ -377,15 +380,15 @@ class HearingAidApp(QWidget): self.setLayout(layout) logging.debug("UI initialized") - def connect_att(self): + def connect_att(self) -> None: logging.info("Connecting to ATT in UI") try: self.att_manager.connect() self.att_manager.enable_notifications(type('Handle', (), {'name': 'HEARING_AID'})()) self.att_manager.register_listener(ATT_HANDLES['HEARING_AID'], self.on_notification) # Initial read - data = self.att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})()) - settings = parse_hearing_aid_settings(data) + data: bytes = self.att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})()) + settings: Optional[HearingAidSettings] = parse_hearing_aid_settings(data) if settings: self.emitter.update_ui.emit(settings) logging.info("Initial settings loaded") @@ -396,13 +399,13 @@ class HearingAidApp(QWidget): else: logging.error(f"Connection failed: {e}") - def on_notification(self, value): + def on_notification(self, value: bytes) -> None: logging.debug("Notification received") - settings = parse_hearing_aid_settings(value) + settings: Optional[HearingAidSettings] = parse_hearing_aid_settings(value) if settings: self.emitter.update_ui.emit(settings) - def on_update_ui(self, settings): + def on_update_ui(self, settings: HearingAidSettings) -> None: logging.debug("Updating UI with settings") self.amp_slider.setValue(int(settings.net_amplification * 100)) self.balance_slider.setValue(int(settings.balance * 100)) @@ -416,30 +419,30 @@ class HearingAidApp(QWidget): for i, value in enumerate(settings.right_eq): self.right_eq_inputs[i].setText(f"{value:.2f}") - def on_value_changed(self): + def on_value_changed(self) -> None: logging.debug("UI value changed, starting debounce") self.debounce_timer.start(100) - def send_settings(self): + def send_settings(self) -> None: logging.info("Sending settings from UI") - amp = self.amp_slider.value() / 100.0 - balance = self.balance_slider.value() / 100.0 - tone = self.tone_slider.value() / 100.0 - anr = self.anr_slider.value() / 100.0 - conv = self.conv_checkbox.isChecked() - own_voice = self.own_voice_slider.value() / 100.0 + amp: float = self.amp_slider.value() / 100.0 + balance: float = self.balance_slider.value() / 100.0 + tone: float = self.tone_slider.value() / 100.0 + anr: float = self.anr_slider.value() / 100.0 + conv: bool = self.conv_checkbox.isChecked() + own_voice: float = self.own_voice_slider.value() / 100.0 - left_amp = amp + (0.5 - balance) * amp * 2 if balance < 0 else amp - right_amp = amp + (balance - 0.5) * amp * 2 if balance > 0 else amp + left_amp: float = amp + (0.5 - balance) * amp * 2 if balance < 0 else amp + right_amp: float = amp + (balance - 0.5) * amp * 2 if balance > 0 else amp - left_eq = [float(input_box.text() or 0) for input_box in self.left_eq_inputs] - right_eq = [float(input_box.text() or 0) for input_box in self.right_eq_inputs] + left_eq: List[float] = [float(input_box.text() or 0) for input_box in self.left_eq_inputs] + right_eq: List[float] = [float(input_box.text() or 0) for input_box in self.right_eq_inputs] - settings = HearingAidSettings( + settings: HearingAidSettings = HearingAidSettings( left_eq, right_eq, left_amp, right_amp, tone, tone, conv, conv, anr, anr, amp, balance, own_voice ) - threading.Thread(target=send_hearing_aid_settings, args=(self.att_manager, settings)).start() + Thread(target=send_hearing_aid_settings, args=(self.att_manager, settings)).start() def reset_settings(self): logging.debug("Resetting settings to defaults") @@ -451,26 +454,25 @@ class HearingAidApp(QWidget): self.own_voice_slider.setValue(50) self.on_value_changed() - def closeEvent(self, event): + def closeEvent(self, event: Any) -> None: logging.info("Closing app") self.att_manager.disconnect() event.accept() if __name__ == "__main__": - mac = None if len(sys.argv) != 2: logging.error("Usage: python hearing-aid-adjustments.py ") sys.exit(1) - mac = sys.argv[1] - mac_regex = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$' + mac: str = sys.argv[1] + mac_regex: str = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$' import re if not re.match(mac_regex, mac): logging.error("Invalid MAC address format") sys.exit(1) logging.info(f"Starting app") - app = QApplication(sys.argv) + app: QApplication = QApplication(sys.argv) - def quit_app(signum, frame): + def quit_app(signum: int, frame: Any) -> None: app.quit() signal.signal(signal.SIGINT, quit_app) diff --git a/proximity_keys.py b/proximity_keys.py index a0a9e42..29ec444 100644 --- a/proximity_keys.py +++ b/proximity_keys.py @@ -4,50 +4,53 @@ # See https://github.com/google/bumble/blob/main/docs/mkdocs/src/platforms/windows.md for usage. # You need to associate WinUSB with your Bluetooth interface. Once done, you can roll back to the original driver from Device Manager. -import sys import asyncio -import argparse +import colorama import logging import platform -from typing import Any, Optional +from argparse import ArgumentParser, Namespace +from asyncio import Queue, TimeoutError +from colorama import Fore, Style +from logging import Formatter, LogRecord, Logger, StreamHandler +from socket import socket as Socket +from typing import Any, Dict, List, Optional, Tuple -from colorama import Fore, Style, init as colorama_init -colorama_init(autoreset=True) +colorama.init(autoreset=True) -handler = logging.StreamHandler() -class ColorFormatter(logging.Formatter): - COLORS = { +handler: StreamHandler = StreamHandler() +class ColorFormatter(Formatter): + COLORS: Dict[int, str] = { logging.DEBUG: Fore.BLUE, logging.INFO: Fore.GREEN, logging.WARNING: Fore.YELLOW, logging.ERROR: Fore.RED, logging.CRITICAL: Fore.MAGENTA, } - def format(self, record): - color = self.COLORS.get(record.levelno, "") - prefix = f"{color}[{record.levelname}:{record.name}]{Style.RESET_ALL}" + def format(self, record: LogRecord) -> str: + color: str = self.COLORS.get(record.levelno, "") + prefix: str = f"{color}[{record.levelname}:{record.name}]{Style.RESET_ALL}" return f"{prefix} {record.getMessage()}" handler.setFormatter(ColorFormatter()) logging.basicConfig(level=logging.INFO, handlers=[handler]) -logger = logging.getLogger("proximitykeys") +logger: Logger = logging.getLogger("proximitykeys") -PROXIMITY_KEY_TYPES = {0x01: "IRK", 0x04: "ENC_KEY"} +PROXIMITY_KEY_TYPES: Dict[int, str] = {0x01: "IRK", 0x04: "ENC_KEY"} -def parse_proximity_keys_response(data: bytes): +def parse_proximity_keys_response(data: bytes) -> Optional[List[Tuple[str, bytes]]]: if len(data) < 7 or data[4] != 0x31: return None - key_count = data[6] - keys = [] - offset = 7 + key_count: int = data[6] + keys: List[Tuple[str, bytes]] = [] + offset: int = 7 for _ in range(key_count): if offset + 3 >= len(data): break - key_type = data[offset] - key_length = data[offset + 2] + key_type: int = data[offset] + key_length: int = data[offset + 2] offset += 4 if offset + key_length > len(data): break - key_bytes = data[offset:offset + key_length] + key_bytes: bytes = data[offset:offset + key_length] keys.append((PROXIMITY_KEY_TYPES.get(key_type, f"TYPE_{key_type:02X}"), key_bytes)) offset += key_length return keys @@ -55,7 +58,7 @@ def parse_proximity_keys_response(data: bytes): def hexdump(data: bytes) -> str: return " ".join(f"{b:02X}" for b in data) -async def run_bumble(bdaddr: str): +async def run_bumble(bdaddr: str) -> int: try: from bumble.l2cap import ClassicChannelSpec from bumble.transport import open_transport @@ -68,19 +71,23 @@ async def run_bumble(bdaddr: str): logger.error("Bumble not installed") return 1 - PSM_PROXIMITY = 0x1001 - HANDSHAKE = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00") - KEY_REQ = bytes.fromhex("04 00 04 00 30 00 05 00") + PSM_PROXIMITY: int = 0x1001 + HANDSHAKE: bytes = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00") + KEY_REQ: bytes = bytes.fromhex("04 00 04 00 30 00 05 00") class KeyStore: - async def delete(self, name: str): pass - async def update(self, name: str, keys: Any): pass - async def get(self, _name: str) -> Optional[Any]: return None - async def get_all(self): return [] + async def delete(self, name: str) -> None: + pass + async def update(self, name: str, keys: Any) -> None: + pass + async def get(self, _name: str) -> Optional[Any]: + return None + async def get_all(self) -> List[Tuple[str, Any]]: + return [] - async def get_resolving_keys(self) -> list[tuple[bytes, Any]]: - all_keys = await self.get_all() - resolving_keys = [] + async def get_resolving_keys(self) -> List[Tuple[bytes, Any]]: + all_keys: List[Tuple[str, Any]] = await self.get_all() + resolving_keys: List[Tuple[bytes, Any]] = [] for name, keys in all_keys: if getattr(keys, "irk", None) is not None: resolving_keys.append(( @@ -89,8 +96,8 @@ async def run_bumble(bdaddr: str): )) return resolving_keys - async def exchange_keys(channel, timeout=5.0): - recv_q: asyncio.Queue = asyncio.Queue() + async def exchange_keys(channel: Any, timeout: float = 5.0) -> Optional[List[Tuple[str, bytes]]]: + recv_q: Queue = Queue() channel.sink = lambda sdu: recv_q.put_nowait(sdu) logger.info("Sending handshake packet...") channel.send_pdu(HANDSHAKE) @@ -99,19 +106,19 @@ async def run_bumble(bdaddr: str): channel.send_pdu(KEY_REQ) while True: try: - pkt = await asyncio.wait_for(recv_q.get(), timeout) - except asyncio.TimeoutError: + pkt: bytes = await asyncio.wait_for(recv_q.get(), timeout) + except TimeoutError: logger.error("Timed out waiting for SDU response") return None logger.debug("Received SDU (%d bytes): %s", len(pkt), hexdump(pkt)) - keys = parse_proximity_keys_response(pkt) + keys: Optional[List[Tuple[str, bytes]]] = parse_proximity_keys_response(pkt) if keys: return keys - async def get_device(): + async def get_device() -> Tuple[Any, Device]: logger.info("Opening transport...") - transport = await open_transport("usb:0") - device = Device(host=Host(controller_source=transport.source, controller_sink=transport.sink)) + transport: Any = await open_transport("usb:0") + device: Device = Device(host=Host(controller_source=transport.source, controller_sink=transport.sink)) device.classic_enabled = True device.le_enabled = False device.keystore = KeyStore() @@ -123,15 +130,15 @@ async def run_bumble(bdaddr: str): logger.info("Device powered on") return transport, device - async def create_channel_and_exchange(conn): - spec = ClassicChannelSpec(psm=PSM_PROXIMITY, mtu=2048) + async def create_channel_and_exchange(conn: Any) -> None: + spec: ClassicChannelSpec = ClassicChannelSpec(psm=PSM_PROXIMITY, mtu=2048) logger.info("Requesting L2CAP channel on PSM = 0x%04X", spec.psm) if not conn.is_encrypted: logger.info("Enabling link encryption...") await conn.encrypt() await asyncio.sleep(0.05) - channel = await conn.create_l2cap_channel(spec=spec) - keys = await exchange_keys(channel, timeout=8.0) + channel: Any = await conn.create_l2cap_channel(spec=spec) + keys: Optional[List[Tuple[str, bytes]]] = await exchange_keys(channel, timeout=8.0) if not keys: logger.warning("No proximity keys found") return @@ -165,14 +172,14 @@ async def run_bumble(bdaddr: str): logger.info("Transport closed") return 0 -def run_linux(bdaddr: str): +def run_linux(bdaddr: str) -> None: import socket - PSM = 0x1001 - handshake = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00") - key_req = bytes.fromhex("04 00 04 00 30 00 05 00") + PSM: int = 0x1001 + handshake: bytes = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00") + key_req: bytes = bytes.fromhex("04 00 04 00 30 00 05 00") logger.info("Connecting to %s (L2CAP)...", bdaddr) - sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + sock: Socket = Socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) try: sock.connect((bdaddr, PSM)) logger.info("Connected, sending handshake and key request...") @@ -180,9 +187,9 @@ def run_linux(bdaddr: str): sock.send(key_req) while True: - pkt = sock.recv(1024) + pkt: bytes = sock.recv(1024) logger.debug("Received packet (%d bytes): %s", len(pkt), hexdump(pkt)) - keys = parse_proximity_keys_response(pkt) + keys: Optional[List[Tuple[str, bytes]]] = parse_proximity_keys_response(pkt) if keys: logger.info("Keys successfully retrieved") print(f"{Fore.CYAN}{Style.BRIGHT}Proximity Keys:{Style.RESET_ALL}") @@ -197,12 +204,12 @@ def run_linux(bdaddr: str): sock.close() logger.info("Connection closed") -def main(): - parser = argparse.ArgumentParser() +def main() -> None: + parser: ArgumentParser = ArgumentParser() parser.add_argument("bdaddr") parser.add_argument("--debug", action="store_true") parser.add_argument("--bumble", action="store_true") - args = parser.parse_args() + args: Namespace = parser.parse_args() logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO) if args.bumble or platform.system() == "Windows":