ge-tool/backend/services/aria2/download_manager.py
2025-12-10 13:41:43 +07:00

448 lines
16 KiB
Python
Executable File

"""
Aria2 Download Manager
High-level interface for downloading files via aria2c RPC
Handles progress tracking, error handling, and fallback logic
"""
import os
import time
import logging
import threading
from pathlib import Path
from typing import Optional, Callable, Dict, Any, Tuple
from urllib.parse import urlencode
import aria2p
from .error_parser import parse_aria2_error, format_error_for_user, Aria2ErrorCategory
logger = logging.getLogger(__name__)
# Global singleton instance
_aria2_manager: Optional['Aria2DownloadManager'] = None
class Aria2DownloadManager:
"""
Manages file downloads via aria2c RPC
Features:
- Multi-threaded downloads (16 connections per file)
- Progress callbacks for UI updates
- Automatic retry on failure
- Cookie and header support for authenticated downloads
"""
def __init__(
self,
host: str = 'localhost',
port: int = 6800,
secret: str = 'dkidownload_secret_2025'
):
"""
Initialize aria2 RPC client
Args:
host: aria2c RPC host
port: aria2c RPC port
secret: RPC secret token
"""
self.host = host
self.port = port
self.secret = secret
try:
# Create aria2p client - aria2p adds /jsonrpc automatically
# Don't include port in host URL, use port parameter instead
self.client = aria2p.Client(
host=f"http://{host}",
port=port,
secret=secret
)
self.api = aria2p.API(self.client)
# Test connection - try to get stats to verify
try:
stats = self.api.get_stats()
logger.debug(f"✅ Connected to aria2 RPC on {host}:{port}")
except Exception as conn_err:
# More specific error handling
error_str = str(conn_err)
if "Expecting value" in error_str or "JSON" in error_str:
logger.warning(
f"aria2 RPC JSON parsing issue (non-fatal): {conn_err}")
logger.debug(
f"✅ aria2 RPC client initialized on {host}:{port} (connection assumed)")
else:
logger.debug(
f"✅ aria2 RPC client initialized on {host}:{port}")
except Exception as e:
logger.error(f"Failed to connect to aria2c RPC: {e}")
raise
def download_file(
self,
url: str,
dest_path: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[str] = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
referer: Optional[str] = None,
method: str = 'GET',
post_data: Optional[Dict[str, str]] = None,
max_download_limit: Optional[str] = None
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Download a file via aria2 (GET ONLY - NO POST SUPPORT)
IMPORTANT: aria2 ONLY supports GET requests. POST requests will raise ValueError.
Do NOT pass method='POST' or post_data to this function.
Args:
url: Download URL (with query params if needed)
dest_path: Full local path to save file
headers: Custom HTTP headers dict
cookies: Cookie string (format: "name=value; name2=value2")
progress_callback: Optional callback(downloaded_bytes, total_bytes)
referer: Referer header value
method: HTTP method (MUST be 'GET', others will raise error)
post_data: NOT SUPPORTED - will raise ValueError if provided
max_download_limit: Optional bandwidth limit (e.g., '100K', '1M')
Returns:
Tuple[success: bool, error_message: Optional[str], gid: Optional[str]]
Raises:
ValueError: If method is not 'GET' or post_data is provided
"""
try:
# Ensure parent directory exists BEFORE aria2 tries to access it
parent_dir = os.path.dirname(dest_path)
if parent_dir:
try:
os.makedirs(parent_dir, exist_ok=True)
logger.debug(f"Ensured directory exists: {parent_dir}")
except Exception as mkdir_err:
logger.error(
f"Failed to create directory {parent_dir}: {mkdir_err}")
raise
# Normalize path for aria2 (convert to absolute path)
# This prevents aria2 from trying to create directories itself
abs_dest_path = os.path.abspath(dest_path)
abs_parent_dir = os.path.dirname(abs_dest_path)
# Load max connections from environment
max_connections = os.getenv('ARIA2_MAX_CONNECTIONS_PER_FILE', '16')
# Reduce connections for bandwidth-limited downloads
if max_download_limit:
# With throttled bandwidth (e.g., 100KB/s), use only 1 connection
# to avoid timeout from too many slow connections
max_connections = '1'
logger.debug(
f"Reduced connections to 1 for bandwidth-limited download")
# Build aria2 options - use absolute paths
options = {
'dir': abs_parent_dir,
'out': os.path.basename(abs_dest_path),
'max-connection-per-server': max_connections,
'split': max_connections,
'min-split-size': '1M',
'continue': 'true',
'auto-file-renaming': 'false',
'allow-overwrite': 'true',
'check-certificate': 'false', # NAS self-signed cert
}
# Apply bandwidth limit if specified (for background downloads)
if max_download_limit:
options['max-download-limit'] = max_download_limit
logger.debug(f"Bandwidth limit: {max_download_limit}")
# Add headers - aria2 accepts list format
if headers:
header_list = []
for key, value in headers.items():
header_list.append(f"{key}: {value}")
if header_list:
options['header'] = '\n'.join(header_list)
# Add referer
if referer:
options['referer'] = referer
# Add cookies - append to header
if cookies:
cookie_header = f"Cookie: {cookies}"
if 'header' in options:
options['header'] = options['header'] + \
'\n' + cookie_header
else:
options['header'] = cookie_header
logger.debug(
f"Starting aria2 download: {os.path.basename(abs_dest_path)}")
logger.debug(f"URL: {url[:100]}...")
logger.debug(f"Dest: {abs_dest_path}")
logger.debug(f"Aria2 dir: {abs_parent_dir}")
logger.debug(f"Method: {method}")
# Validate method - aria2 ONLY supports GET/HEAD
if method.upper() != 'GET':
error_msg = (
f"❌ aria2 only supports GET method, received: {method}. "
f"POST requests are NOT supported. "
f"This is a critical error - please check your download implementation."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Validate no POST data present
if post_data:
error_msg = (
f"❌ aria2 cannot handle POST data. "
f"POST requests with form data are NOT supported by aria2c RPC. "
f"This indicates a bug in the calling code."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Standard GET download with aria2 (16 connections)
download = self.api.add_uris([url], options=options)
gid = download.gid
logger.debug(f"Download started with GID: {gid}")
# Monitor progress
success = self._wait_for_completion(
gid=gid,
progress_callback=progress_callback
)
if success:
logger.debug(f"✅ Download completed: {abs_dest_path}")
return True, None, gid
else:
# Get error info and parse it
try:
download = self.api.get_download(gid)
raw_error = download.error_message or "Unknown error"
except:
raw_error = "Download failed"
# Parse error with structured handler
parsed_error = parse_aria2_error(raw_error)
# Log with category
logger.error(
f"❌ Download failed [{parsed_error.category.value}]: {raw_error}")
logger.debug(f"💡 {parsed_error.user_message}")
logger.debug(
f"📋 Suggested action: {parsed_error.suggested_action}")
# Return user-friendly error message
error_msg = format_error_for_user(
parsed_error, include_technical=False)
return False, error_msg, None
except Exception as e:
raw_error = str(e)
# Parse error
parsed_error = parse_aria2_error(raw_error)
# Special handling for JSON/RPC errors
if "Expecting value" in raw_error or "JSON" in raw_error:
error_msg = (
f"aria2 RPC communication error: {raw_error}. "
"This might be due to aria2c daemon not running or RPC secret mismatch. "
"Check if aria2c process is active and RPC port 6800 is accessible."
)
else:
error_msg = format_error_for_user(
parsed_error, include_technical=True)
logger.error(
f"❌ Exception [{parsed_error.category.value}]: {raw_error}", exc_info=True)
return False, error_msg, None
def _wait_for_completion(
self,
gid: str,
progress_callback: Optional[Callable] = None,
poll_interval: float = 0.5
) -> bool:
"""
Wait for download to complete and track progress
Args:
gid: aria2 download GID
progress_callback: Optional callback(downloaded_bytes, total_bytes)
poll_interval: How often to check status (seconds)
Returns:
True if download completed successfully
"""
last_completed = 0
last_total = 0
while True:
try:
download = self.api.get_download(gid)
# Check status
if download.is_complete:
# Final progress callback
if progress_callback:
progress_callback(
download.completed_length, download.total_length)
return True
elif download.has_failed:
logger.error(f"Download failed: {download.error_message}")
return False
elif download.is_removed:
logger.warning("Download was removed")
return False
# Update progress
completed = download.completed_length
total = download.total_length
# Only call callback if values changed
if progress_callback and (completed != last_completed or total != last_total):
progress_callback(completed, total)
last_completed = completed
last_total = total
# Wait before next poll
time.sleep(poll_interval)
except Exception as e:
logger.error(f"Error checking download status: {e}")
time.sleep(poll_interval)
def get_status(self, gid: str) -> Dict[str, Any]:
"""
Get current status of a download
Args:
gid: aria2 download GID
Returns:
Status dict with keys: gid, status, completed, total, speed, progress
"""
try:
download = self.api.get_download(gid)
progress = 0
if download.total_length > 0:
progress = (download.completed_length /
download.total_length) * 100
return {
'gid': gid,
'status': download.status,
'completed': download.completed_length,
'total': download.total_length,
'speed': download.download_speed,
'progress': round(progress, 2),
'error': download.error_message if download.has_failed else None
}
except Exception as e:
return {
'gid': gid,
'status': 'error',
'error': str(e)
}
def cancel_download(self, gid: str) -> bool:
"""
Cancel an active download
Args:
gid: aria2 download GID
Returns:
True if cancelled successfully
"""
try:
# Remove by GID using force remove
download = self.api.get_download(gid)
if download:
download.remove(force=True)
logger.debug(f"Cancelled download: {gid}")
return True
return False
except Exception as e:
logger.error(f"Failed to cancel download {gid}: {e}")
return False
def get_global_stats(self) -> Dict[str, Any]:
"""
Get global download statistics
Returns:
Stats dict with download speed, active downloads, etc.
"""
try:
stats = self.api.get_stats()
return {
'download_speed': getattr(stats, 'download_speed', 0),
'upload_speed': getattr(stats, 'upload_speed', 0),
'num_active': getattr(stats, 'num_active', 0),
'num_waiting': getattr(stats, 'num_waiting', 0),
'num_stopped': getattr(stats, 'num_stopped', 0)
}
except Exception as e:
logger.error(f"Failed to get stats: {e}")
return {
'download_speed': 0,
'upload_speed': 0,
'num_active': 0,
'num_waiting': 0,
'num_stopped': 0
}
def get_aria2_manager(
host: str = 'localhost',
port: int = 6800,
secret: str = 'dkidownload_secret_2025'
) -> Optional[Aria2DownloadManager]:
"""
Get or create global aria2 manager instance (singleton)
Args:
host: aria2c RPC host
port: aria2c RPC port
secret: RPC secret token
Returns:
Aria2DownloadManager instance or None if connection failed
"""
global _aria2_manager
if _aria2_manager is None:
try:
_aria2_manager = Aria2DownloadManager(
host=host,
port=port,
secret=secret
)
except Exception as e:
logger.error(f"Failed to create aria2 manager: {e}")
return None
return _aria2_manager
def reset_aria2_manager():
"""Reset the global manager instance (useful for testing)"""
global _aria2_manager
_aria2_manager = None