448 lines
16 KiB
Python
Executable File
448 lines
16 KiB
Python
Executable File
"""
|
|
Aria2 Download Manager
|
|
|
|
High-level interface for downloading files via aria2c RPC
|
|
Handles progress tracking, error handling, and fallback logic
|
|
"""
|
|
import os
|
|
import time
|
|
import logging
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Optional, Callable, Dict, Any, Tuple
|
|
from urllib.parse import urlencode
|
|
|
|
import aria2p
|
|
|
|
from .error_parser import parse_aria2_error, format_error_for_user, Aria2ErrorCategory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global singleton instance
|
|
_aria2_manager: Optional['Aria2DownloadManager'] = None
|
|
|
|
|
|
class Aria2DownloadManager:
|
|
"""
|
|
Manages file downloads via aria2c RPC
|
|
|
|
Features:
|
|
- Multi-threaded downloads (16 connections per file)
|
|
- Progress callbacks for UI updates
|
|
- Automatic retry on failure
|
|
- Cookie and header support for authenticated downloads
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = 'localhost',
|
|
port: int = 6800,
|
|
secret: str = 'dkidownload_secret_2025'
|
|
):
|
|
"""
|
|
Initialize aria2 RPC client
|
|
|
|
Args:
|
|
host: aria2c RPC host
|
|
port: aria2c RPC port
|
|
secret: RPC secret token
|
|
"""
|
|
self.host = host
|
|
self.port = port
|
|
self.secret = secret
|
|
|
|
try:
|
|
# Create aria2p client - aria2p adds /jsonrpc automatically
|
|
# Don't include port in host URL, use port parameter instead
|
|
self.client = aria2p.Client(
|
|
host=f"http://{host}",
|
|
port=port,
|
|
secret=secret
|
|
)
|
|
self.api = aria2p.API(self.client)
|
|
|
|
# Test connection - try to get stats to verify
|
|
try:
|
|
stats = self.api.get_stats()
|
|
logger.debug(f"✅ Connected to aria2 RPC on {host}:{port}")
|
|
except Exception as conn_err:
|
|
# More specific error handling
|
|
error_str = str(conn_err)
|
|
if "Expecting value" in error_str or "JSON" in error_str:
|
|
logger.warning(
|
|
f"aria2 RPC JSON parsing issue (non-fatal): {conn_err}")
|
|
logger.debug(
|
|
f"✅ aria2 RPC client initialized on {host}:{port} (connection assumed)")
|
|
else:
|
|
logger.debug(
|
|
f"✅ aria2 RPC client initialized on {host}:{port}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to aria2c RPC: {e}")
|
|
raise
|
|
|
|
def download_file(
|
|
self,
|
|
url: str,
|
|
dest_path: str,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
cookies: Optional[str] = None,
|
|
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
referer: Optional[str] = None,
|
|
method: str = 'GET',
|
|
post_data: Optional[Dict[str, str]] = None,
|
|
max_download_limit: Optional[str] = None
|
|
) -> Tuple[bool, Optional[str], Optional[str]]:
|
|
"""
|
|
Download a file via aria2 (GET ONLY - NO POST SUPPORT)
|
|
|
|
IMPORTANT: aria2 ONLY supports GET requests. POST requests will raise ValueError.
|
|
Do NOT pass method='POST' or post_data to this function.
|
|
|
|
Args:
|
|
url: Download URL (with query params if needed)
|
|
dest_path: Full local path to save file
|
|
headers: Custom HTTP headers dict
|
|
cookies: Cookie string (format: "name=value; name2=value2")
|
|
progress_callback: Optional callback(downloaded_bytes, total_bytes)
|
|
referer: Referer header value
|
|
method: HTTP method (MUST be 'GET', others will raise error)
|
|
post_data: NOT SUPPORTED - will raise ValueError if provided
|
|
max_download_limit: Optional bandwidth limit (e.g., '100K', '1M')
|
|
|
|
Returns:
|
|
Tuple[success: bool, error_message: Optional[str], gid: Optional[str]]
|
|
|
|
Raises:
|
|
ValueError: If method is not 'GET' or post_data is provided
|
|
"""
|
|
try:
|
|
# Ensure parent directory exists BEFORE aria2 tries to access it
|
|
parent_dir = os.path.dirname(dest_path)
|
|
if parent_dir:
|
|
try:
|
|
os.makedirs(parent_dir, exist_ok=True)
|
|
logger.debug(f"Ensured directory exists: {parent_dir}")
|
|
except Exception as mkdir_err:
|
|
logger.error(
|
|
f"Failed to create directory {parent_dir}: {mkdir_err}")
|
|
raise
|
|
|
|
# Normalize path for aria2 (convert to absolute path)
|
|
# This prevents aria2 from trying to create directories itself
|
|
abs_dest_path = os.path.abspath(dest_path)
|
|
abs_parent_dir = os.path.dirname(abs_dest_path)
|
|
|
|
# Load max connections from environment
|
|
max_connections = os.getenv('ARIA2_MAX_CONNECTIONS_PER_FILE', '16')
|
|
|
|
# Reduce connections for bandwidth-limited downloads
|
|
if max_download_limit:
|
|
# With throttled bandwidth (e.g., 100KB/s), use only 1 connection
|
|
# to avoid timeout from too many slow connections
|
|
max_connections = '1'
|
|
logger.debug(
|
|
f"Reduced connections to 1 for bandwidth-limited download")
|
|
|
|
# Build aria2 options - use absolute paths
|
|
options = {
|
|
'dir': abs_parent_dir,
|
|
'out': os.path.basename(abs_dest_path),
|
|
'max-connection-per-server': max_connections,
|
|
'split': max_connections,
|
|
'min-split-size': '1M',
|
|
'continue': 'true',
|
|
'auto-file-renaming': 'false',
|
|
'allow-overwrite': 'true',
|
|
'check-certificate': 'false', # NAS self-signed cert
|
|
}
|
|
|
|
# Apply bandwidth limit if specified (for background downloads)
|
|
if max_download_limit:
|
|
options['max-download-limit'] = max_download_limit
|
|
logger.debug(f"Bandwidth limit: {max_download_limit}")
|
|
|
|
# Add headers - aria2 accepts list format
|
|
if headers:
|
|
header_list = []
|
|
for key, value in headers.items():
|
|
header_list.append(f"{key}: {value}")
|
|
if header_list:
|
|
options['header'] = '\n'.join(header_list)
|
|
|
|
# Add referer
|
|
if referer:
|
|
options['referer'] = referer
|
|
|
|
# Add cookies - append to header
|
|
if cookies:
|
|
cookie_header = f"Cookie: {cookies}"
|
|
if 'header' in options:
|
|
options['header'] = options['header'] + \
|
|
'\n' + cookie_header
|
|
else:
|
|
options['header'] = cookie_header
|
|
|
|
logger.debug(
|
|
f"Starting aria2 download: {os.path.basename(abs_dest_path)}")
|
|
logger.debug(f"URL: {url[:100]}...")
|
|
logger.debug(f"Dest: {abs_dest_path}")
|
|
logger.debug(f"Aria2 dir: {abs_parent_dir}")
|
|
logger.debug(f"Method: {method}")
|
|
|
|
# Validate method - aria2 ONLY supports GET/HEAD
|
|
if method.upper() != 'GET':
|
|
error_msg = (
|
|
f"❌ aria2 only supports GET method, received: {method}. "
|
|
f"POST requests are NOT supported. "
|
|
f"This is a critical error - please check your download implementation."
|
|
)
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
# Validate no POST data present
|
|
if post_data:
|
|
error_msg = (
|
|
f"❌ aria2 cannot handle POST data. "
|
|
f"POST requests with form data are NOT supported by aria2c RPC. "
|
|
f"This indicates a bug in the calling code."
|
|
)
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
# Standard GET download with aria2 (16 connections)
|
|
download = self.api.add_uris([url], options=options)
|
|
gid = download.gid
|
|
|
|
logger.debug(f"Download started with GID: {gid}")
|
|
|
|
# Monitor progress
|
|
success = self._wait_for_completion(
|
|
gid=gid,
|
|
progress_callback=progress_callback
|
|
)
|
|
|
|
if success:
|
|
logger.debug(f"✅ Download completed: {abs_dest_path}")
|
|
return True, None, gid
|
|
else:
|
|
# Get error info and parse it
|
|
try:
|
|
download = self.api.get_download(gid)
|
|
raw_error = download.error_message or "Unknown error"
|
|
except:
|
|
raw_error = "Download failed"
|
|
|
|
# Parse error with structured handler
|
|
parsed_error = parse_aria2_error(raw_error)
|
|
|
|
# Log with category
|
|
logger.error(
|
|
f"❌ Download failed [{parsed_error.category.value}]: {raw_error}")
|
|
logger.debug(f"💡 {parsed_error.user_message}")
|
|
logger.debug(
|
|
f"📋 Suggested action: {parsed_error.suggested_action}")
|
|
|
|
# Return user-friendly error message
|
|
error_msg = format_error_for_user(
|
|
parsed_error, include_technical=False)
|
|
return False, error_msg, None
|
|
|
|
except Exception as e:
|
|
raw_error = str(e)
|
|
|
|
# Parse error
|
|
parsed_error = parse_aria2_error(raw_error)
|
|
|
|
# Special handling for JSON/RPC errors
|
|
if "Expecting value" in raw_error or "JSON" in raw_error:
|
|
error_msg = (
|
|
f"aria2 RPC communication error: {raw_error}. "
|
|
"This might be due to aria2c daemon not running or RPC secret mismatch. "
|
|
"Check if aria2c process is active and RPC port 6800 is accessible."
|
|
)
|
|
else:
|
|
error_msg = format_error_for_user(
|
|
parsed_error, include_technical=True)
|
|
|
|
logger.error(
|
|
f"❌ Exception [{parsed_error.category.value}]: {raw_error}", exc_info=True)
|
|
return False, error_msg, None
|
|
|
|
def _wait_for_completion(
|
|
self,
|
|
gid: str,
|
|
progress_callback: Optional[Callable] = None,
|
|
poll_interval: float = 0.5
|
|
) -> bool:
|
|
"""
|
|
Wait for download to complete and track progress
|
|
|
|
Args:
|
|
gid: aria2 download GID
|
|
progress_callback: Optional callback(downloaded_bytes, total_bytes)
|
|
poll_interval: How often to check status (seconds)
|
|
|
|
Returns:
|
|
True if download completed successfully
|
|
"""
|
|
last_completed = 0
|
|
last_total = 0
|
|
|
|
while True:
|
|
try:
|
|
download = self.api.get_download(gid)
|
|
|
|
# Check status
|
|
if download.is_complete:
|
|
# Final progress callback
|
|
if progress_callback:
|
|
progress_callback(
|
|
download.completed_length, download.total_length)
|
|
return True
|
|
|
|
elif download.has_failed:
|
|
logger.error(f"Download failed: {download.error_message}")
|
|
return False
|
|
|
|
elif download.is_removed:
|
|
logger.warning("Download was removed")
|
|
return False
|
|
|
|
# Update progress
|
|
completed = download.completed_length
|
|
total = download.total_length
|
|
|
|
# Only call callback if values changed
|
|
if progress_callback and (completed != last_completed or total != last_total):
|
|
progress_callback(completed, total)
|
|
last_completed = completed
|
|
last_total = total
|
|
|
|
# Wait before next poll
|
|
time.sleep(poll_interval)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking download status: {e}")
|
|
time.sleep(poll_interval)
|
|
|
|
def get_status(self, gid: str) -> Dict[str, Any]:
|
|
"""
|
|
Get current status of a download
|
|
|
|
Args:
|
|
gid: aria2 download GID
|
|
|
|
Returns:
|
|
Status dict with keys: gid, status, completed, total, speed, progress
|
|
"""
|
|
try:
|
|
download = self.api.get_download(gid)
|
|
|
|
progress = 0
|
|
if download.total_length > 0:
|
|
progress = (download.completed_length /
|
|
download.total_length) * 100
|
|
|
|
return {
|
|
'gid': gid,
|
|
'status': download.status,
|
|
'completed': download.completed_length,
|
|
'total': download.total_length,
|
|
'speed': download.download_speed,
|
|
'progress': round(progress, 2),
|
|
'error': download.error_message if download.has_failed else None
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
'gid': gid,
|
|
'status': 'error',
|
|
'error': str(e)
|
|
}
|
|
|
|
def cancel_download(self, gid: str) -> bool:
|
|
"""
|
|
Cancel an active download
|
|
|
|
Args:
|
|
gid: aria2 download GID
|
|
|
|
Returns:
|
|
True if cancelled successfully
|
|
"""
|
|
try:
|
|
# Remove by GID using force remove
|
|
download = self.api.get_download(gid)
|
|
if download:
|
|
download.remove(force=True)
|
|
logger.debug(f"Cancelled download: {gid}")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to cancel download {gid}: {e}")
|
|
return False
|
|
|
|
def get_global_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get global download statistics
|
|
|
|
Returns:
|
|
Stats dict with download speed, active downloads, etc.
|
|
"""
|
|
try:
|
|
stats = self.api.get_stats()
|
|
return {
|
|
'download_speed': getattr(stats, 'download_speed', 0),
|
|
'upload_speed': getattr(stats, 'upload_speed', 0),
|
|
'num_active': getattr(stats, 'num_active', 0),
|
|
'num_waiting': getattr(stats, 'num_waiting', 0),
|
|
'num_stopped': getattr(stats, 'num_stopped', 0)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Failed to get stats: {e}")
|
|
return {
|
|
'download_speed': 0,
|
|
'upload_speed': 0,
|
|
'num_active': 0,
|
|
'num_waiting': 0,
|
|
'num_stopped': 0
|
|
}
|
|
|
|
|
|
def get_aria2_manager(
|
|
host: str = 'localhost',
|
|
port: int = 6800,
|
|
secret: str = 'dkidownload_secret_2025'
|
|
) -> Optional[Aria2DownloadManager]:
|
|
"""
|
|
Get or create global aria2 manager instance (singleton)
|
|
|
|
Args:
|
|
host: aria2c RPC host
|
|
port: aria2c RPC port
|
|
secret: RPC secret token
|
|
|
|
Returns:
|
|
Aria2DownloadManager instance or None if connection failed
|
|
"""
|
|
global _aria2_manager
|
|
|
|
if _aria2_manager is None:
|
|
try:
|
|
_aria2_manager = Aria2DownloadManager(
|
|
host=host,
|
|
port=port,
|
|
secret=secret
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to create aria2 manager: {e}")
|
|
return None
|
|
|
|
return _aria2_manager
|
|
|
|
|
|
def reset_aria2_manager():
|
|
"""Reset the global manager instance (useful for testing)"""
|
|
global _aria2_manager
|
|
_aria2_manager = None
|