""" Aria2 Download Manager High-level interface for downloading files via aria2c RPC Handles progress tracking, error handling, and fallback logic """ import os import time import logging import threading from pathlib import Path from typing import Optional, Callable, Dict, Any, Tuple from urllib.parse import urlencode import aria2p from .error_parser import parse_aria2_error, format_error_for_user, Aria2ErrorCategory logger = logging.getLogger(__name__) # Global singleton instance _aria2_manager: Optional['Aria2DownloadManager'] = None class Aria2DownloadManager: """ Manages file downloads via aria2c RPC Features: - Multi-threaded downloads (16 connections per file) - Progress callbacks for UI updates - Automatic retry on failure - Cookie and header support for authenticated downloads """ def __init__( self, host: str = 'localhost', port: int = 6800, secret: str = 'dkidownload_secret_2025' ): """ Initialize aria2 RPC client Args: host: aria2c RPC host port: aria2c RPC port secret: RPC secret token """ self.host = host self.port = port self.secret = secret try: # Create aria2p client - aria2p adds /jsonrpc automatically # Don't include port in host URL, use port parameter instead self.client = aria2p.Client( host=f"http://{host}", port=port, secret=secret ) self.api = aria2p.API(self.client) # Test connection - try to get stats to verify try: stats = self.api.get_stats() logger.debug(f"✅ Connected to aria2 RPC on {host}:{port}") except Exception as conn_err: # More specific error handling error_str = str(conn_err) if "Expecting value" in error_str or "JSON" in error_str: logger.warning( f"aria2 RPC JSON parsing issue (non-fatal): {conn_err}") logger.debug( f"✅ aria2 RPC client initialized on {host}:{port} (connection assumed)") else: logger.debug( f"✅ aria2 RPC client initialized on {host}:{port}") except Exception as e: logger.error(f"Failed to connect to aria2c RPC: {e}") raise def download_file( self, url: str, dest_path: str, headers: Optional[Dict[str, str]] = None, cookies: Optional[str] = None, progress_callback: Optional[Callable[[int, int], None]] = None, referer: Optional[str] = None, method: str = 'GET', post_data: Optional[Dict[str, str]] = None, max_download_limit: Optional[str] = None ) -> Tuple[bool, Optional[str], Optional[str]]: """ Download a file via aria2 (GET ONLY - NO POST SUPPORT) IMPORTANT: aria2 ONLY supports GET requests. POST requests will raise ValueError. Do NOT pass method='POST' or post_data to this function. Args: url: Download URL (with query params if needed) dest_path: Full local path to save file headers: Custom HTTP headers dict cookies: Cookie string (format: "name=value; name2=value2") progress_callback: Optional callback(downloaded_bytes, total_bytes) referer: Referer header value method: HTTP method (MUST be 'GET', others will raise error) post_data: NOT SUPPORTED - will raise ValueError if provided max_download_limit: Optional bandwidth limit (e.g., '100K', '1M') Returns: Tuple[success: bool, error_message: Optional[str], gid: Optional[str]] Raises: ValueError: If method is not 'GET' or post_data is provided """ try: # Ensure parent directory exists BEFORE aria2 tries to access it parent_dir = os.path.dirname(dest_path) if parent_dir: try: os.makedirs(parent_dir, exist_ok=True) logger.debug(f"Ensured directory exists: {parent_dir}") except Exception as mkdir_err: logger.error( f"Failed to create directory {parent_dir}: {mkdir_err}") raise # Normalize path for aria2 (convert to absolute path) # This prevents aria2 from trying to create directories itself abs_dest_path = os.path.abspath(dest_path) abs_parent_dir = os.path.dirname(abs_dest_path) # Load max connections from environment max_connections = os.getenv('ARIA2_MAX_CONNECTIONS_PER_FILE', '16') # Reduce connections for bandwidth-limited downloads if max_download_limit: # With throttled bandwidth (e.g., 100KB/s), use only 1 connection # to avoid timeout from too many slow connections max_connections = '1' logger.debug( f"Reduced connections to 1 for bandwidth-limited download") # Build aria2 options - use absolute paths options = { 'dir': abs_parent_dir, 'out': os.path.basename(abs_dest_path), 'max-connection-per-server': max_connections, 'split': max_connections, 'min-split-size': '1M', 'continue': 'true', 'auto-file-renaming': 'false', 'allow-overwrite': 'true', 'check-certificate': 'false', # NAS self-signed cert } # Apply bandwidth limit if specified (for background downloads) if max_download_limit: options['max-download-limit'] = max_download_limit logger.debug(f"Bandwidth limit: {max_download_limit}") # Add headers - aria2 accepts list format if headers: header_list = [] for key, value in headers.items(): header_list.append(f"{key}: {value}") if header_list: options['header'] = '\n'.join(header_list) # Add referer if referer: options['referer'] = referer # Add cookies - append to header if cookies: cookie_header = f"Cookie: {cookies}" if 'header' in options: options['header'] = options['header'] + \ '\n' + cookie_header else: options['header'] = cookie_header logger.debug( f"Starting aria2 download: {os.path.basename(abs_dest_path)}") logger.debug(f"URL: {url[:100]}...") logger.debug(f"Dest: {abs_dest_path}") logger.debug(f"Aria2 dir: {abs_parent_dir}") logger.debug(f"Method: {method}") # Validate method - aria2 ONLY supports GET/HEAD if method.upper() != 'GET': error_msg = ( f"❌ aria2 only supports GET method, received: {method}. " f"POST requests are NOT supported. " f"This is a critical error - please check your download implementation." ) logger.error(error_msg) raise ValueError(error_msg) # Validate no POST data present if post_data: error_msg = ( f"❌ aria2 cannot handle POST data. " f"POST requests with form data are NOT supported by aria2c RPC. " f"This indicates a bug in the calling code." ) logger.error(error_msg) raise ValueError(error_msg) # Standard GET download with aria2 (16 connections) download = self.api.add_uris([url], options=options) gid = download.gid logger.debug(f"Download started with GID: {gid}") # Monitor progress success = self._wait_for_completion( gid=gid, progress_callback=progress_callback ) if success: logger.debug(f"✅ Download completed: {abs_dest_path}") return True, None, gid else: # Get error info and parse it try: download = self.api.get_download(gid) raw_error = download.error_message or "Unknown error" except: raw_error = "Download failed" # Parse error with structured handler parsed_error = parse_aria2_error(raw_error) # Log with category logger.error( f"❌ Download failed [{parsed_error.category.value}]: {raw_error}") logger.debug(f"💡 {parsed_error.user_message}") logger.debug( f"📋 Suggested action: {parsed_error.suggested_action}") # Return user-friendly error message error_msg = format_error_for_user( parsed_error, include_technical=False) return False, error_msg, None except Exception as e: raw_error = str(e) # Parse error parsed_error = parse_aria2_error(raw_error) # Special handling for JSON/RPC errors if "Expecting value" in raw_error or "JSON" in raw_error: error_msg = ( f"aria2 RPC communication error: {raw_error}. " "This might be due to aria2c daemon not running or RPC secret mismatch. " "Check if aria2c process is active and RPC port 6800 is accessible." ) else: error_msg = format_error_for_user( parsed_error, include_technical=True) logger.error( f"❌ Exception [{parsed_error.category.value}]: {raw_error}", exc_info=True) return False, error_msg, None def _wait_for_completion( self, gid: str, progress_callback: Optional[Callable] = None, poll_interval: float = 0.5 ) -> bool: """ Wait for download to complete and track progress Args: gid: aria2 download GID progress_callback: Optional callback(downloaded_bytes, total_bytes) poll_interval: How often to check status (seconds) Returns: True if download completed successfully """ last_completed = 0 last_total = 0 while True: try: download = self.api.get_download(gid) # Check status if download.is_complete: # Final progress callback if progress_callback: progress_callback( download.completed_length, download.total_length) return True elif download.has_failed: logger.error(f"Download failed: {download.error_message}") return False elif download.is_removed: logger.warning("Download was removed") return False # Update progress completed = download.completed_length total = download.total_length # Only call callback if values changed if progress_callback and (completed != last_completed or total != last_total): progress_callback(completed, total) last_completed = completed last_total = total # Wait before next poll time.sleep(poll_interval) except Exception as e: logger.error(f"Error checking download status: {e}") time.sleep(poll_interval) def get_status(self, gid: str) -> Dict[str, Any]: """ Get current status of a download Args: gid: aria2 download GID Returns: Status dict with keys: gid, status, completed, total, speed, progress """ try: download = self.api.get_download(gid) progress = 0 if download.total_length > 0: progress = (download.completed_length / download.total_length) * 100 return { 'gid': gid, 'status': download.status, 'completed': download.completed_length, 'total': download.total_length, 'speed': download.download_speed, 'progress': round(progress, 2), 'error': download.error_message if download.has_failed else None } except Exception as e: return { 'gid': gid, 'status': 'error', 'error': str(e) } def cancel_download(self, gid: str) -> bool: """ Cancel an active download Args: gid: aria2 download GID Returns: True if cancelled successfully """ try: # Remove by GID using force remove download = self.api.get_download(gid) if download: download.remove(force=True) logger.debug(f"Cancelled download: {gid}") return True return False except Exception as e: logger.error(f"Failed to cancel download {gid}: {e}") return False def get_global_stats(self) -> Dict[str, Any]: """ Get global download statistics Returns: Stats dict with download speed, active downloads, etc. """ try: stats = self.api.get_stats() return { 'download_speed': getattr(stats, 'download_speed', 0), 'upload_speed': getattr(stats, 'upload_speed', 0), 'num_active': getattr(stats, 'num_active', 0), 'num_waiting': getattr(stats, 'num_waiting', 0), 'num_stopped': getattr(stats, 'num_stopped', 0) } except Exception as e: logger.error(f"Failed to get stats: {e}") return { 'download_speed': 0, 'upload_speed': 0, 'num_active': 0, 'num_waiting': 0, 'num_stopped': 0 } def get_aria2_manager( host: str = 'localhost', port: int = 6800, secret: str = 'dkidownload_secret_2025' ) -> Optional[Aria2DownloadManager]: """ Get or create global aria2 manager instance (singleton) Args: host: aria2c RPC host port: aria2c RPC port secret: RPC secret token Returns: Aria2DownloadManager instance or None if connection failed """ global _aria2_manager if _aria2_manager is None: try: _aria2_manager = Aria2DownloadManager( host=host, port=port, secret=secret ) except Exception as e: logger.error(f"Failed to create aria2 manager: {e}") return None return _aria2_manager def reset_aria2_manager(): """Reset the global manager instance (useful for testing)""" global _aria2_manager _aria2_manager = None