mirror of
https://github.com/kavishdevar/librepods.git
synced 2026-01-28 22:01:50 +00:00
refactor: Add Python type annotations wherever appropriate (#269)
* Add Python type annotations wherever appropriate * Might as well annotate this too
This commit is contained in:
29
head-tracking/colors.py
Normal file
29
head-tracking/colors.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from logging import Formatter, LogRecord
|
||||
from typing import Dict
|
||||
|
||||
class Colors:
|
||||
RESET: str = "\033[0m"
|
||||
BOLD: str = "\033[1m"
|
||||
RED: str = "\033[91m"
|
||||
GREEN: str = "\033[92m"
|
||||
YELLOW: str = "\033[93m"
|
||||
BLUE: str = "\033[94m"
|
||||
MAGENTA: str = "\033[95m"
|
||||
CYAN: str = "\033[96m"
|
||||
WHITE: str = "\033[97m"
|
||||
BG_BLACK: str = "\033[40m"
|
||||
|
||||
class ColorFormatter(Formatter):
|
||||
FORMATS: Dict[int, str] = {
|
||||
logging.DEBUG: f"{Colors.BLUE}[%(levelname)s] %(message)s{Colors.RESET}",
|
||||
logging.INFO: f"{Colors.GREEN}%(message)s{Colors.RESET}",
|
||||
logging.WARNING: f"{Colors.YELLOW}%(message)s{Colors.RESET}",
|
||||
logging.ERROR: f"{Colors.RED}[%(levelname)s] %(message)s{Colors.RESET}",
|
||||
logging.CRITICAL: f"{Colors.RED}{Colors.BOLD}[%(levelname)s] %(message)s{Colors.RESET}"
|
||||
}
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
log_fmt: str = self.FORMATS.get(record.levelno)
|
||||
formatter: Formatter = Formatter(log_fmt, datefmt="%H:%M:%S")
|
||||
return formatter.format(record)
|
||||
@@ -1,23 +1,25 @@
|
||||
import bluetooth
|
||||
import logging
|
||||
from bluetooth import BluetoothSocket
|
||||
from logging import Logger
|
||||
|
||||
class ConnectionManager:
|
||||
INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
|
||||
START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
|
||||
STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
|
||||
INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
|
||||
START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
|
||||
STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
|
||||
|
||||
def __init__(self, bt_addr="28:2D:7F:C2:05:5B", psm=0x1001, logger=None):
|
||||
self.bt_addr = bt_addr
|
||||
self.psm = psm
|
||||
self.logger = logger if logger else logging.getLogger(__name__)
|
||||
self.sock = None
|
||||
self.connected = False
|
||||
self.started = False
|
||||
def __init__(self, bt_addr: str = "28:2D:7F:C2:05:5B", psm: int = 0x1001, logger: Logger = None) -> None:
|
||||
self.bt_addr: str = bt_addr
|
||||
self.psm: int = psm
|
||||
self.logger: Logger = logger if logger else logging.getLogger(__name__)
|
||||
self.sock: BluetoothSocket = None
|
||||
self.connected: bool = False
|
||||
self.started: bool = False
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> bool:
|
||||
self.logger.info(f"Connecting to {self.bt_addr} on PSM {self.psm:#04x}...")
|
||||
try:
|
||||
self.sock = bluetooth.BluetoothSocket(bluetooth.L2CAP)
|
||||
self.sock = BluetoothSocket(bluetooth.L2CAP)
|
||||
self.sock.connect((self.bt_addr, self.psm))
|
||||
self.connected = True
|
||||
self.logger.info("Connected to AirPods.")
|
||||
@@ -28,7 +30,7 @@ class ConnectionManager:
|
||||
self.connected = False
|
||||
return self.connected
|
||||
|
||||
def send_start(self):
|
||||
def send_start(self) -> bool:
|
||||
if not self.connected:
|
||||
self.logger.error("Not connected. Cannot send START command.")
|
||||
return False
|
||||
@@ -40,7 +42,7 @@ class ConnectionManager:
|
||||
self.logger.info("START command has already been sent.")
|
||||
return True
|
||||
|
||||
def send_stop(self):
|
||||
def send_stop(self) -> None:
|
||||
if self.connected and self.started:
|
||||
try:
|
||||
self.sock.send(bytes.fromhex(self.STOP_CMD))
|
||||
@@ -51,7 +53,7 @@ class ConnectionManager:
|
||||
else:
|
||||
self.logger.info("Cannot send STOP; not started or not connected.")
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
if self.sock:
|
||||
try:
|
||||
self.sock.close()
|
||||
@@ -59,4 +61,4 @@ class ConnectionManager:
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during disconnect: {e}")
|
||||
self.connected = False
|
||||
self.started = False
|
||||
self.started = False
|
||||
|
||||
@@ -1,88 +1,65 @@
|
||||
import bluetooth
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
import statistics
|
||||
import time
|
||||
from bluetooth import BluetoothSocket
|
||||
from collections import deque
|
||||
from colors import *
|
||||
from connection_manager import ConnectionManager
|
||||
from logging import Logger, StreamHandler
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Deque, List, Optional, Tuple
|
||||
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
WHITE = "\033[97m"
|
||||
BG_BLACK = "\033[40m"
|
||||
|
||||
class ColorFormatter(logging.Formatter):
|
||||
FORMATS = {
|
||||
logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET,
|
||||
logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET,
|
||||
logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET,
|
||||
logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET,
|
||||
logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S")
|
||||
return formatter.format(record)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler: StreamHandler = StreamHandler()
|
||||
handler.setFormatter(ColorFormatter())
|
||||
log = logging.getLogger(__name__)
|
||||
log: Logger = logging.getLogger(__name__)
|
||||
log.setLevel(logging.INFO)
|
||||
log.addHandler(handler)
|
||||
log.propagate = False
|
||||
|
||||
class GestureDetector:
|
||||
INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
|
||||
START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
|
||||
STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
|
||||
INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
|
||||
START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
|
||||
STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
|
||||
|
||||
def __init__(self, conn=None):
|
||||
self.sock = None
|
||||
self.bt_addr = "28:2D:7F:C2:05:5B"
|
||||
self.psm = 0x1001
|
||||
self.running = False
|
||||
self.data_lock = threading.Lock()
|
||||
def __init__(self, conn: ConnectionManager = None) -> None:
|
||||
self.sock: BluetoothSocket = None
|
||||
self.bt_addr: str = "28:2D:7F:C2:05:5B"
|
||||
self.psm: int = 0x1001
|
||||
self.running: bool = False
|
||||
self.data_lock: Lock = Lock()
|
||||
|
||||
self.horiz_buffer = deque(maxlen=100)
|
||||
self.vert_buffer = deque(maxlen=100)
|
||||
self.horiz_buffer: Deque[int] = deque(maxlen=100)
|
||||
self.vert_buffer: Deque[int] = deque(maxlen=100)
|
||||
|
||||
self.horiz_avg_buffer = deque(maxlen=5)
|
||||
self.vert_avg_buffer = deque(maxlen=5)
|
||||
self.horiz_avg_buffer: Deque[float] = deque(maxlen=5)
|
||||
self.vert_avg_buffer: Deque[float] = deque(maxlen=5)
|
||||
|
||||
self.horiz_peaks = []
|
||||
self.horiz_troughs = []
|
||||
self.vert_peaks = []
|
||||
self.vert_troughs = []
|
||||
self.horiz_peaks: List[int] = []
|
||||
self.horiz_troughs: List[int] = []
|
||||
self.vert_peaks: List[int] = []
|
||||
self.vert_troughs: List[int] = []
|
||||
|
||||
self.last_peak_time = 0
|
||||
self.peak_intervals = deque(maxlen=5)
|
||||
self.last_peak_time: float = 0
|
||||
self.peak_intervals: Deque[float] = deque(maxlen=5)
|
||||
|
||||
self.peak_threshold = 400
|
||||
self.direction_change_threshold = 175
|
||||
self.rhythm_consistency_threshold = 0.5
|
||||
self.peak_threshold: int = 400
|
||||
self.direction_change_threshold: int = 175
|
||||
self.rhythm_consistency_threshold: float = 0.5
|
||||
|
||||
self.horiz_increasing = None
|
||||
self.vert_increasing = None
|
||||
self.horiz_increasing: Optional[bool] = None
|
||||
self.vert_increasing: Optional[bool] = None
|
||||
|
||||
self.required_extremes = 3
|
||||
self.detection_timeout = 15
|
||||
self.detection_timeout: int = 15
|
||||
|
||||
self.min_confidence_threshold = 0.7
|
||||
self.min_confidence_threshold: float = 0.7
|
||||
|
||||
self.conn = conn
|
||||
self.conn: ConnectionManager = conn
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> bool:
|
||||
try:
|
||||
log.info(f"Connecting to AirPods at {self.bt_addr}...")
|
||||
if self.conn is None:
|
||||
from connection_manager import ConnectionManager
|
||||
self.conn = ConnectionManager(self.bt_addr, self.psm, logger=log)
|
||||
if not self.conn.connect():
|
||||
return False
|
||||
@@ -97,13 +74,13 @@ class GestureDetector:
|
||||
log.error(f"{Colors.RED}Connection failed: {e}{Colors.RESET}")
|
||||
return False
|
||||
|
||||
def process_data(self):
|
||||
def process_data(self) -> None:
|
||||
"""Process incoming head tracking data."""
|
||||
self.conn.send_start()
|
||||
log.info(f"{Colors.GREEN}✓ Head tracking activated{Colors.RESET}")
|
||||
|
||||
self.running = True
|
||||
start_time = time.time()
|
||||
start_time: float = time.time()
|
||||
|
||||
log.info(f"{Colors.GREEN}Ready! Make a YES or NO gesture{Colors.RESET}")
|
||||
log.info(f"{Colors.YELLOW}Tip: Use natural, moderate speed head movements{Colors.RESET}")
|
||||
@@ -118,10 +95,10 @@ class GestureDetector:
|
||||
if not self.sock:
|
||||
log.error("Socket not available.")
|
||||
break
|
||||
data = self.sock.recv(1024)
|
||||
formatted = self.format_hex(data)
|
||||
data: bytes = self.sock.recv(1024)
|
||||
formatted: str = self.format_hex(data)
|
||||
if self.is_valid_tracking_packet(formatted):
|
||||
raw_bytes = bytes.fromhex(formatted.replace(" ", ""))
|
||||
raw_bytes: bytes = bytes.fromhex(formatted.replace(" ", ""))
|
||||
horizontal, vertical = self.extract_orientation_values(raw_bytes)
|
||||
|
||||
if horizontal is not None and vertical is not None:
|
||||
@@ -132,7 +109,7 @@ class GestureDetector:
|
||||
self.vert_buffer.append(smooth_v)
|
||||
|
||||
self.detect_peaks_and_troughs()
|
||||
gesture = self.detect_gestures()
|
||||
gesture: Optional[str] = self.detect_gestures()
|
||||
|
||||
if gesture:
|
||||
self.running = False
|
||||
@@ -143,19 +120,19 @@ class GestureDetector:
|
||||
log.error(f"Data processing error: {e}")
|
||||
break
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from socket."""
|
||||
self.conn.disconnect()
|
||||
|
||||
def format_hex(self, data):
|
||||
def format_hex(self, data: bytes) -> str:
|
||||
"""Format binary data to readable hex string."""
|
||||
hex_str = data.hex()
|
||||
hex_str: str = data.hex()
|
||||
return ' '.join(hex_str[i:i+2] for i in range(0, len(hex_str), 2))
|
||||
|
||||
def is_valid_tracking_packet(self, hex_string):
|
||||
def is_valid_tracking_packet(self, hex_string: str) -> bool:
|
||||
"""Verify packet is a valid head tracking packet."""
|
||||
standard_header = "04 00 04 00 17 00 00 00 10 00 45 00"
|
||||
alternate_header = "04 00 04 00 17 00 00 00 10 00 44 00"
|
||||
standard_header: str = "04 00 04 00 17 00 00 00 10 00 45 00"
|
||||
alternate_header: str = "04 00 04 00 17 00 00 00 10 00 44 00"
|
||||
if not hex_string.startswith(standard_header) and not hex_string.startswith(alternate_header):
|
||||
return False
|
||||
|
||||
@@ -164,55 +141,55 @@ class GestureDetector:
|
||||
|
||||
return True
|
||||
|
||||
def extract_orientation_values(self, raw_bytes):
|
||||
def extract_orientation_values(self, raw_bytes: bytes) -> Tuple[Optional[int], Optional[int]]:
|
||||
"""Extract head orientation data from packet."""
|
||||
try:
|
||||
horizontal = int.from_bytes(raw_bytes[51:53], byteorder='little', signed=True)
|
||||
vertical = int.from_bytes(raw_bytes[53:55], byteorder='little', signed=True)
|
||||
horizontal: int = int.from_bytes(raw_bytes[51:53], byteorder='little', signed=True)
|
||||
vertical: int = int.from_bytes(raw_bytes[53:55], byteorder='little', signed=True)
|
||||
|
||||
return horizontal, vertical
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to extract orientation: {e}")
|
||||
return None, None
|
||||
|
||||
def apply_smoothing(self, horizontal, vertical):
|
||||
def apply_smoothing(self, horizontal: int, vertical: int) -> Tuple[float, float]:
|
||||
"""Apply moving average smoothing (Apple-like filtering)."""
|
||||
self.horiz_avg_buffer.append(horizontal)
|
||||
self.vert_avg_buffer.append(vertical)
|
||||
|
||||
smooth_horiz = sum(self.horiz_avg_buffer) / len(self.horiz_avg_buffer)
|
||||
smooth_vert = sum(self.vert_avg_buffer) / len(self.vert_avg_buffer)
|
||||
smooth_horiz: float = sum(self.horiz_avg_buffer) / len(self.horiz_avg_buffer)
|
||||
smooth_vert: float = sum(self.vert_avg_buffer) / len(self.vert_avg_buffer)
|
||||
|
||||
return smooth_horiz, smooth_vert
|
||||
|
||||
def detect_peaks_and_troughs(self):
|
||||
def detect_peaks_and_troughs(self) -> None:
|
||||
"""Detect motion direction changes with Apple-like refinements."""
|
||||
if len(self.horiz_buffer) < 4 or len(self.vert_buffer) < 4:
|
||||
return
|
||||
|
||||
h_values = list(self.horiz_buffer)[-4:]
|
||||
v_values = list(self.vert_buffer)[-4:]
|
||||
h_values: List[int] = list(self.horiz_buffer)[-4:]
|
||||
v_values: List[int] = list(self.vert_buffer)[-4:]
|
||||
|
||||
h_variance = statistics.variance(h_values) if len(h_values) > 1 else 0
|
||||
v_variance = statistics.variance(v_values) if len(v_values) > 1 else 0
|
||||
h_variance: float = statistics.variance(h_values) if len(h_values) > 1 else 0
|
||||
v_variance: float = statistics.variance(v_values) if len(v_values) > 1 else 0
|
||||
|
||||
current = self.horiz_buffer[-1]
|
||||
prev = self.horiz_buffer[-2]
|
||||
current: int = self.horiz_buffer[-1]
|
||||
prev: int = self.horiz_buffer[-2]
|
||||
|
||||
if self.horiz_increasing is None:
|
||||
self.horiz_increasing = current > prev
|
||||
|
||||
dynamic_h_threshold = max(100, min(self.direction_change_threshold, h_variance / 3))
|
||||
dynamic_h_threshold: float = max(100, min(self.direction_change_threshold, h_variance / 3))
|
||||
|
||||
if self.horiz_increasing and current < prev - dynamic_h_threshold:
|
||||
if abs(prev) > self.peak_threshold:
|
||||
self.horiz_peaks.append((len(self.horiz_buffer)-1, prev, time.time()))
|
||||
direction = "➡️ " if prev > 0 else "⬅️ "
|
||||
direction: str = "➡️ " if prev > 0 else "⬅️ "
|
||||
log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}")
|
||||
|
||||
now = time.time()
|
||||
now: float = time.time()
|
||||
if self.last_peak_time > 0:
|
||||
interval = now - self.last_peak_time
|
||||
interval: float = now - self.last_peak_time
|
||||
self.peak_intervals.append(interval)
|
||||
self.last_peak_time = now
|
||||
|
||||
@@ -221,34 +198,34 @@ class GestureDetector:
|
||||
elif not self.horiz_increasing and current > prev + dynamic_h_threshold:
|
||||
if abs(prev) > self.peak_threshold:
|
||||
self.horiz_troughs.append((len(self.horiz_buffer)-1, prev, time.time()))
|
||||
direction = "➡️ " if prev > 0 else "⬅️ "
|
||||
direction: str = "➡️ " if prev > 0 else "⬅️ "
|
||||
log.info(f"{Colors.CYAN}{direction} Horizontal max: {prev} (threshold: {dynamic_h_threshold:.1f}){Colors.RESET}")
|
||||
|
||||
now = time.time()
|
||||
now: float = time.time()
|
||||
if self.last_peak_time > 0:
|
||||
interval = now - self.last_peak_time
|
||||
interval: float = now - self.last_peak_time
|
||||
self.peak_intervals.append(interval)
|
||||
self.last_peak_time = now
|
||||
|
||||
self.horiz_increasing = True
|
||||
|
||||
current = self.vert_buffer[-1]
|
||||
prev = self.vert_buffer[-2]
|
||||
current: int = self.vert_buffer[-1]
|
||||
prev: int = self.vert_buffer[-2]
|
||||
|
||||
if self.vert_increasing is None:
|
||||
self.vert_increasing = current > prev
|
||||
|
||||
dynamic_v_threshold = max(100, min(self.direction_change_threshold, v_variance / 3))
|
||||
dynamic_v_threshold: float = max(100, min(self.direction_change_threshold, v_variance / 3))
|
||||
|
||||
if self.vert_increasing and current < prev - dynamic_v_threshold:
|
||||
if abs(prev) > self.peak_threshold:
|
||||
self.vert_peaks.append((len(self.vert_buffer)-1, prev, time.time()))
|
||||
direction = "⬆️ " if prev > 0 else "⬇️ "
|
||||
direction: str = "⬆️ " if prev > 0 else "⬇️ "
|
||||
log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}")
|
||||
|
||||
now = time.time()
|
||||
now: float = time.time()
|
||||
if self.last_peak_time > 0:
|
||||
interval = now - self.last_peak_time
|
||||
interval: float = now - self.last_peak_time
|
||||
self.peak_intervals.append(interval)
|
||||
self.last_peak_time = now
|
||||
|
||||
@@ -257,60 +234,60 @@ class GestureDetector:
|
||||
elif not self.vert_increasing and current > prev + dynamic_v_threshold:
|
||||
if abs(prev) > self.peak_threshold:
|
||||
self.vert_troughs.append((len(self.vert_buffer)-1, prev, time.time()))
|
||||
direction = "⬆️ " if prev > 0 else "⬇️ "
|
||||
direction: str = "⬆️ " if prev > 0 else "⬇️ "
|
||||
log.info(f"{Colors.MAGENTA}{direction} Vertical max: {prev} (threshold: {dynamic_v_threshold:.1f}){Colors.RESET}")
|
||||
|
||||
now = time.time()
|
||||
now: float = time.time()
|
||||
if self.last_peak_time > 0:
|
||||
interval = now - self.last_peak_time
|
||||
interval: float = now - self.last_peak_time
|
||||
self.peak_intervals.append(interval)
|
||||
self.last_peak_time = now
|
||||
|
||||
self.vert_increasing = True
|
||||
|
||||
def calculate_rhythm_consistency(self):
|
||||
def calculate_rhythm_consistency(self) -> float:
|
||||
"""Calculate how consistent the timing between peaks is (Apple-like)."""
|
||||
if len(self.peak_intervals) < 2:
|
||||
return 0
|
||||
|
||||
mean_interval = statistics.mean(self.peak_intervals)
|
||||
mean_interval: float = statistics.mean(self.peak_intervals)
|
||||
if mean_interval == 0:
|
||||
return 0
|
||||
|
||||
variances = [(i/mean_interval - 1.0) ** 2 for i in self.peak_intervals]
|
||||
consistency = 1.0 - min(1.0, statistics.mean(variances) / self.rhythm_consistency_threshold)
|
||||
variances: List[float] = [(i/mean_interval - 1.0) ** 2 for i in self.peak_intervals]
|
||||
consistency: float = 1.0 - min(1.0, statistics.mean(variances) / self.rhythm_consistency_threshold)
|
||||
return max(0, consistency)
|
||||
|
||||
def calculate_confidence_score(self, extremes, is_vertical=True):
|
||||
def calculate_confidence_score(self, extremes: List[Tuple[int, int, float]], is_vertical: bool = True) -> float:
|
||||
"""Calculate confidence score for gesture detection (Apple-like)."""
|
||||
if len(extremes) < self.required_extremes:
|
||||
return 0.0
|
||||
|
||||
sorted_extremes = sorted(extremes, key=lambda x: x[0])
|
||||
sorted_extremes: List[Tuple[int, int, float]] = sorted(extremes, key=lambda x: x[0])
|
||||
|
||||
recent = sorted_extremes[-self.required_extremes:]
|
||||
recent: List[Tuple[int, int, float]] = sorted_extremes[-self.required_extremes:]
|
||||
|
||||
avg_amplitude = sum(abs(val) for _, val, _ in recent) / len(recent)
|
||||
amplitude_factor = min(1.0, avg_amplitude / 600)
|
||||
avg_amplitude: float = sum(abs(val) for _, val, _ in recent) / len(recent)
|
||||
amplitude_factor: float = min(1.0, avg_amplitude / 600)
|
||||
|
||||
rhythm_factor = self.calculate_rhythm_consistency()
|
||||
rhythm_factor: float = self.calculate_rhythm_consistency()
|
||||
|
||||
signs = [1 if val > 0 else -1 for _, val, _ in recent]
|
||||
alternating = all(signs[i] != signs[i-1] for i in range(1, len(signs)))
|
||||
alternation_factor = 1.0 if alternating else 0.5
|
||||
signs: List[int] = [1 if val > 0 else -1 for _, val, _ in recent]
|
||||
alternating: bool = all(signs[i] != signs[i-1] for i in range(1, len(signs)))
|
||||
alternation_factor: float = 1.0 if alternating else 0.5
|
||||
|
||||
if is_vertical:
|
||||
vert_amp = sum(abs(val) for _, val, _ in recent) / len(recent)
|
||||
horiz_vals = list(self.horiz_buffer)[-len(recent)*2:]
|
||||
horiz_amp = sum(abs(val) for val in horiz_vals) / len(horiz_vals) if horiz_vals else 0
|
||||
isolation_factor = min(1.0, vert_amp / (horiz_amp + 0.1) * 1.2)
|
||||
vert_amp: float = sum(abs(val) for _, val, _ in recent) / len(recent)
|
||||
horiz_vals: List[int] = list(self.horiz_buffer)[-len(recent)*2:]
|
||||
horiz_amp: float = sum(abs(val) for val in horiz_vals) / len(horiz_vals) if horiz_vals else 0
|
||||
isolation_factor: float = min(1.0, vert_amp / (horiz_amp + 0.1) * 1.2)
|
||||
else:
|
||||
horiz_amp = sum(abs(val) for _, val, _ in recent)
|
||||
vert_vals = list(self.vert_buffer)[-len(recent)*2:]
|
||||
vert_amp = sum(abs(val) for val in vert_vals) / len(vert_vals) if vert_vals else 0
|
||||
isolation_factor = min(1.0, horiz_amp / (vert_amp + 0.1) * 1.2)
|
||||
horiz_amp: float = sum(abs(val) for _, val, _ in recent)
|
||||
vert_vals: List[int] = list(self.vert_buffer)[-len(recent)*2:]
|
||||
vert_amp: float = sum(abs(val) for val in vert_vals) / len(vert_vals) if vert_vals else 0
|
||||
isolation_factor: float = min(1.0, horiz_amp / (vert_amp + 0.1) * 1.2)
|
||||
|
||||
confidence = (
|
||||
confidence: float = (
|
||||
amplitude_factor * 0.4 +
|
||||
rhythm_factor * 0.2 +
|
||||
alternation_factor * 0.2 +
|
||||
@@ -319,12 +296,12 @@ class GestureDetector:
|
||||
|
||||
return confidence
|
||||
|
||||
def detect_gestures(self):
|
||||
def detect_gestures(self) -> Optional[str]:
|
||||
"""Recognize head gesture patterns with Apple-like intelligence."""
|
||||
if len(self.vert_peaks) + len(self.vert_troughs) >= self.required_extremes:
|
||||
all_extremes = sorted(self.vert_peaks + self.vert_troughs, key=lambda x: x[0])
|
||||
all_extremes: List[Tuple[int, int, float]] = sorted(self.vert_peaks + self.vert_troughs, key=lambda x: x[0])
|
||||
|
||||
confidence = self.calculate_confidence_score(all_extremes, is_vertical=True)
|
||||
confidence: float = self.calculate_confidence_score(all_extremes, is_vertical=True)
|
||||
|
||||
log.info(f"Vertical motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})")
|
||||
|
||||
@@ -333,9 +310,9 @@ class GestureDetector:
|
||||
return "YES"
|
||||
|
||||
if len(self.horiz_peaks) + len(self.horiz_troughs) >= self.required_extremes:
|
||||
all_extremes = sorted(self.horiz_peaks + self.horiz_troughs, key=lambda x: x[0])
|
||||
all_extremes: List[Tuple[int, int, float]] = sorted(self.horiz_peaks + self.horiz_troughs, key=lambda x: x[0])
|
||||
|
||||
confidence = self.calculate_confidence_score(all_extremes, is_vertical=False)
|
||||
confidence: float = self.calculate_confidence_score(all_extremes, is_vertical=False)
|
||||
|
||||
log.info(f"Horizontal motion confidence: {confidence:.2f} (need {self.min_confidence_threshold:.2f})")
|
||||
|
||||
@@ -345,7 +322,7 @@ class GestureDetector:
|
||||
|
||||
return None
|
||||
|
||||
def start_detection(self):
|
||||
def start_detection(self) -> None:
|
||||
"""Begin gesture detection process."""
|
||||
log.info(f"{Colors.BOLD}{Colors.WHITE}Starting gesture detection...{Colors.RESET}")
|
||||
|
||||
@@ -353,7 +330,7 @@ class GestureDetector:
|
||||
log.error(f"{Colors.RED}Failed to connect to AirPods.{Colors.RESET}")
|
||||
return
|
||||
|
||||
data_thread = threading.Thread(target=self.process_data)
|
||||
data_thread: Thread = Thread(target=self.process_data)
|
||||
data_thread.daemon = True
|
||||
data_thread.start()
|
||||
|
||||
@@ -377,5 +354,5 @@ if __name__ == "__main__":
|
||||
print(f"{Colors.GREEN}• YES: {Colors.WHITE}nodding head up and down{Colors.RESET}")
|
||||
print(f"{Colors.RED}• NO: {Colors.WHITE}shaking head left and right{Colors.RESET}\n")
|
||||
|
||||
detector = GestureDetector()
|
||||
detector: GestureDetector = GestureDetector()
|
||||
detector.start_detection()
|
||||
@@ -1,63 +1,43 @@
|
||||
import math
|
||||
import drawille
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from colors import *
|
||||
from drawille import Canvas
|
||||
from logging import Logger, StreamHandler
|
||||
from matplotlib.animation import FuncAnimation
|
||||
from matplotlib.pyplot import Axes, Figure
|
||||
from numpy.typing import NDArray
|
||||
from os import terminal_size as TerminalSize
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
WHITE = "\033[97m"
|
||||
BG_BLACK = "\033[40m"
|
||||
|
||||
class ColorFormatter(logging.Formatter):
|
||||
FORMATS = {
|
||||
logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET,
|
||||
logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET,
|
||||
logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET,
|
||||
logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET,
|
||||
logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S")
|
||||
return formatter.format(record)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler: StreamHandler = StreamHandler()
|
||||
handler.setFormatter(ColorFormatter())
|
||||
log = logging.getLogger(__name__)
|
||||
log: Logger = logging.getLogger(__name__)
|
||||
log.setLevel(logging.INFO)
|
||||
log.addHandler(handler)
|
||||
log.propagate = False
|
||||
|
||||
|
||||
class HeadOrientation:
|
||||
def __init__(self, use_terminal=False):
|
||||
self.orientation_offset = 5500
|
||||
self.o1_neutral = 19000
|
||||
self.o2_neutral = 0
|
||||
self.o3_neutral = 0
|
||||
self.calibration_samples = []
|
||||
self.calibration_complete = False
|
||||
self.calibration_sample_count = 10
|
||||
self.fig = None
|
||||
self.ax = None
|
||||
self.arrow = None
|
||||
self.animation = None
|
||||
self.use_terminal = use_terminal
|
||||
def __init__(self, use_terminal: bool = False) -> None:
|
||||
self.orientation_offset: int = 5500
|
||||
self.o1_neutral: int = 19000
|
||||
self.o2_neutral: int = 0
|
||||
self.o3_neutral: int = 0
|
||||
self.calibration_samples: List[List[int]] = []
|
||||
self.calibration_complete: bool = False
|
||||
self.calibration_sample_count: int = 10
|
||||
self.fig: Optional[Figure] = None
|
||||
self.ax: Optional[Axes] = None
|
||||
self.arrow: Any = None
|
||||
self.animation: Optional[FuncAnimation] = None
|
||||
self.use_terminal: bool = use_terminal
|
||||
|
||||
def reset_calibration(self):
|
||||
def reset_calibration(self) -> None:
|
||||
self.calibration_samples = []
|
||||
self.calibration_complete = False
|
||||
|
||||
def add_calibration_sample(self, orientation_values):
|
||||
def add_calibration_sample(self, orientation_values: List[int]) -> bool:
|
||||
if len(self.calibration_samples) < self.calibration_sample_count:
|
||||
self.calibration_samples.append(orientation_values)
|
||||
return False
|
||||
@@ -66,57 +46,58 @@ class HeadOrientation:
|
||||
return True
|
||||
return True
|
||||
|
||||
def _calculate_calibration(self):
|
||||
def _calculate_calibration(self) -> None:
|
||||
if len(self.calibration_samples) < 3:
|
||||
log.warning("Not enough calibration samples")
|
||||
return
|
||||
samples = np.array(self.calibration_samples)
|
||||
self.o1_neutral = np.mean(samples[:, 0])
|
||||
avg_o2 = np.mean(samples[:, 1])
|
||||
avg_o3 = np.mean(samples[:, 2])
|
||||
self.o2_neutral = avg_o2
|
||||
self.o3_neutral = avg_o3
|
||||
samples: NDArray[[List[int]]] = np.array(self.calibration_samples)
|
||||
self.o1_neutral: float = np.mean(samples[:, 0])
|
||||
avg_o2: float = np.mean(samples[:, 1])
|
||||
avg_o3: float = np.mean(samples[:, 2])
|
||||
self.o2_neutral: float = avg_o2
|
||||
self.o3_neutral: float = avg_o3
|
||||
log.info("Calibration complete: o1_neutral=%.2f, o2_neutral=%.2f, o3_neutral=%.2f",
|
||||
self.o1_neutral, self.o2_neutral, self.o3_neutral)
|
||||
self.calibration_complete = True
|
||||
|
||||
def calculate_orientation(self, o1, o2, o3):
|
||||
def calculate_orientation(self, o1: float, o2: float, o3: float) -> Dict[str, float]:
|
||||
if not self.calibration_complete:
|
||||
return {'pitch': 0, 'yaw': 0}
|
||||
o1_norm = o1 - self.o1_neutral
|
||||
o2_norm = o2 - self.o2_neutral
|
||||
o3_norm = o3 - self.o3_neutral
|
||||
pitch = (o2_norm + o3_norm) / 2 / 32000 * 180
|
||||
yaw = (o2_norm - o3_norm) / 2 / 32000 * 180
|
||||
o1_norm: float = o1 - self.o1_neutral
|
||||
o2_norm: float = o2 - self.o2_neutral
|
||||
o3_norm: float = o3 - self.o3_neutral
|
||||
pitch: float = (o2_norm + o3_norm) / 2 / 32000 * 180
|
||||
yaw: float = (o2_norm - o3_norm) / 2 / 32000 * 180
|
||||
return {'pitch': pitch, 'yaw': yaw}
|
||||
|
||||
def create_face_art(self, pitch, yaw):
|
||||
def create_face_art(self, pitch: float, yaw: float) -> str:
|
||||
if self.use_terminal:
|
||||
try:
|
||||
ts = os.get_terminal_size()
|
||||
ts: TerminalSize = os.get_terminal_size()
|
||||
width, height = ts.columns, ts.lines * 2
|
||||
except Exception:
|
||||
width, height = 80, 40
|
||||
else:
|
||||
width, height = 80, 40
|
||||
center_x, center_y = width // 2, height // 2
|
||||
radius = (min(width, height) // 2 - 2) // 2
|
||||
pitch_rad = math.radians(pitch)
|
||||
yaw_rad = math.radians(yaw)
|
||||
canvas = drawille.Canvas()
|
||||
def rotate_point(x, y, z, pitch_r, yaw_r):
|
||||
radius: int = (min(width, height) // 2 - 2) // 2
|
||||
pitch_rad: float = math.radians(pitch)
|
||||
yaw_rad: float = math.radians(yaw)
|
||||
canvas: Canvas = Canvas()
|
||||
|
||||
def rotate_point(x: float, y: float, z: float, pitch_r: float, yaw_r: float) -> Tuple[int, int]:
|
||||
cos_y, sin_y = math.cos(yaw_r), math.sin(yaw_r)
|
||||
cos_p, sin_p = math.cos(pitch_r), math.sin(pitch_r)
|
||||
x1 = x * cos_y - z * sin_y
|
||||
z1 = x * sin_y + z * cos_y
|
||||
y1 = y * cos_p - z1 * sin_p
|
||||
z2 = y * sin_p + z1 * cos_p
|
||||
scale = 1 + (z2 / width)
|
||||
x1: float = x * cos_y - z * sin_y
|
||||
z1: float = x * sin_y + z * cos_y
|
||||
y1: float = y * cos_p - z1 * sin_p
|
||||
z2: float = y * sin_p + z1 * cos_p
|
||||
scale: float = 1 + (z2 / width)
|
||||
return int(center_x + x1 * scale), int(center_y + y1 * scale)
|
||||
for angle in range(0, 360, 2):
|
||||
rad = math.radians(angle)
|
||||
x = radius * math.cos(rad)
|
||||
y = radius * math.sin(rad)
|
||||
rad: float = math.radians(angle)
|
||||
x: float = radius * math.cos(rad)
|
||||
y: float = radius * math.sin(rad)
|
||||
x1, y1 = rotate_point(x, y, 0, pitch_rad, yaw_rad)
|
||||
canvas.set(x1, y1)
|
||||
for eye in [(-radius//2, -radius//3, 2), (radius//2, -radius//3, 2)]:
|
||||
@@ -129,14 +110,14 @@ class HeadOrientation:
|
||||
for dx in [-1, 0, 1]:
|
||||
for dy in [-1, 0, 1]:
|
||||
canvas.set(nx + dx, ny + dy)
|
||||
smile_depth = radius // 8
|
||||
mouth_local_y = radius // 4
|
||||
mouth_length = radius
|
||||
smile_depth: int = radius // 8
|
||||
mouth_local_y: int = radius // 4
|
||||
mouth_length: int = radius
|
||||
for x_offset in range(-mouth_length // 2, mouth_length // 2 + 1):
|
||||
norm = abs(x_offset) / (mouth_length / 2)
|
||||
y_offset = int((1 - norm ** 2) * smile_depth)
|
||||
local_x = x_offset
|
||||
local_y = mouth_local_y + y_offset
|
||||
norm: float = abs(x_offset) / (mouth_length / 2)
|
||||
y_offset: int = int((1 - norm ** 2) * smile_depth)
|
||||
local_x: int = x_offset
|
||||
local_y: int = mouth_local_y + y_offset
|
||||
mx, my = rotate_point(local_x, local_y, 0, pitch_rad, yaw_rad)
|
||||
canvas.set(mx, my)
|
||||
return canvas.frame()
|
||||
|
||||
@@ -1,61 +1,41 @@
|
||||
import struct
|
||||
import bluetooth
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
import os
|
||||
import asciichartpy as acp
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
from bluetooth import BluetoothSocket
|
||||
from colors import *
|
||||
from connection_manager import ConnectionManager
|
||||
from datetime import datetime as DateTime
|
||||
from drawille import Canvas
|
||||
from head_orientation import HeadOrientation
|
||||
from logging import Logger, StreamHandler
|
||||
from matplotlib.animation import FuncAnimation
|
||||
from matplotlib.legend import Legend
|
||||
from matplotlib.pyplot import Axes, Figure
|
||||
from numpy.typing import NDArray
|
||||
from rich.live import Live
|
||||
from rich.layout import Layout
|
||||
from rich.panel import Panel
|
||||
from rich.console import Console
|
||||
import drawille
|
||||
from head_orientation import HeadOrientation
|
||||
import logging
|
||||
from connection_manager import ConnectionManager
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Dict, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
WHITE = "\033[97m"
|
||||
BG_BLACK = "\033[40m"
|
||||
|
||||
class ColorFormatter(logging.Formatter):
|
||||
FORMATS = {
|
||||
logging.DEBUG: Colors.BLUE + "[%(levelname)s] %(message)s" + Colors.RESET,
|
||||
logging.INFO: Colors.GREEN + "%(message)s" + Colors.RESET,
|
||||
logging.WARNING: Colors.YELLOW + "%(message)s" + Colors.RESET,
|
||||
logging.ERROR: Colors.RED + "[%(levelname)s] %(message)s" + Colors.RESET,
|
||||
logging.CRITICAL: Colors.RED + Colors.BOLD + "[%(levelname)s] %(message)s" + Colors.RESET
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, datefmt="%H:%M:%S")
|
||||
return formatter.format(record)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler: StreamHandler = StreamHandler()
|
||||
handler.setFormatter(ColorFormatter())
|
||||
logger = logging.getLogger("airpods-head-tracking")
|
||||
logger: Logger = logging.getLogger("airpods-head-tracking")
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = True
|
||||
|
||||
INIT_CMD = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
|
||||
NOTIF_CMD = "04 00 04 00 0F 00 FF FF FE FF"
|
||||
START_CMD = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
|
||||
STOP_CMD = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
|
||||
INIT_CMD: str = "00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00"
|
||||
NOTIF_CMD: str = "04 00 04 00 0F 00 FF FF FE FF"
|
||||
START_CMD: str = "04 00 04 00 17 00 00 00 10 00 10 00 08 A1 02 42 0B 08 0E 10 02 1A 05 01 40 9C 00 00"
|
||||
STOP_CMD: str = "04 00 04 00 17 00 00 00 10 00 11 00 08 7E 10 02 42 0B 08 4E 10 02 1A 05 01 00 00 00 00"
|
||||
|
||||
KEY_FIELDS = {
|
||||
KEY_FIELDS: Dict[str, Tuple[int, int]] = {
|
||||
"orientation 1": (43, 2),
|
||||
"orientation 2": (45, 2),
|
||||
"orientation 3": (47, 2),
|
||||
@@ -68,28 +48,28 @@ KEY_FIELDS = {
|
||||
}
|
||||
|
||||
class AirPodsTracker:
|
||||
def __init__(self):
|
||||
self.sock = None
|
||||
self.recording = False
|
||||
self.log_file = None
|
||||
self.listener_thread = None
|
||||
self.bt_addr = "28:2D:7F:C2:05:5B"
|
||||
self.psm = 0x1001
|
||||
self.raw_packets = []
|
||||
self.parsed_packets = []
|
||||
self.live_data = []
|
||||
self.live_plotting = False
|
||||
self.animation = None
|
||||
self.fig = None
|
||||
self.axes = None
|
||||
self.lines = {}
|
||||
self.selected_fields = []
|
||||
self.data_lock = threading.Lock()
|
||||
self.orientation_offset = 5500
|
||||
self.use_terminal = True # '--terminal' in sys.argv
|
||||
self.orientation_visualizer = HeadOrientation(use_terminal=self.use_terminal)
|
||||
def __init__(self) -> None:
|
||||
self.sock: BluetoothSocket = None
|
||||
self.recording: bool = False
|
||||
self.log_file: Optional[TextIO] = None
|
||||
self.listener_thread: Optional[Thread] = None
|
||||
self.bt_addr: str = "28:2D:7F:C2:05:5B"
|
||||
self.psm: int = 0x1001
|
||||
self.raw_packets: List[bytes] = []
|
||||
self.parsed_packets: List[bytes] = []
|
||||
self.live_data: List[bytes] = []
|
||||
self.live_plotting: bool = False
|
||||
self.animation: FuncAnimation = None
|
||||
self.fig: Optional[Figure] = None
|
||||
self.axes: Optional[Axes] = None
|
||||
self.lines: Dict[str, Any] = {}
|
||||
self.selected_fields: List[str] = []
|
||||
self.data_lock: Lock = Lock()
|
||||
self.orientation_offset: int = 5500
|
||||
self.use_terminal: bool = True # '--terminal' in sys.argv
|
||||
self.orientation_visualizer: HeadOrientation = HeadOrientation(use_terminal=self.use_terminal)
|
||||
|
||||
self.conn = None
|
||||
self.conn: Optional[ConnectionManager] = None
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
@@ -102,35 +82,35 @@ class AirPodsTracker:
|
||||
self.sock.send(bytes.fromhex(NOTIF_CMD))
|
||||
logger.info("Sent initialization command.")
|
||||
|
||||
self.listener_thread = threading.Thread(target=self.listen, daemon=True)
|
||||
self.listener_thread = Thread(target=self.listen, daemon=True)
|
||||
self.listener_thread.start()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Connection error: %s", e)
|
||||
return False
|
||||
|
||||
def start_tracking(self, duration=None):
|
||||
def start_tracking(self, duration: Optional[float] = None) -> None:
|
||||
if not self.recording:
|
||||
self.conn.send_start()
|
||||
filename = "head_tracking_" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".log"
|
||||
filename: str = f"head_tracking_{DateTime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||
self.log_file = open(filename, "w")
|
||||
self.recording = True
|
||||
logger.info("Recording started. Saving data to %s", filename)
|
||||
|
||||
if duration is not None and duration > 0:
|
||||
def auto_stop():
|
||||
def auto_stop() -> None:
|
||||
time.sleep(duration)
|
||||
if self.recording:
|
||||
self.stop_tracking()
|
||||
logger.info("Recording automatically stopped after %s seconds.", duration)
|
||||
|
||||
timer_thread = threading.Thread(target=auto_stop, daemon=True)
|
||||
timer_thread = Thread(target=auto_stop, daemon=True)
|
||||
timer_thread.start()
|
||||
logger.info("Will automatically stop recording after %s seconds.", duration)
|
||||
else:
|
||||
logger.info("Already recording.")
|
||||
|
||||
def stop_tracking(self):
|
||||
def stop_tracking(self) -> None:
|
||||
if self.recording:
|
||||
self.conn.send_stop()
|
||||
self.recording = False
|
||||
@@ -141,39 +121,41 @@ class AirPodsTracker:
|
||||
else:
|
||||
logger.info("Not currently recording.")
|
||||
|
||||
def format_hex(self, data):
|
||||
hex_str = data.hex()
|
||||
def format_hex(self, data: bytes) -> str:
|
||||
hex_str: str = data.hex()
|
||||
return ' '.join(hex_str[i:i + 2] for i in range(0, len(hex_str), 2))
|
||||
|
||||
def parse_raw_packet(self, hex_string):
|
||||
def parse_raw_packet(self, hex_string: str) -> bytes:
|
||||
return bytes.fromhex(hex_string.replace(" ", ""))
|
||||
|
||||
def interpret_bytes(self, raw_bytes, start, length, data_type="signed_short"):
|
||||
def interpret_bytes(self, raw_bytes: bytes, start: int, length: int, data_type: str = "signed_short") -> Optional[Union[int, float]]:
|
||||
if start + length > len(raw_bytes):
|
||||
return None
|
||||
|
||||
if data_type == "signed_short":
|
||||
return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=True)
|
||||
elif data_type == "unsigned_short":
|
||||
return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=False)
|
||||
elif data_type == "signed_short_be":
|
||||
return int.from_bytes(raw_bytes[start:start + 2], byteorder='big', signed=True)
|
||||
elif data_type == "float_le":
|
||||
if start + 4 <= len(raw_bytes):
|
||||
return struct.unpack('<f', raw_bytes[start:start + 4])[0]
|
||||
elif data_type == "float_be":
|
||||
if start + 4 <= len(raw_bytes):
|
||||
return struct.unpack('>f', raw_bytes[start:start + 4])[0]
|
||||
return None
|
||||
match data_type:
|
||||
case "signed_short":
|
||||
return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=True)
|
||||
case "unsigned_short":
|
||||
return int.from_bytes(raw_bytes[start:start + 2], byteorder='little', signed=False)
|
||||
case "signed_short_be":
|
||||
return int.from_bytes(raw_bytes[start:start + 2], byteorder='big', signed=True)
|
||||
case "float_le":
|
||||
if start + 4 <= len(raw_bytes):
|
||||
return struct.unpack('<f', raw_bytes[start:start + 4])[0]
|
||||
case "float_be":
|
||||
if start + 4 <= len(raw_bytes):
|
||||
return struct.unpack('>f', raw_bytes[start:start + 4])[0]
|
||||
case _:
|
||||
return None
|
||||
|
||||
def normalize_orientation(self, value, field_name):
|
||||
def normalize_orientation(self, value: Optional[Union[int, float]], field_name: str) -> Optional[Union[int, float]]:
|
||||
if 'orientation' in field_name.lower():
|
||||
return value + self.orientation_offset
|
||||
|
||||
return value
|
||||
|
||||
def parse_packet_all_fields(self, raw_bytes):
|
||||
packet = {}
|
||||
def parse_packet_all_fields(self, raw_bytes: bytes) -> Dict[str, Union[int, float]]:
|
||||
packet: Dict[str, Union[int, float]] = {}
|
||||
|
||||
packet["seq_num"] = int.from_bytes(raw_bytes[12:14], byteorder='little')
|
||||
|
||||
@@ -186,14 +168,14 @@ class AirPodsTracker:
|
||||
packet[field_name] = self.normalize_orientation(raw_value, field_name)
|
||||
|
||||
for i in range(30, min(90, len(raw_bytes) - 1), 2):
|
||||
field_name = f"byte_{i:02d}"
|
||||
raw_value = self.interpret_bytes(raw_bytes, i, 2, "signed_short")
|
||||
field_name: str = f"byte_{i:02d}"
|
||||
raw_value: Optional[Union[int, float]] = self.interpret_bytes(raw_bytes, i, 2, "signed_short")
|
||||
if raw_value is not None:
|
||||
packet[field_name] = self.normalize_orientation(raw_value, field_name)
|
||||
|
||||
return packet
|
||||
|
||||
def apply_dark_theme(self, fig, axes):
|
||||
def apply_dark_theme(self, fig: Figure, axes: List[Axes]) -> None:
|
||||
fig.patch.set_facecolor('#1e1e1e')
|
||||
for ax in axes:
|
||||
ax.set_facecolor('#2d2d2d')
|
||||
@@ -210,21 +192,21 @@ class AirPodsTracker:
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color('#555555')
|
||||
|
||||
legend = ax.get_legend()
|
||||
legend: Optional[Legend] = ax.get_legend()
|
||||
if (legend):
|
||||
legend.get_frame().set_facecolor('#2d2d2d')
|
||||
legend.get_frame().set_alpha(0.7)
|
||||
for text in legend.get_texts():
|
||||
text.set_color('white')
|
||||
|
||||
def listen(self):
|
||||
def listen(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
data = self.sock.recv(1024)
|
||||
formatted = self.format_hex(data)
|
||||
timestamp = datetime.now().isoformat()
|
||||
data: bytes = self.sock.recv(1024)
|
||||
formatted: str = self.format_hex(data)
|
||||
timestamp: str = DateTime.now().isoformat()
|
||||
|
||||
is_valid = self.is_valid_tracking_packet(formatted)
|
||||
is_valid: bool = self.is_valid_tracking_packet(formatted)
|
||||
|
||||
if not self.live_plotting:
|
||||
if is_valid:
|
||||
@@ -238,8 +220,8 @@ class AirPodsTracker:
|
||||
self.log_file.flush()
|
||||
|
||||
try:
|
||||
raw_bytes = self.parse_raw_packet(formatted)
|
||||
packet = self.parse_packet_all_fields(raw_bytes)
|
||||
raw_bytes: bytes = self.parse_raw_packet(formatted)
|
||||
packet: Dict[str, Union[int, float]] = self.parse_packet_all_fields(raw_bytes)
|
||||
|
||||
with self.data_lock:
|
||||
self.live_data.append(packet)
|
||||
@@ -253,7 +235,7 @@ class AirPodsTracker:
|
||||
logger.error("Error receiving data: %s", e)
|
||||
break
|
||||
|
||||
def load_log_file(self, filepath):
|
||||
def load_log_file(self, filepath: str) -> bool:
|
||||
self.raw_packets = []
|
||||
self.parsed_packets = []
|
||||
try:
|
||||
@@ -262,11 +244,11 @@ class AirPodsTracker:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
raw_bytes = self.parse_raw_packet(line)
|
||||
raw_bytes: bytes = self.parse_raw_packet(line)
|
||||
self.raw_packets.append(raw_bytes)
|
||||
packet = self.parse_packet_all_fields(raw_bytes)
|
||||
packet: Dict[str, Union[int, float]] = self.parse_packet_all_fields(raw_bytes)
|
||||
|
||||
min_seq_num = min(
|
||||
min_seq_num: int = min(
|
||||
[parsed_packet["seq_num"] for parsed_packet in self.parsed_packets], default=0
|
||||
)
|
||||
|
||||
@@ -282,26 +264,26 @@ class AirPodsTracker:
|
||||
logger.error(f"Error loading log file: {e}")
|
||||
return False
|
||||
|
||||
def extract_field_values(self, field_name, data_source='loaded'):
|
||||
def extract_field_values(self, field_name: str, data_source: str = 'loaded') -> List[Union[int, float]]:
|
||||
if data_source == 'loaded':
|
||||
data = self.parsed_packets
|
||||
data: List[Dict[str, Union[int, float]]] = self.parsed_packets
|
||||
else:
|
||||
with self.data_lock:
|
||||
data = self.live_data.copy()
|
||||
data: List[Dict[str, Union[int, float]]] = self.live_data.copy()
|
||||
|
||||
values = [packet.get(field_name, 0) for packet in data if field_name in packet]
|
||||
values: List[Union[int, float]] = [packet.get(field_name, 0) for packet in data if field_name in packet]
|
||||
|
||||
if data_source == 'live' and len(values) > 5:
|
||||
try:
|
||||
values = np.array(values, dtype=float)
|
||||
values: NDArray[Any] = np.array(values, dtype=float)
|
||||
values = np.convolve(values, np.ones(5) / 5, mode='valid')
|
||||
except Exception as e:
|
||||
logger.warning(f"Smoothing error (non-critical): {e}")
|
||||
|
||||
return values
|
||||
|
||||
def is_valid_tracking_packet(self, hex_string):
|
||||
standard_header = "04 00 04 00 17 00 00 00 10 00"
|
||||
def is_valid_tracking_packet(self, hex_string: str) -> bool:
|
||||
standard_header: str = "04 00 04 00 17 00 00 00 10 00"
|
||||
|
||||
if not hex_string.startswith(standard_header):
|
||||
if self.live_plotting:
|
||||
@@ -316,13 +298,13 @@ class AirPodsTracker:
|
||||
return True
|
||||
|
||||
|
||||
def plot_fields(self, field_names=None):
|
||||
def plot_fields(self, field_names: Optional[List[str]] = None) -> None:
|
||||
if not self.parsed_packets:
|
||||
logger.error("No data to plot. Load a log file first.")
|
||||
return
|
||||
|
||||
if field_names is None:
|
||||
field_names = list(KEY_FIELDS.keys())
|
||||
field_names: List[str] = list(KEY_FIELDS.keys())
|
||||
|
||||
if not self.orientation_visualizer.calibration_complete:
|
||||
if len(self.parsed_packets) < self.orientation_visualizer.calibration_sample_count:
|
||||
@@ -339,16 +321,16 @@ class AirPodsTracker:
|
||||
self._plot_fields_terminal(field_names)
|
||||
|
||||
else:
|
||||
acceleration_fields = [f for f in field_names if 'acceleration' in f.lower()]
|
||||
orientation_fields = [f for f in field_names if 'orientation' in f.lower()]
|
||||
other_fields = [f for f in field_names if f not in acceleration_fields + orientation_fields]
|
||||
acceleration_fields: List[str] = [f for f in field_names if 'acceleration' in f.lower()]
|
||||
orientation_fields: List[str] = [f for f in field_names if 'orientation' in f.lower()]
|
||||
other_fields: List[str] = [f for f in field_names if f not in acceleration_fields + orientation_fields]
|
||||
|
||||
fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True)
|
||||
self.apply_dark_theme(fig, axes)
|
||||
|
||||
acceleration_colors = ['#FFFF00', '#00FFFF']
|
||||
orientation_colors = ['#FF00FF', '#00FF00', '#FFA500']
|
||||
other_colors = ['#52b788', '#f4a261', '#e76f51', '#2a9d8f']
|
||||
acceleration_colors: List[str] = ['#FFFF00', '#00FFFF']
|
||||
orientation_colors: List[str] = ['#FF00FF', '#00FF00', '#FFA500']
|
||||
other_colors: List[str] = ['#52b788', '#f4a261', '#e76f51', '#2a9d8f']
|
||||
|
||||
if acceleration_fields:
|
||||
for i, field in enumerate(acceleration_fields):
|
||||
@@ -375,17 +357,17 @@ class AirPodsTracker:
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
def _plot_fields_terminal(self, field_names):
|
||||
def _plot_fields_terminal(self, field_names: List[str]) -> None:
|
||||
"""Internal method for terminal-based plotting"""
|
||||
terminal_width = os.get_terminal_size().columns
|
||||
plot_width = min(terminal_width - 10, 120)
|
||||
plot_height = 15
|
||||
terminal_width: int = os.get_terminal_size().columns
|
||||
plot_width: int = min(terminal_width - 10, 120)
|
||||
plot_height: int = 15
|
||||
|
||||
acceleration_fields = [f for f in field_names if 'acceleration' in f.lower()]
|
||||
orientation_fields = [f for f in field_names if 'orientation' in f.lower()]
|
||||
other_fields = [f for f in field_names if f not in acceleration_fields + orientation_fields]
|
||||
acceleration_fields: List[str] = [f for f in field_names if 'acceleration' in f.lower()]
|
||||
orientation_fields: List[str] = [f for f in field_names if 'orientation' in f.lower()]
|
||||
other_fields: List[str] = [f for f in field_names if f not in acceleration_fields + orientation_fields]
|
||||
|
||||
def plot_group(fields, title):
|
||||
def plot_group(fields: List[str], title: str) -> None:
|
||||
if not fields:
|
||||
return
|
||||
|
||||
@@ -393,40 +375,39 @@ class AirPodsTracker:
|
||||
print("=" * len(title))
|
||||
|
||||
for field in fields:
|
||||
values = self.extract_field_values(field)
|
||||
values: List[float] = self.extract_field_values(field)
|
||||
if len(values) > plot_width:
|
||||
values = values[-plot_width:]
|
||||
|
||||
if title == "Acceleration Data":
|
||||
chart = acp.plot(values, {'height': plot_height})
|
||||
chart: str = acp.plot(values, {'height': plot_height})
|
||||
print(chart)
|
||||
else:
|
||||
chart = acp.plot(values, {'height': plot_height})
|
||||
chart: str = acp.plot(values, {'height': plot_height})
|
||||
print(chart)
|
||||
|
||||
print(f"Min: {min(values):.2f}, Max: {max(values):.2f}, " +
|
||||
f"Mean: {np.mean(values):.2f}")
|
||||
print(f"Min: {min(values):.2f}, Max: {max(values):.2f}, " + f"Mean: {np.mean(values):.2f}")
|
||||
print()
|
||||
|
||||
plot_group(acceleration_fields, "Acceleration Data")
|
||||
plot_group(orientation_fields, "Orientation Data")
|
||||
plot_group(other_fields, "Other Fields")
|
||||
|
||||
def create_braille_plot(self, values, width=80, height=20, y_label=True, fixed_y_min=None, fixed_y_max=None):
|
||||
canvas = drawille.Canvas()
|
||||
def create_braille_plot(self, values: List[float], width: int = 80, height: int = 20, y_label: bool = True, fixed_y_min: Optional[float] = None, fixed_y_max: Optional[float] = None) -> str:
|
||||
canvas: Canvas = Canvas()
|
||||
if fixed_y_min is None or fixed_y_max is None:
|
||||
local_min, local_max = min(values), max(values)
|
||||
else:
|
||||
local_min, local_max = fixed_y_min, fixed_y_max
|
||||
y_range = local_max - local_min or 1
|
||||
x_step = max(1, len(values) // width)
|
||||
y_range: float = local_max - local_min or 1
|
||||
x_step: int = max(1, len(values) // width)
|
||||
for i, v in enumerate(values[::x_step]):
|
||||
y = int(((v - local_min) / y_range) * (height * 2 - 1))
|
||||
y: int = int(((v - local_min) / y_range) * (height * 2 - 1))
|
||||
canvas.set(i, y)
|
||||
frame = canvas.frame()
|
||||
frame: str = canvas.frame()
|
||||
if y_label:
|
||||
lines = frame.split('\n')
|
||||
labeled_lines = []
|
||||
lines: List[str] = frame.split('\n')
|
||||
labeled_lines: List[str] = []
|
||||
for idx, line in enumerate(lines):
|
||||
if idx == 0:
|
||||
labeled_lines.append(f"{local_max:6.0f} {line}")
|
||||
@@ -437,17 +418,17 @@ class AirPodsTracker:
|
||||
frame = "\n".join(labeled_lines)
|
||||
return frame
|
||||
|
||||
def _start_live_plotting_terminal(self, record_data=False, duration=None):
|
||||
def _start_live_plotting_terminal(self, record_data: bool = False, duration: Optional[float] = None) -> None:
|
||||
import sys, select, tty, termios
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
tty.setcbreak(sys.stdin.fileno())
|
||||
console = Console()
|
||||
term_width = console.width
|
||||
plot_width = round(min(term_width / 2 - 15, 120))
|
||||
ori_height = 10
|
||||
console: Console = Console()
|
||||
term_width: int = console.width
|
||||
plot_width: int = round(min(term_width / 2 - 15, 120))
|
||||
ori_height: int = 10
|
||||
|
||||
def make_compact_layout():
|
||||
layout = Layout()
|
||||
def make_compact_layout() -> Layout:
|
||||
layout: Layout = Layout()
|
||||
layout.split_column(
|
||||
Layout(name="header", size=3),
|
||||
Layout(name="main", ratio=1),
|
||||
@@ -466,7 +447,7 @@ class AirPodsTracker:
|
||||
)
|
||||
return layout
|
||||
|
||||
layout = make_compact_layout()
|
||||
layout: Layout = make_compact_layout()
|
||||
|
||||
try:
|
||||
import time
|
||||
@@ -479,76 +460,76 @@ class AirPodsTracker:
|
||||
logger.info("Paused" if self.paused else "Resumed")
|
||||
if self.paused:
|
||||
time.sleep(0.1)
|
||||
rec_str = " [red][REC][/red]" if record_data else ""
|
||||
left = "AirPods Head Tracking - v1.0.0"
|
||||
right = "Ctrl+C - Close | p - Pause" + rec_str
|
||||
status = "[bold red]Paused[/bold red]"
|
||||
header = list(" " * term_width)
|
||||
rec_str: str = " [red][REC][/red]" if record_data else ""
|
||||
left: str = "AirPods Head Tracking - v1.0.0"
|
||||
right: str = "Ctrl+C - Close | p - Pause" + rec_str
|
||||
status: str = "[bold red]Paused[/bold red]"
|
||||
header: List[str] = list(" " * term_width)
|
||||
header[0:len(left)] = list(left)
|
||||
header[term_width - len(right):] = list(right)
|
||||
start = (term_width - len(status)) // 2
|
||||
start: int = (term_width - len(status)) // 2
|
||||
header[start:start+len(status)] = list(status)
|
||||
header_text = "".join(header)
|
||||
header_text: str = "".join(header)
|
||||
layout["header"].update(Panel(header_text, style="bold white on black"))
|
||||
continue
|
||||
|
||||
with self.data_lock:
|
||||
if len(self.live_data) < 1:
|
||||
continue
|
||||
latest = self.live_data[-1]
|
||||
data = self.live_data[-plot_width:]
|
||||
latest: Dict[str, float] = self.live_data[-1]
|
||||
data: List[Dict[str, float]] = self.live_data[-plot_width:]
|
||||
|
||||
if not self.orientation_visualizer.calibration_complete:
|
||||
sample = [
|
||||
sample: List[float] = [
|
||||
latest.get('orientation 1', 0),
|
||||
latest.get('orientation 2', 0),
|
||||
latest.get('orientation 3', 0)
|
||||
]
|
||||
self.orientation_visualizer.add_calibration_sample(sample)
|
||||
time.sleep(0.05)
|
||||
rec_str = " [red][REC][/red]" if record_data else ""
|
||||
rec_str: str = " [red][REC][/red]" if record_data else ""
|
||||
|
||||
left = "AirPods Head Tracking - v1.0.0"
|
||||
status = "[bold yellow]Calibrating...[/bold yellow]"
|
||||
right = "Ctrl+C - Close | p - Pause"
|
||||
remaining = max(term_width - len(left) - len(right), 0)
|
||||
header_text = f"{left}{status.center(remaining)}{right}{rec_str}"
|
||||
left: str = "AirPods Head Tracking - v1.0.0"
|
||||
status: str = "[bold yellow]Calibrating...[/bold yellow]"
|
||||
right: str = "Ctrl+C - Close | p - Pause"
|
||||
remaining: int = max(term_width - len(left) - len(right), 0)
|
||||
header_text: str = f"{left}{status.center(remaining)}{right}{rec_str}"
|
||||
layout["header"].update(Panel(header_text, style="bold white on black"))
|
||||
live.refresh()
|
||||
continue
|
||||
|
||||
o1 = latest.get('orientation 1', 0)
|
||||
o2 = latest.get('orientation 2', 0)
|
||||
o3 = latest.get('orientation 3', 0)
|
||||
orientation = self.orientation_visualizer.calculate_orientation(o1, o2, o3)
|
||||
pitch = orientation['pitch']
|
||||
yaw = orientation['yaw']
|
||||
o1: float = latest.get('orientation 1', 0)
|
||||
o2: float = latest.get('orientation 2', 0)
|
||||
o3: float = latest.get('orientation 3', 0)
|
||||
orientation: Dict[str, float] = self.orientation_visualizer.calculate_orientation(o1, o2, o3)
|
||||
pitch: float = orientation['pitch']
|
||||
yaw: float = orientation['yaw']
|
||||
|
||||
h_accel = [p.get('Horizontal Acceleration', 0) for p in data]
|
||||
v_accel = [p.get('Vertical Acceleration', 0) for p in data]
|
||||
h_accel: List[float] = [p.get('Horizontal Acceleration', 0) for p in data]
|
||||
v_accel: List[float] = [p.get('Vertical Acceleration', 0) for p in data]
|
||||
if len(h_accel) > plot_width:
|
||||
h_accel = h_accel[-plot_width:]
|
||||
if len(v_accel) > plot_width:
|
||||
v_accel = v_accel[-plot_width:]
|
||||
global_min = min(min(v_accel), min(h_accel))
|
||||
global_max = max(max(v_accel), max(h_accel))
|
||||
config_acc = {'height': 20, 'min': global_min, 'max': global_max}
|
||||
vert_plot = acp.plot(v_accel, config_acc)
|
||||
horiz_plot = acp.plot(h_accel, config_acc)
|
||||
global_min: float = min(min(v_accel), min(h_accel))
|
||||
global_max: float = max(max(v_accel), max(h_accel))
|
||||
config_acc: Dict[str, float] = {'height': 20, 'min': global_min, 'max': global_max}
|
||||
vert_plot: str = acp.plot(v_accel, config_acc)
|
||||
horiz_plot: str = acp.plot(h_accel, config_acc)
|
||||
|
||||
rec_str = " [red][REC][/red]" if record_data else ""
|
||||
left = "AirPods Head Tracking - v1.0.0"
|
||||
right = "Ctrl+C - Close | p - Pause" + rec_str
|
||||
status = "[bold green]Live[/bold green]"
|
||||
header = list(" " * term_width)
|
||||
rec_str: str = " [red][REC][/red]" if record_data else ""
|
||||
left: str = "AirPods Head Tracking - v1.0.0"
|
||||
right: str = "Ctrl+C - Close | p - Pause" + rec_str
|
||||
status: str = "[bold green]Live[/bold green]"
|
||||
header: List[str] = list(" " * term_width)
|
||||
header[0:len(left)] = list(left)
|
||||
header[term_width - len(right):] = list(right)
|
||||
start = (term_width - len(status)) // 2
|
||||
start: int = (term_width - len(status)) // 2
|
||||
header[start:start+len(status)] = list(status)
|
||||
header_text = "".join(header)
|
||||
header_text: str = "".join(header)
|
||||
layout["header"].update(Panel(header_text, style="bold white on black"))
|
||||
|
||||
face_art = self.orientation_visualizer.create_face_art(pitch, yaw)
|
||||
face_art: str = self.orientation_visualizer.create_face_art(pitch, yaw)
|
||||
layout["accelerations"]["vertical"].update(Panel(
|
||||
"[bold yellow]Vertical Acceleration[/]\n" +
|
||||
vert_plot + "\n" +
|
||||
@@ -563,15 +544,15 @@ class AirPodsTracker:
|
||||
))
|
||||
layout["orientations"]["face"].update(Panel(face_art, title="[green]Orientation - Visualization[/]", style="green"))
|
||||
|
||||
o2_values = [p.get('orientation 2', 0) for p in data[-plot_width:]]
|
||||
o3_values = [p.get('orientation 3', 0) for p in data[-plot_width:]]
|
||||
o2_values = o2_values[:plot_width]
|
||||
o3_values = o3_values[:plot_width]
|
||||
common_min = min(min(o2_values), min(o3_values))
|
||||
common_max = max(max(o2_values), max(o3_values))
|
||||
config_ori = {'height': ori_height, 'min': common_min, 'max': common_max, 'format': "{:6.0f}"}
|
||||
chart_o2 = acp.plot(o2_values, config_ori)
|
||||
chart_o3 = acp.plot(o3_values, config_ori)
|
||||
o2_values: List[float] = [p.get('orientation 2', 0) for p in data[-plot_width:]]
|
||||
o3_values: List[float] = [p.get('orientation 3', 0) for p in data[-plot_width:]]
|
||||
o2_values: List[float] = o2_values[:plot_width]
|
||||
o3_values: List[float] = o3_values[:plot_width]
|
||||
common_min: float = min(min(o2_values), min(o3_values))
|
||||
common_max: float = max(max(o2_values), max(o3_values))
|
||||
config_ori: Dict[str, float] = {'height': ori_height, 'min': common_min, 'max': common_max, 'format': "{:6.0f}"}
|
||||
chart_o2: str = acp.plot(o2_values, config_ori)
|
||||
chart_o3: str = acp.plot(o3_values, config_ori)
|
||||
layout["orientations"]["raw"].update(Panel(
|
||||
"[bold yellow]Orientation 1:[/]\n" + chart_o2 + "\n" +
|
||||
f"Cur: {o2_values[-1]:6.1f} | Min: {min(o2_values):6.1f} | Max: {max(o2_values):6.1f}\n\n" +
|
||||
@@ -591,10 +572,10 @@ class AirPodsTracker:
|
||||
finally:
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
|
||||
def _start_live_plotting(self, record_data=False, duration=None):
|
||||
terminal_width = os.get_terminal_size().columns
|
||||
plot_width = min(terminal_width - 10, 80)
|
||||
plot_height = 10
|
||||
def _start_live_plotting(self, record_data: bool = False, duration: Optional[float] = None) -> None:
|
||||
terminal_width: int = os.get_terminal_size().columns
|
||||
plot_width: int = min(terminal_width - 10, 80)
|
||||
plot_height: int = 10
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -605,13 +586,13 @@ class AirPodsTracker:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
data = self.live_data[-plot_width:]
|
||||
data: List[Dict[str, float]] = self.live_data[-plot_width:]
|
||||
|
||||
acceleration_fields = [f for f in KEY_FIELDS.keys() if 'acceleration' in f.lower()]
|
||||
orientation_fields = [f for f in KEY_FIELDS.keys() if 'orientation' in f.lower()]
|
||||
other_fields = [f for f in KEY_FIELDS.keys() if f not in acceleration_fields + orientation_fields]
|
||||
acceleration_fields: List[str] = [f for f in KEY_FIELDS.keys() if 'acceleration' in f.lower()]
|
||||
orientation_fields: List[str] = [f for f in KEY_FIELDS.keys() if 'orientation' in f.lower()]
|
||||
other_fields: List[str] = [f for f in KEY_FIELDS.keys() if f not in acceleration_fields + orientation_fields]
|
||||
|
||||
def plot_group(fields, title):
|
||||
def plot_group(fields: List[str], title: str) -> None:
|
||||
if not fields:
|
||||
return
|
||||
|
||||
@@ -619,9 +600,9 @@ class AirPodsTracker:
|
||||
print("=" * len(title))
|
||||
|
||||
for field in fields:
|
||||
values = [packet.get(field, 0) for packet in data if field in packet]
|
||||
values: List[float] = [packet.get(field, 0) for packet in data if field in packet]
|
||||
if len(values) > 0:
|
||||
chart = acp.plot(values, {'height': plot_height})
|
||||
chart: str = acp.plot(values, {'height': plot_height})
|
||||
print(chart)
|
||||
print(f"Current: {values[-1]:.2f}, " +
|
||||
f"Min: {min(values):.2f}, Max: {max(values):.2f}")
|
||||
@@ -641,7 +622,7 @@ class AirPodsTracker:
|
||||
self.stop_tracking()
|
||||
self.live_plotting = False
|
||||
|
||||
def start_live_plotting(self, record_data=False, duration=None):
|
||||
def start_live_plotting(self, record_data: bool = False, duration: Optional[float] = None) -> None:
|
||||
if self.sock is None:
|
||||
if not self.connect():
|
||||
logger.error("Could not connect to AirPods. Live plotting aborted.")
|
||||
@@ -660,12 +641,12 @@ class AirPodsTracker:
|
||||
self._start_live_plotting_terminal(record_data, duration)
|
||||
else:
|
||||
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
|
||||
fig = plt.figure(figsize=(14, 6))
|
||||
gs = GridSpec(1, 2, width_ratios=[1, 1])
|
||||
ax_accel = fig.add_subplot(gs[0])
|
||||
subgs = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[1], height_ratios=[2, 1])
|
||||
ax_head_top = fig.add_subplot(subgs[0], projection='3d')
|
||||
ax_ori = fig.add_subplot(subgs[1])
|
||||
fig: Figure = plt.figure(figsize=(14, 6))
|
||||
gs: GridSpec = GridSpec(1, 2, width_ratios=[1, 1])
|
||||
ax_accel: Axes = fig.add_subplot(gs[0])
|
||||
subgs: GridSpecFromSubplotSpec = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[1], height_ratios=[2, 1])
|
||||
ax_head_top: Axes = fig.add_subplot(subgs[0], projection='3d')
|
||||
ax_ori: Axes = fig.add_subplot(subgs[1])
|
||||
|
||||
ax_accel.set_title("Acceleration Data")
|
||||
ax_accel.set_xlabel("Packet Index")
|
||||
@@ -676,16 +657,16 @@ class AirPodsTracker:
|
||||
self.apply_dark_theme(fig, [ax_accel, ax_head_top, ax_ori])
|
||||
plt.ion()
|
||||
|
||||
def update_plot(_):
|
||||
def update_plot(_: int) -> None:
|
||||
with self.data_lock:
|
||||
data = self.live_data.copy()
|
||||
data: List[Dict[str, float]] = self.live_data.copy()
|
||||
if len(data) == 0:
|
||||
return
|
||||
|
||||
latest = data[-1]
|
||||
latest: Dict[str, float] = data[-1]
|
||||
|
||||
if not self.orientation_visualizer.calibration_complete:
|
||||
sample = [
|
||||
sample: List[float] = [
|
||||
latest.get('orientation 1', 0),
|
||||
latest.get('orientation 2', 0),
|
||||
latest.get('orientation 3', 0)
|
||||
@@ -696,9 +677,9 @@ class AirPodsTracker:
|
||||
fig.canvas.draw_idle()
|
||||
return
|
||||
|
||||
h_accel = [p.get('Horizontal Acceleration', 0) for p in data]
|
||||
v_accel = [p.get('Vertical Acceleration', 0) for p in data]
|
||||
x_vals = list(range(len(h_accel)))
|
||||
h_accel: List[float] = [p.get('Horizontal Acceleration', 0) for p in data]
|
||||
v_accel: List[float] = [p.get('Vertical Acceleration', 0) for p in data]
|
||||
x_vals: List[int] = list(range(len(h_accel)))
|
||||
ax_accel.cla()
|
||||
ax_accel.plot(x_vals, v_accel, label='Vertical Acceleration', color='#FFFF00', linewidth=2)
|
||||
ax_accel.plot(x_vals, h_accel, label='Horizontal Acceleration', color='#00FFFF', linewidth=2)
|
||||
@@ -711,13 +692,13 @@ class AirPodsTracker:
|
||||
ax_accel.xaxis.label.set_color('white')
|
||||
ax_accel.yaxis.label.set_color('white')
|
||||
|
||||
latest = data[-1]
|
||||
o1 = latest.get('orientation 1', 0)
|
||||
o2 = latest.get('orientation 2', 0)
|
||||
o3 = latest.get('orientation 3', 0)
|
||||
orientation = self.orientation_visualizer.calculate_orientation(o1, o2, o3)
|
||||
pitch = orientation['pitch']
|
||||
yaw = orientation['yaw']
|
||||
latest: Dict[str, float] = data[-1]
|
||||
o1: float = latest.get('orientation 1', 0)
|
||||
o2: float = latest.get('orientation 2', 0)
|
||||
o3: float = latest.get('orientation 3', 0)
|
||||
orientation: Dict[str, float] = self.orientation_visualizer.calculate_orientation(o1, o2, o3)
|
||||
pitch: float = orientation['pitch']
|
||||
yaw: float = orientation['yaw']
|
||||
|
||||
ax_head_top.cla()
|
||||
ax_head_top.set_title("Head Orientation")
|
||||
@@ -727,25 +708,25 @@ class AirPodsTracker:
|
||||
ax_head_top.set_facecolor('#2d2d2d')
|
||||
pitch_rad = np.radians(pitch)
|
||||
yaw_rad = np.radians(yaw)
|
||||
Rz = np.array([
|
||||
Rz: NDArray[Any] = np.array([
|
||||
[np.cos(yaw_rad), np.sin(yaw_rad), 0],
|
||||
[-np.sin(yaw_rad), np.cos(yaw_rad), 0],
|
||||
[0, 0, 1]
|
||||
])
|
||||
Ry = np.array([
|
||||
Ry: NDArray[Any] = np.array([
|
||||
[np.cos(pitch_rad), 0, np.sin(pitch_rad)],
|
||||
[0, 1, 0],
|
||||
[-np.sin(pitch_rad), 0, np.cos(pitch_rad)]
|
||||
])
|
||||
R = Rz @ Ry
|
||||
dir_vec = R @ np.array([1, 0, 0])
|
||||
R: NDArray[Any] = Rz @ Ry
|
||||
dir_vec: NDArray[Any] = R @ np.array([1, 0, 0])
|
||||
ax_head_top.quiver(0, 0, 0, dir_vec[0], dir_vec[1], dir_vec[2],
|
||||
color='r', length=0.8, linewidth=3)
|
||||
|
||||
ax_ori.cla()
|
||||
o2_values = [p.get('orientation 2', 0) for p in data]
|
||||
o3_values = [p.get('orientation 3', 0) for p in data]
|
||||
x_range = list(range(len(o2_values)))
|
||||
o2_values: List[float] = [p.get('orientation 2', 0) for p in data]
|
||||
o3_values: List[float] = [p.get('orientation 3', 0) for p in data]
|
||||
x_range: List[int] = list(range(len(o2_values)))
|
||||
ax_ori.plot(x_range, o2_values, label='Orientation 1', color='red', linewidth=2)
|
||||
ax_ori.plot(x_range, o3_values, label='Orientation 2', color='green', linewidth=2)
|
||||
ax_ori.set_facecolor('#2d2d2d')
|
||||
@@ -775,9 +756,9 @@ class AirPodsTracker:
|
||||
self.animation = None
|
||||
plt.ioff()
|
||||
|
||||
def interactive_mode(self):
|
||||
def interactive_mode(self) -> None:
|
||||
from prompt_toolkit import PromptSession
|
||||
session = PromptSession("> ")
|
||||
session: PromptSession = PromptSession("> ")
|
||||
logger.info("\nAirPods Head Tracking Analyzer")
|
||||
print("------------------------------")
|
||||
logger.info("Commands:")
|
||||
@@ -793,59 +774,61 @@ class AirPodsTracker:
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd_input = session.prompt("> ")
|
||||
cmd_parts = cmd_input.strip().split()
|
||||
cmd_input: str = session.prompt("> ")
|
||||
cmd_parts: List[str] = cmd_input.strip().split()
|
||||
if not cmd_parts:
|
||||
continue
|
||||
cmd = cmd_parts[0].lower()
|
||||
if cmd == "connect":
|
||||
self.connect()
|
||||
elif cmd == "start":
|
||||
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
|
||||
self.start_tracking(duration)
|
||||
elif cmd == "stop":
|
||||
self.stop_tracking()
|
||||
elif cmd == "load" and len(cmd_parts) > 1:
|
||||
self.load_log_file(cmd_parts[1])
|
||||
elif cmd == "plot":
|
||||
self.plot_fields()
|
||||
elif cmd == "live":
|
||||
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
|
||||
logger.info("Starting live plotting mode (without recording)%s.",
|
||||
f" for {duration} seconds" if duration else "")
|
||||
self.start_live_plotting(record_data=False, duration=duration)
|
||||
elif cmd == "liver":
|
||||
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
|
||||
logger.info("Starting live plotting mode WITH recording%s.",
|
||||
f" for {duration} seconds" if duration else "")
|
||||
self.start_live_plotting(record_data=True, duration=duration)
|
||||
elif cmd == "gestures":
|
||||
from gestures import GestureDetector
|
||||
if self.conn is not None:
|
||||
detector = GestureDetector(conn=self.conn)
|
||||
else:
|
||||
detector = GestureDetector()
|
||||
detector.start_detection()
|
||||
elif cmd == "quit":
|
||||
logger.info("Exiting.")
|
||||
if self.conn != None:
|
||||
self.conn.disconnect()
|
||||
break
|
||||
elif cmd == "help":
|
||||
logger.info("\nAirPods Head Tracking Analyzer")
|
||||
logger.info("------------------------------")
|
||||
logger.info("Commands:")
|
||||
logger.info(" connect - connect to your AirPods")
|
||||
logger.info(" start [seconds] - start recording head tracking data, optionally for specified duration")
|
||||
logger.info(" stop - stop recording")
|
||||
logger.info(" load <file> - load and parse a log file")
|
||||
logger.info(" plot - plot all sensor data fields")
|
||||
logger.info(" live [seconds] - start live plotting (without recording), optionally stop recording after seconds")
|
||||
logger.info(" liver [seconds] - start live plotting with recording, optionally stop recording after seconds")
|
||||
logger.info(" gestures - start gesture detection")
|
||||
logger.info(" quit - exit the program")
|
||||
else:
|
||||
logger.info("Unknown command. Type 'help' to see available commands.")
|
||||
match cmd:
|
||||
case "connect":
|
||||
self.connect()
|
||||
case "start":
|
||||
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
|
||||
self.start_tracking(duration)
|
||||
case "stop":
|
||||
self.stop_tracking()
|
||||
case "load":
|
||||
if len(cmd_parts) > 1:
|
||||
self.load_log_file(cmd_parts[1])
|
||||
case "plot":
|
||||
self.plot_fields()
|
||||
case "live":
|
||||
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
|
||||
logger.info("Starting live plotting mode (without recording)%s.",
|
||||
f" for {duration} seconds" if duration else "")
|
||||
self.start_live_plotting(record_data=False, duration=duration)
|
||||
case "liver":
|
||||
duration = float(cmd_parts[1]) if len(cmd_parts) > 1 else None
|
||||
logger.info("Starting live plotting mode WITH recording%s.",
|
||||
f" for {duration} seconds" if duration else "")
|
||||
self.start_live_plotting(record_data=True, duration=duration)
|
||||
case "gestures":
|
||||
from gestures import GestureDetector
|
||||
if self.conn is not None:
|
||||
detector: GestureDetector = GestureDetector(conn=self.conn)
|
||||
else:
|
||||
detector: GestureDetector = GestureDetector()
|
||||
detector.start_detection()
|
||||
case "quit":
|
||||
logger.info("Exiting.")
|
||||
if self.conn != None:
|
||||
self.conn.disconnect()
|
||||
break
|
||||
case "help":
|
||||
logger.info("\nAirPods Head Tracking Analyzer")
|
||||
logger.info("------------------------------")
|
||||
logger.info("Commands:")
|
||||
logger.info(" connect - connect to your AirPods")
|
||||
logger.info(" start [seconds] - start recording head tracking data, optionally for specified duration")
|
||||
logger.info(" stop - stop recording")
|
||||
logger.info(" load <file> - load and parse a log file")
|
||||
logger.info(" plot - plot all sensor data fields")
|
||||
logger.info(" live [seconds] - start live plotting (without recording), optionally stop recording after seconds")
|
||||
logger.info(" liver [seconds] - start live plotting with recording, optionally stop recording after seconds")
|
||||
logger.info(" gestures - start gesture detection")
|
||||
logger.info(" quit - exit the program")
|
||||
case _:
|
||||
logger.info("Unknown command. Type 'help' to see available commands.")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Use 'quit' to exit.")
|
||||
except EOFError:
|
||||
@@ -856,5 +839,5 @@ class AirPodsTracker:
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
tracker = AirPodsTracker()
|
||||
tracker.interactive_mode()
|
||||
tracker: AirPodsTracker = AirPodsTracker()
|
||||
tracker.interactive_mode()
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import sys
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
from queue import Queue
|
||||
import logging
|
||||
import signal
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
from socket import socket as Socket, TimeoutError
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
@@ -12,47 +15,47 @@ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %
|
||||
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QSlider, QCheckBox, QPushButton, QLineEdit, QFormLayout, QGridLayout
|
||||
from PyQt5.QtCore import Qt, QTimer, pyqtSignal, QObject
|
||||
|
||||
OPCODE_READ_REQUEST = 0x0A
|
||||
OPCODE_WRITE_REQUEST = 0x12
|
||||
OPCODE_HANDLE_VALUE_NTF = 0x1B
|
||||
OPCODE_READ_REQUEST: int = 0x0A
|
||||
OPCODE_WRITE_REQUEST: int = 0x12
|
||||
OPCODE_HANDLE_VALUE_NTF: int = 0x1B
|
||||
|
||||
ATT_HANDLES = {
|
||||
ATT_HANDLES: Dict[str, int] = {
|
||||
'TRANSPARENCY': 0x18,
|
||||
'LOUD_SOUND_REDUCTION': 0x1B,
|
||||
'HEARING_AID': 0x2A,
|
||||
}
|
||||
|
||||
ATT_CCCD_HANDLES = {
|
||||
ATT_CCCD_HANDLES: Dict[str, int] = {
|
||||
'TRANSPARENCY': ATT_HANDLES['TRANSPARENCY'] + 1,
|
||||
'LOUD_SOUND_REDUCTION': ATT_HANDLES['LOUD_SOUND_REDUCTION'] + 1,
|
||||
'HEARING_AID': ATT_HANDLES['HEARING_AID'] + 1,
|
||||
}
|
||||
|
||||
PSM_ATT = 31
|
||||
PSM_ATT: int = 31
|
||||
|
||||
class ATTManager:
|
||||
def __init__(self, mac_address):
|
||||
self.mac_address = mac_address
|
||||
self.sock = None
|
||||
self.responses = Queue()
|
||||
self.listeners = {}
|
||||
self.notification_thread = None
|
||||
self.running = False
|
||||
def __init__(self, mac_address: str) -> None:
|
||||
self.mac_address: str = mac_address
|
||||
self.sock: Optional[Socket] = None
|
||||
self.responses: Queue = Queue()
|
||||
self.listeners: Dict[int, List[Any]] = {}
|
||||
self.notification_thread: Optional[Thread] = None
|
||||
self.running: bool = False
|
||||
# Avoid logging full MAC address to prevent sensitive data exposure
|
||||
mac_tail = ':'.join(mac_address.split(':')[-2:]) if isinstance(mac_address, str) and ':' in mac_address else '[redacted]'
|
||||
mac_tail: str = ':'.join(mac_address.split(':')[-2:]) if isinstance(mac_address, str) and ':' in mac_address else '[redacted]'
|
||||
logging.info(f"ATTManager initialized")
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> None:
|
||||
logging.info("Attempting to connect to ATT socket")
|
||||
self.sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP)
|
||||
self.sock = Socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP)
|
||||
self.sock.connect((self.mac_address, PSM_ATT))
|
||||
self.sock.settimeout(0.1)
|
||||
self.running = True
|
||||
self.notification_thread = threading.Thread(target=self._listen_notifications)
|
||||
self.notification_thread = Thread(target=self._listen_notifications)
|
||||
self.notification_thread.start()
|
||||
logging.info("Connected to ATT socket")
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
logging.info("Disconnecting from ATT socket")
|
||||
self.running = False
|
||||
if self.sock:
|
||||
@@ -63,37 +66,37 @@ class ATTManager:
|
||||
self.notification_thread.join(timeout=1.0)
|
||||
logging.info("Disconnected from ATT socket")
|
||||
|
||||
def register_listener(self, handle, listener):
|
||||
def register_listener(self, handle: int, listener: Any) -> None:
|
||||
if handle not in self.listeners:
|
||||
self.listeners[handle] = []
|
||||
self.listeners[handle].append(listener)
|
||||
logging.debug(f"Registered listener for handle {handle}")
|
||||
|
||||
def unregister_listener(self, handle, listener):
|
||||
def unregister_listener(self, handle: int, listener: Any) -> None:
|
||||
if handle in self.listeners:
|
||||
self.listeners[handle].remove(listener)
|
||||
logging.debug(f"Unregistered listener for handle {handle}")
|
||||
|
||||
def enable_notifications(self, handle):
|
||||
def enable_notifications(self, handle: Any) -> None:
|
||||
self.write_cccd(handle, b'\x01\x00')
|
||||
logging.info(f"Enabled notifications for handle {handle.name}")
|
||||
|
||||
def read(self, handle):
|
||||
handle_value = ATT_HANDLES[handle.name]
|
||||
lsb = handle_value & 0xFF
|
||||
msb = (handle_value >> 8) & 0xFF
|
||||
pdu = bytes([OPCODE_READ_REQUEST, lsb, msb])
|
||||
def read(self, handle: Any) -> bytes:
|
||||
handle_value: int = ATT_HANDLES[handle.name]
|
||||
lsb: int = handle_value & 0xFF
|
||||
msb: int = (handle_value >> 8) & 0xFF
|
||||
pdu: bytes = bytes([OPCODE_READ_REQUEST, lsb, msb])
|
||||
logging.debug(f"Sending read request for handle {handle.name}: {pdu.hex()}")
|
||||
self._write_raw(pdu)
|
||||
response = self._read_response()
|
||||
response: bytes = self._read_response()
|
||||
logging.debug(f"Read response for handle {handle.name}: {response.hex()}")
|
||||
return response
|
||||
|
||||
def write(self, handle, value):
|
||||
handle_value = ATT_HANDLES[handle.name]
|
||||
lsb = handle_value & 0xFF
|
||||
msb = (handle_value >> 8) & 0xFF
|
||||
pdu = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value
|
||||
def write(self, handle: Any, value: bytes) -> None:
|
||||
handle_value: int = ATT_HANDLES[handle.name]
|
||||
lsb: int = handle_value & 0xFF
|
||||
msb: int = (handle_value >> 8) & 0xFF
|
||||
pdu: bytes = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value
|
||||
logging.debug(f"Sending write request for handle {handle.name}: {pdu.hex()}")
|
||||
self._write_raw(pdu)
|
||||
try:
|
||||
@@ -102,11 +105,11 @@ class ATTManager:
|
||||
except:
|
||||
logging.warning(f"No write response received for handle {handle.name}")
|
||||
|
||||
def write_cccd(self, handle, value):
|
||||
handle_value = ATT_CCCD_HANDLES[handle.name]
|
||||
lsb = handle_value & 0xFF
|
||||
msb = (handle_value >> 8) & 0xFF
|
||||
pdu = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value
|
||||
def write_cccd(self, handle: Any, value: bytes) -> None:
|
||||
handle_value: int = ATT_CCCD_HANDLES[handle.name]
|
||||
lsb: int = handle_value & 0xFF
|
||||
msb: int = (handle_value >> 8) & 0xFF
|
||||
pdu: bytes = bytes([OPCODE_WRITE_REQUEST, lsb, msb]) + value
|
||||
logging.debug(f"Sending CCCD write request for handle {handle.name}: {pdu.hex()}")
|
||||
self._write_raw(pdu)
|
||||
try:
|
||||
@@ -115,42 +118,42 @@ class ATTManager:
|
||||
except:
|
||||
logging.warning(f"No CCCD write response received for handle {handle.name}")
|
||||
|
||||
def _write_raw(self, pdu):
|
||||
def _write_raw(self, pdu: bytes) -> None:
|
||||
self.sock.send(pdu)
|
||||
logging.debug(f"Sent PDU: {pdu.hex()}")
|
||||
|
||||
def _read_pdu(self):
|
||||
def _read_pdu(self) -> Optional[bytes]:
|
||||
try:
|
||||
data = self.sock.recv(512)
|
||||
data: bytes = self.sock.recv(512)
|
||||
logging.debug(f"Received PDU: {data.hex()}")
|
||||
return data
|
||||
except socket.timeout:
|
||||
except TimeoutError:
|
||||
return None
|
||||
except:
|
||||
raise
|
||||
|
||||
def _read_response(self, timeout=2.0):
|
||||
def _read_response(self, timeout: float = 2.0) -> bytes:
|
||||
try:
|
||||
response = self.responses.get(timeout=timeout)[1:] # Skip opcode
|
||||
response: bytes = self.responses.get(timeout=timeout)[1:] # Skip opcode
|
||||
logging.debug(f"Response received: {response.hex()}")
|
||||
return response
|
||||
except:
|
||||
logging.error("No response received within timeout")
|
||||
raise Exception("No response received")
|
||||
|
||||
def _listen_notifications(self):
|
||||
def _listen_notifications(self) -> None:
|
||||
logging.info("Starting notification listener thread")
|
||||
while self.running:
|
||||
try:
|
||||
pdu = self._read_pdu()
|
||||
pdu: Optional[bytes] = self._read_pdu()
|
||||
except:
|
||||
break
|
||||
if pdu is None:
|
||||
continue
|
||||
if len(pdu) > 0 and pdu[0] == OPCODE_HANDLE_VALUE_NTF:
|
||||
logging.debug(f"Notification PDU received: {pdu.hex()}")
|
||||
handle = pdu[1] | (pdu[2] << 8)
|
||||
value = pdu[3:]
|
||||
handle: int = pdu[1] | (pdu[2] << 8)
|
||||
value: bytes = pdu[3:]
|
||||
logging.debug(f"Notification for handle {handle}: {value.hex()}")
|
||||
if handle in self.listeners:
|
||||
for listener in self.listeners[handle]:
|
||||
@@ -165,36 +168,36 @@ class ATTManager:
|
||||
logging.error(f"Reconnection failed: {e}")
|
||||
|
||||
class HearingAidSettings:
|
||||
def __init__(self, left_eq, right_eq, left_amp, right_amp, left_tone, right_tone,
|
||||
left_conv, right_conv, left_anr, right_anr, net_amp, balance, own_voice):
|
||||
self.left_eq = left_eq
|
||||
self.right_eq = right_eq
|
||||
self.left_amplification = left_amp
|
||||
self.right_amplification = right_amp
|
||||
self.left_tone = left_tone
|
||||
self.right_tone = right_tone
|
||||
self.left_conversation_boost = left_conv
|
||||
self.right_conversation_boost = right_conv
|
||||
self.left_ambient_noise_reduction = left_anr
|
||||
self.right_ambient_noise_reduction = right_anr
|
||||
self.net_amplification = net_amp
|
||||
self.balance = balance
|
||||
self.own_voice_amplification = own_voice
|
||||
def __init__(self, left_eq: List[float], right_eq: List[float], left_amp: float, right_amp: float, left_tone: float, right_tone: float,
|
||||
left_conv: bool, right_conv: bool, left_anr: float, right_anr: float, net_amp: float, balance: float, own_voice: float) -> None:
|
||||
self.left_eq: List[float] = left_eq
|
||||
self.right_eq: List[float] = right_eq
|
||||
self.left_amplification: float = left_amp
|
||||
self.right_amplification: float = right_amp
|
||||
self.left_tone: float = left_tone
|
||||
self.right_tone: float = right_tone
|
||||
self.left_conversation_boost: bool = left_conv
|
||||
self.right_conversation_boost: bool = right_conv
|
||||
self.left_ambient_noise_reduction: float = left_anr
|
||||
self.right_ambient_noise_reduction: float = right_anr
|
||||
self.net_amplification: float = net_amp
|
||||
self.balance: float = balance
|
||||
self.own_voice_amplification: float = own_voice
|
||||
logging.debug(f"HearingAidSettings created: amp={net_amp}, balance={balance}, tone={left_tone}, anr={left_anr}, conv={left_conv}")
|
||||
|
||||
def parse_hearing_aid_settings(data):
|
||||
def parse_hearing_aid_settings(data: bytes) -> Optional[HearingAidSettings]:
|
||||
logging.debug(f"Parsing hearing aid settings from data: {data.hex()}")
|
||||
if len(data) < 104:
|
||||
logging.warning("Data too short for parsing")
|
||||
return None
|
||||
buffer = data
|
||||
offset = 0
|
||||
buffer: bytes = data
|
||||
offset: int = 0
|
||||
|
||||
offset += 4
|
||||
|
||||
logging.info(f"Parsing hearing aid settings, starting read at offset 4, value: {buffer[offset]:02x}")
|
||||
|
||||
left_eq = []
|
||||
left_eq: List[float] = []
|
||||
for i in range(8):
|
||||
val, = struct.unpack('<f', buffer[offset:offset+4])
|
||||
left_eq.append(val)
|
||||
@@ -228,23 +231,23 @@ def parse_hearing_aid_settings(data):
|
||||
|
||||
own_voice, = struct.unpack('<f', buffer[offset:offset+4])
|
||||
|
||||
avg = (left_amp + right_amp) / 2
|
||||
amplification = max(-1, min(1, avg))
|
||||
diff = right_amp - left_amp
|
||||
balance = max(-1, min(1, diff))
|
||||
avg: float = (left_amp + right_amp) / 2
|
||||
amplification: float = max(-1, min(1, avg))
|
||||
diff: float = right_amp - left_amp
|
||||
balance: float = max(-1, min(1, diff))
|
||||
|
||||
settings = HearingAidSettings(left_eq, right_eq, left_amp, right_amp, left_tone, right_tone,
|
||||
settings: HearingAidSettings = HearingAidSettings(left_eq, right_eq, left_amp, right_amp, left_tone, right_tone,
|
||||
left_conv, right_conv, left_anr, right_anr, amplification, balance, own_voice)
|
||||
logging.info(f"Parsed settings: amp={amplification}, balance={balance}")
|
||||
return settings
|
||||
|
||||
def send_hearing_aid_settings(att_manager, settings):
|
||||
def send_hearing_aid_settings(att_manager: ATTManager, settings: HearingAidSettings) -> None:
|
||||
logging.info("Sending hearing aid settings")
|
||||
data = att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})())
|
||||
data: bytes = att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})())
|
||||
if len(data) < 104:
|
||||
logging.error("Read data too short for sending settings")
|
||||
return
|
||||
buffer = bytearray(data)
|
||||
buffer: bytearray = bytearray(data)
|
||||
|
||||
# Modify byte at index 2 to 0x64
|
||||
buffer[2] = 0x64
|
||||
@@ -272,16 +275,16 @@ def send_hearing_aid_settings(att_manager, settings):
|
||||
logging.info("Hearing aid settings sent")
|
||||
|
||||
class SignalEmitter(QObject):
|
||||
update_ui = pyqtSignal(HearingAidSettings)
|
||||
update_ui: pyqtSignal = pyqtSignal(HearingAidSettings)
|
||||
|
||||
class HearingAidApp(QWidget):
|
||||
def __init__(self, mac_address):
|
||||
def __init__(self, mac_address: str) -> None:
|
||||
super().__init__()
|
||||
self.mac_address = mac_address
|
||||
self.att_manager = ATTManager(mac_address)
|
||||
self.emitter = SignalEmitter()
|
||||
self.mac_address: str = mac_address
|
||||
self.att_manager: ATTManager = ATTManager(mac_address)
|
||||
self.emitter: SignalEmitter = SignalEmitter()
|
||||
self.emitter.update_ui.connect(self.on_update_ui)
|
||||
self.debounce_timer = QTimer()
|
||||
self.debounce_timer: QTimer = QTimer()
|
||||
self.debounce_timer.setSingleShot(True)
|
||||
self.debounce_timer.timeout.connect(self.send_settings)
|
||||
logging.info("HearingAidConfig initialized")
|
||||
@@ -289,25 +292,25 @@ class HearingAidApp(QWidget):
|
||||
self.init_ui()
|
||||
self.connect_att()
|
||||
|
||||
def init_ui(self):
|
||||
def init_ui(self) -> None:
|
||||
logging.debug("Initializing UI")
|
||||
self.setWindowTitle("Hearing Aid Adjustments")
|
||||
layout = QVBoxLayout()
|
||||
layout: QVBoxLayout = QVBoxLayout()
|
||||
|
||||
# EQ Inputs
|
||||
eq_layout = QGridLayout()
|
||||
self.left_eq_inputs = []
|
||||
self.right_eq_inputs = []
|
||||
eq_layout: QGridLayout = QGridLayout()
|
||||
self.left_eq_inputs: List[QLineEdit] = []
|
||||
self.right_eq_inputs: List[QLineEdit] = []
|
||||
|
||||
eq_labels = ["250Hz", "500Hz", "1kHz", "2kHz", "3kHz", "4kHz", "6kHz", "8kHz"]
|
||||
eq_labels: List[str] = ["250Hz", "500Hz", "1kHz", "2kHz", "3kHz", "4kHz", "6kHz", "8kHz"]
|
||||
eq_layout.addWidget(QLabel("Frequency"), 0, 0)
|
||||
eq_layout.addWidget(QLabel("Left"), 0, 1)
|
||||
eq_layout.addWidget(QLabel("Right"), 0, 2)
|
||||
|
||||
for i, label in enumerate(eq_labels):
|
||||
eq_layout.addWidget(QLabel(label), i + 1, 0)
|
||||
left_input = QLineEdit()
|
||||
right_input = QLineEdit()
|
||||
left_input: QLineEdit = QLineEdit()
|
||||
right_input: QLineEdit = QLineEdit()
|
||||
left_input.setPlaceholderText("Left")
|
||||
right_input.setPlaceholderText("Right")
|
||||
self.left_eq_inputs.append(left_input)
|
||||
@@ -315,52 +318,52 @@ class HearingAidApp(QWidget):
|
||||
eq_layout.addWidget(left_input, i + 1, 1)
|
||||
eq_layout.addWidget(right_input, i + 1, 2)
|
||||
|
||||
eq_group = QWidget()
|
||||
eq_group: QWidget = QWidget()
|
||||
eq_group.setLayout(eq_layout)
|
||||
layout.addWidget(QLabel("Loss, in dBHL"))
|
||||
layout.addWidget(eq_group)
|
||||
|
||||
# Amplification
|
||||
self.amp_slider = QSlider(Qt.Horizontal)
|
||||
self.amp_slider: QSlider = QSlider(Qt.Horizontal)
|
||||
self.amp_slider.setRange(-100, 100)
|
||||
self.amp_slider.setValue(50)
|
||||
layout.addWidget(QLabel("Amplification"))
|
||||
layout.addWidget(self.amp_slider)
|
||||
|
||||
# Balance
|
||||
self.balance_slider = QSlider(Qt.Horizontal)
|
||||
self.balance_slider: QSlider = QSlider(Qt.Horizontal)
|
||||
self.balance_slider.setRange(-100, 100)
|
||||
self.balance_slider.setValue(50)
|
||||
layout.addWidget(QLabel("Balance"))
|
||||
layout.addWidget(self.balance_slider)
|
||||
|
||||
# Tone
|
||||
self.tone_slider = QSlider(Qt.Horizontal)
|
||||
self.tone_slider: QSlider = QSlider(Qt.Horizontal)
|
||||
self.tone_slider.setRange(-100, 100)
|
||||
self.tone_slider.setValue(50)
|
||||
layout.addWidget(QLabel("Tone"))
|
||||
layout.addWidget(self.tone_slider)
|
||||
|
||||
# Ambient Noise Reduction
|
||||
self.anr_slider = QSlider(Qt.Horizontal)
|
||||
self.anr_slider: QSlider = QSlider(Qt.Horizontal)
|
||||
self.anr_slider.setRange(0, 100)
|
||||
self.anr_slider.setValue(0)
|
||||
layout.addWidget(QLabel("Ambient Noise Reduction"))
|
||||
layout.addWidget(self.anr_slider)
|
||||
|
||||
# Conversation Boost
|
||||
self.conv_checkbox = QCheckBox("Conversation Boost")
|
||||
self.conv_checkbox: QCheckBox = QCheckBox("Conversation Boost")
|
||||
layout.addWidget(self.conv_checkbox)
|
||||
|
||||
# Own Voice Amplification
|
||||
self.own_voice_slider = QSlider(Qt.Horizontal)
|
||||
self.own_voice_slider: QSlider = QSlider(Qt.Horizontal)
|
||||
self.own_voice_slider.setRange(0, 100)
|
||||
self.own_voice_slider.setValue(50)
|
||||
# layout.addWidget(QLabel("Own Voice Amplification"))
|
||||
# layout.addWidget(self.own_voice_slider) # seems to have no effect
|
||||
|
||||
# Reset button
|
||||
self.reset_button = QPushButton("Reset")
|
||||
self.reset_button: QPushButton = QPushButton("Reset")
|
||||
layout.addWidget(self.reset_button)
|
||||
|
||||
# Connect signals
|
||||
@@ -377,15 +380,15 @@ class HearingAidApp(QWidget):
|
||||
self.setLayout(layout)
|
||||
logging.debug("UI initialized")
|
||||
|
||||
def connect_att(self):
|
||||
def connect_att(self) -> None:
|
||||
logging.info("Connecting to ATT in UI")
|
||||
try:
|
||||
self.att_manager.connect()
|
||||
self.att_manager.enable_notifications(type('Handle', (), {'name': 'HEARING_AID'})())
|
||||
self.att_manager.register_listener(ATT_HANDLES['HEARING_AID'], self.on_notification)
|
||||
# Initial read
|
||||
data = self.att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})())
|
||||
settings = parse_hearing_aid_settings(data)
|
||||
data: bytes = self.att_manager.read(type('Handle', (), {'name': 'HEARING_AID'})())
|
||||
settings: Optional[HearingAidSettings] = parse_hearing_aid_settings(data)
|
||||
if settings:
|
||||
self.emitter.update_ui.emit(settings)
|
||||
logging.info("Initial settings loaded")
|
||||
@@ -396,13 +399,13 @@ class HearingAidApp(QWidget):
|
||||
else:
|
||||
logging.error(f"Connection failed: {e}")
|
||||
|
||||
def on_notification(self, value):
|
||||
def on_notification(self, value: bytes) -> None:
|
||||
logging.debug("Notification received")
|
||||
settings = parse_hearing_aid_settings(value)
|
||||
settings: Optional[HearingAidSettings] = parse_hearing_aid_settings(value)
|
||||
if settings:
|
||||
self.emitter.update_ui.emit(settings)
|
||||
|
||||
def on_update_ui(self, settings):
|
||||
def on_update_ui(self, settings: HearingAidSettings) -> None:
|
||||
logging.debug("Updating UI with settings")
|
||||
self.amp_slider.setValue(int(settings.net_amplification * 100))
|
||||
self.balance_slider.setValue(int(settings.balance * 100))
|
||||
@@ -416,30 +419,30 @@ class HearingAidApp(QWidget):
|
||||
for i, value in enumerate(settings.right_eq):
|
||||
self.right_eq_inputs[i].setText(f"{value:.2f}")
|
||||
|
||||
def on_value_changed(self):
|
||||
def on_value_changed(self) -> None:
|
||||
logging.debug("UI value changed, starting debounce")
|
||||
self.debounce_timer.start(100)
|
||||
|
||||
def send_settings(self):
|
||||
def send_settings(self) -> None:
|
||||
logging.info("Sending settings from UI")
|
||||
amp = self.amp_slider.value() / 100.0
|
||||
balance = self.balance_slider.value() / 100.0
|
||||
tone = self.tone_slider.value() / 100.0
|
||||
anr = self.anr_slider.value() / 100.0
|
||||
conv = self.conv_checkbox.isChecked()
|
||||
own_voice = self.own_voice_slider.value() / 100.0
|
||||
amp: float = self.amp_slider.value() / 100.0
|
||||
balance: float = self.balance_slider.value() / 100.0
|
||||
tone: float = self.tone_slider.value() / 100.0
|
||||
anr: float = self.anr_slider.value() / 100.0
|
||||
conv: bool = self.conv_checkbox.isChecked()
|
||||
own_voice: float = self.own_voice_slider.value() / 100.0
|
||||
|
||||
left_amp = amp + (0.5 - balance) * amp * 2 if balance < 0 else amp
|
||||
right_amp = amp + (balance - 0.5) * amp * 2 if balance > 0 else amp
|
||||
left_amp: float = amp + (0.5 - balance) * amp * 2 if balance < 0 else amp
|
||||
right_amp: float = amp + (balance - 0.5) * amp * 2 if balance > 0 else amp
|
||||
|
||||
left_eq = [float(input_box.text() or 0) for input_box in self.left_eq_inputs]
|
||||
right_eq = [float(input_box.text() or 0) for input_box in self.right_eq_inputs]
|
||||
left_eq: List[float] = [float(input_box.text() or 0) for input_box in self.left_eq_inputs]
|
||||
right_eq: List[float] = [float(input_box.text() or 0) for input_box in self.right_eq_inputs]
|
||||
|
||||
settings = HearingAidSettings(
|
||||
settings: HearingAidSettings = HearingAidSettings(
|
||||
left_eq, right_eq, left_amp, right_amp, tone, tone,
|
||||
conv, conv, anr, anr, amp, balance, own_voice
|
||||
)
|
||||
threading.Thread(target=send_hearing_aid_settings, args=(self.att_manager, settings)).start()
|
||||
Thread(target=send_hearing_aid_settings, args=(self.att_manager, settings)).start()
|
||||
|
||||
def reset_settings(self):
|
||||
logging.debug("Resetting settings to defaults")
|
||||
@@ -451,26 +454,25 @@ class HearingAidApp(QWidget):
|
||||
self.own_voice_slider.setValue(50)
|
||||
self.on_value_changed()
|
||||
|
||||
def closeEvent(self, event):
|
||||
def closeEvent(self, event: Any) -> None:
|
||||
logging.info("Closing app")
|
||||
self.att_manager.disconnect()
|
||||
event.accept()
|
||||
|
||||
if __name__ == "__main__":
|
||||
mac = None
|
||||
if len(sys.argv) != 2:
|
||||
logging.error("Usage: python hearing-aid-adjustments.py <MAC_ADDRESS>")
|
||||
sys.exit(1)
|
||||
mac = sys.argv[1]
|
||||
mac_regex = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$'
|
||||
mac: str = sys.argv[1]
|
||||
mac_regex: str = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$'
|
||||
import re
|
||||
if not re.match(mac_regex, mac):
|
||||
logging.error("Invalid MAC address format")
|
||||
sys.exit(1)
|
||||
logging.info(f"Starting app")
|
||||
app = QApplication(sys.argv)
|
||||
app: QApplication = QApplication(sys.argv)
|
||||
|
||||
def quit_app(signum, frame):
|
||||
def quit_app(signum: int, frame: Any) -> None:
|
||||
app.quit()
|
||||
|
||||
signal.signal(signal.SIGINT, quit_app)
|
||||
|
||||
@@ -4,50 +4,53 @@
|
||||
# See https://github.com/google/bumble/blob/main/docs/mkdocs/src/platforms/windows.md for usage.
|
||||
# You need to associate WinUSB with your Bluetooth interface. Once done, you can roll back to the original driver from Device Manager.
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import argparse
|
||||
import colorama
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any, Optional
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from asyncio import Queue, TimeoutError
|
||||
from colorama import Fore, Style
|
||||
from logging import Formatter, LogRecord, Logger, StreamHandler
|
||||
from socket import socket as Socket
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from colorama import Fore, Style, init as colorama_init
|
||||
colorama_init(autoreset=True)
|
||||
colorama.init(autoreset=True)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
class ColorFormatter(logging.Formatter):
|
||||
COLORS = {
|
||||
handler: StreamHandler = StreamHandler()
|
||||
class ColorFormatter(Formatter):
|
||||
COLORS: Dict[int, str] = {
|
||||
logging.DEBUG: Fore.BLUE,
|
||||
logging.INFO: Fore.GREEN,
|
||||
logging.WARNING: Fore.YELLOW,
|
||||
logging.ERROR: Fore.RED,
|
||||
logging.CRITICAL: Fore.MAGENTA,
|
||||
}
|
||||
def format(self, record):
|
||||
color = self.COLORS.get(record.levelno, "")
|
||||
prefix = f"{color}[{record.levelname}:{record.name}]{Style.RESET_ALL}"
|
||||
def format(self, record: LogRecord) -> str:
|
||||
color: str = self.COLORS.get(record.levelno, "")
|
||||
prefix: str = f"{color}[{record.levelname}:{record.name}]{Style.RESET_ALL}"
|
||||
return f"{prefix} {record.getMessage()}"
|
||||
handler.setFormatter(ColorFormatter())
|
||||
logging.basicConfig(level=logging.INFO, handlers=[handler])
|
||||
logger = logging.getLogger("proximitykeys")
|
||||
logger: Logger = logging.getLogger("proximitykeys")
|
||||
|
||||
PROXIMITY_KEY_TYPES = {0x01: "IRK", 0x04: "ENC_KEY"}
|
||||
PROXIMITY_KEY_TYPES: Dict[int, str] = {0x01: "IRK", 0x04: "ENC_KEY"}
|
||||
|
||||
def parse_proximity_keys_response(data: bytes):
|
||||
def parse_proximity_keys_response(data: bytes) -> Optional[List[Tuple[str, bytes]]]:
|
||||
if len(data) < 7 or data[4] != 0x31:
|
||||
return None
|
||||
key_count = data[6]
|
||||
keys = []
|
||||
offset = 7
|
||||
key_count: int = data[6]
|
||||
keys: List[Tuple[str, bytes]] = []
|
||||
offset: int = 7
|
||||
for _ in range(key_count):
|
||||
if offset + 3 >= len(data):
|
||||
break
|
||||
key_type = data[offset]
|
||||
key_length = data[offset + 2]
|
||||
key_type: int = data[offset]
|
||||
key_length: int = data[offset + 2]
|
||||
offset += 4
|
||||
if offset + key_length > len(data):
|
||||
break
|
||||
key_bytes = data[offset:offset + key_length]
|
||||
key_bytes: bytes = data[offset:offset + key_length]
|
||||
keys.append((PROXIMITY_KEY_TYPES.get(key_type, f"TYPE_{key_type:02X}"), key_bytes))
|
||||
offset += key_length
|
||||
return keys
|
||||
@@ -55,7 +58,7 @@ def parse_proximity_keys_response(data: bytes):
|
||||
def hexdump(data: bytes) -> str:
|
||||
return " ".join(f"{b:02X}" for b in data)
|
||||
|
||||
async def run_bumble(bdaddr: str):
|
||||
async def run_bumble(bdaddr: str) -> int:
|
||||
try:
|
||||
from bumble.l2cap import ClassicChannelSpec
|
||||
from bumble.transport import open_transport
|
||||
@@ -68,19 +71,23 @@ async def run_bumble(bdaddr: str):
|
||||
logger.error("Bumble not installed")
|
||||
return 1
|
||||
|
||||
PSM_PROXIMITY = 0x1001
|
||||
HANDSHAKE = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00")
|
||||
KEY_REQ = bytes.fromhex("04 00 04 00 30 00 05 00")
|
||||
PSM_PROXIMITY: int = 0x1001
|
||||
HANDSHAKE: bytes = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00")
|
||||
KEY_REQ: bytes = bytes.fromhex("04 00 04 00 30 00 05 00")
|
||||
|
||||
class KeyStore:
|
||||
async def delete(self, name: str): pass
|
||||
async def update(self, name: str, keys: Any): pass
|
||||
async def get(self, _name: str) -> Optional[Any]: return None
|
||||
async def get_all(self): return []
|
||||
async def delete(self, name: str) -> None:
|
||||
pass
|
||||
async def update(self, name: str, keys: Any) -> None:
|
||||
pass
|
||||
async def get(self, _name: str) -> Optional[Any]:
|
||||
return None
|
||||
async def get_all(self) -> List[Tuple[str, Any]]:
|
||||
return []
|
||||
|
||||
async def get_resolving_keys(self) -> list[tuple[bytes, Any]]:
|
||||
all_keys = await self.get_all()
|
||||
resolving_keys = []
|
||||
async def get_resolving_keys(self) -> List[Tuple[bytes, Any]]:
|
||||
all_keys: List[Tuple[str, Any]] = await self.get_all()
|
||||
resolving_keys: List[Tuple[bytes, Any]] = []
|
||||
for name, keys in all_keys:
|
||||
if getattr(keys, "irk", None) is not None:
|
||||
resolving_keys.append((
|
||||
@@ -89,8 +96,8 @@ async def run_bumble(bdaddr: str):
|
||||
))
|
||||
return resolving_keys
|
||||
|
||||
async def exchange_keys(channel, timeout=5.0):
|
||||
recv_q: asyncio.Queue = asyncio.Queue()
|
||||
async def exchange_keys(channel: Any, timeout: float = 5.0) -> Optional[List[Tuple[str, bytes]]]:
|
||||
recv_q: Queue = Queue()
|
||||
channel.sink = lambda sdu: recv_q.put_nowait(sdu)
|
||||
logger.info("Sending handshake packet...")
|
||||
channel.send_pdu(HANDSHAKE)
|
||||
@@ -99,19 +106,19 @@ async def run_bumble(bdaddr: str):
|
||||
channel.send_pdu(KEY_REQ)
|
||||
while True:
|
||||
try:
|
||||
pkt = await asyncio.wait_for(recv_q.get(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
pkt: bytes = await asyncio.wait_for(recv_q.get(), timeout)
|
||||
except TimeoutError:
|
||||
logger.error("Timed out waiting for SDU response")
|
||||
return None
|
||||
logger.debug("Received SDU (%d bytes): %s", len(pkt), hexdump(pkt))
|
||||
keys = parse_proximity_keys_response(pkt)
|
||||
keys: Optional[List[Tuple[str, bytes]]] = parse_proximity_keys_response(pkt)
|
||||
if keys:
|
||||
return keys
|
||||
|
||||
async def get_device():
|
||||
async def get_device() -> Tuple[Any, Device]:
|
||||
logger.info("Opening transport...")
|
||||
transport = await open_transport("usb:0")
|
||||
device = Device(host=Host(controller_source=transport.source, controller_sink=transport.sink))
|
||||
transport: Any = await open_transport("usb:0")
|
||||
device: Device = Device(host=Host(controller_source=transport.source, controller_sink=transport.sink))
|
||||
device.classic_enabled = True
|
||||
device.le_enabled = False
|
||||
device.keystore = KeyStore()
|
||||
@@ -123,15 +130,15 @@ async def run_bumble(bdaddr: str):
|
||||
logger.info("Device powered on")
|
||||
return transport, device
|
||||
|
||||
async def create_channel_and_exchange(conn):
|
||||
spec = ClassicChannelSpec(psm=PSM_PROXIMITY, mtu=2048)
|
||||
async def create_channel_and_exchange(conn: Any) -> None:
|
||||
spec: ClassicChannelSpec = ClassicChannelSpec(psm=PSM_PROXIMITY, mtu=2048)
|
||||
logger.info("Requesting L2CAP channel on PSM = 0x%04X", spec.psm)
|
||||
if not conn.is_encrypted:
|
||||
logger.info("Enabling link encryption...")
|
||||
await conn.encrypt()
|
||||
await asyncio.sleep(0.05)
|
||||
channel = await conn.create_l2cap_channel(spec=spec)
|
||||
keys = await exchange_keys(channel, timeout=8.0)
|
||||
channel: Any = await conn.create_l2cap_channel(spec=spec)
|
||||
keys: Optional[List[Tuple[str, bytes]]] = await exchange_keys(channel, timeout=8.0)
|
||||
if not keys:
|
||||
logger.warning("No proximity keys found")
|
||||
return
|
||||
@@ -165,14 +172,14 @@ async def run_bumble(bdaddr: str):
|
||||
logger.info("Transport closed")
|
||||
return 0
|
||||
|
||||
def run_linux(bdaddr: str):
|
||||
def run_linux(bdaddr: str) -> None:
|
||||
import socket
|
||||
PSM = 0x1001
|
||||
handshake = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00")
|
||||
key_req = bytes.fromhex("04 00 04 00 30 00 05 00")
|
||||
PSM: int = 0x1001
|
||||
handshake: bytes = bytes.fromhex("00 00 04 00 01 00 02 00 00 00 00 00 00 00 00 00")
|
||||
key_req: bytes = bytes.fromhex("04 00 04 00 30 00 05 00")
|
||||
|
||||
logger.info("Connecting to %s (L2CAP)...", bdaddr)
|
||||
sock = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP)
|
||||
sock: Socket = Socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP)
|
||||
try:
|
||||
sock.connect((bdaddr, PSM))
|
||||
logger.info("Connected, sending handshake and key request...")
|
||||
@@ -180,9 +187,9 @@ def run_linux(bdaddr: str):
|
||||
sock.send(key_req)
|
||||
|
||||
while True:
|
||||
pkt = sock.recv(1024)
|
||||
pkt: bytes = sock.recv(1024)
|
||||
logger.debug("Received packet (%d bytes): %s", len(pkt), hexdump(pkt))
|
||||
keys = parse_proximity_keys_response(pkt)
|
||||
keys: Optional[List[Tuple[str, bytes]]] = parse_proximity_keys_response(pkt)
|
||||
if keys:
|
||||
logger.info("Keys successfully retrieved")
|
||||
print(f"{Fore.CYAN}{Style.BRIGHT}Proximity Keys:{Style.RESET_ALL}")
|
||||
@@ -197,12 +204,12 @@ def run_linux(bdaddr: str):
|
||||
sock.close()
|
||||
logger.info("Connection closed")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
def main() -> None:
|
||||
parser: ArgumentParser = ArgumentParser()
|
||||
parser.add_argument("bdaddr")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--bumble", action="store_true")
|
||||
args = parser.parse_args()
|
||||
args: Namespace = parser.parse_args()
|
||||
logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
|
||||
|
||||
if args.bumble or platform.system() == "Windows":
|
||||
|
||||
Reference in New Issue
Block a user