VPN/access/utils.py
2025-09-27 23:06:32 +07:00

244 lines
6.3 KiB
Python

"""
Utility functions for VPN Access Server.
Shared utility functions for logging, MAC address validation,
environment variable processing, and other common operations.
"""
import os
import re
import sys
import logging
import hashlib
from typing import Optional, Dict, Any, Tuple
from datetime import datetime
def setup_logging(log_level: str = 'INFO', log_file: Optional[str] = None) -> logging.Logger:
"""
Set up logging configuration for the VPN Access Server.
Args:
log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_file: Optional log file path
Returns:
Configured logger instance
"""
logger = logging.getLogger('vpn_access_server')
logger.setLevel(getattr(logging, log_level.upper()))
# Create formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# File handler if specified
if log_file:
try:
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file), exist_ok=True)
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
except Exception as e:
logger.warning(f"Could not set up file logging: {e}")
return logger
def validate_mac_address(mac: str) -> bool:
"""
Validate MAC address format.
Args:
mac: MAC address string
Returns:
True if valid MAC address format, False otherwise
"""
if not mac:
return False
# Normalize MAC address (remove separators and convert to lowercase)
mac_clean = re.sub(r'[:-]', '', mac.lower())
# Check if it's exactly 12 hexadecimal characters
if len(mac_clean) != 12:
return False
return all(c in '0123456789abcdef' for c in mac_clean)
def normalize_mac_address(mac: str) -> Optional[str]:
"""
Normalize MAC address to standard format (XX:XX:XX:XX:XX:XX).
Args:
mac: MAC address string in any common format
Returns:
Normalized MAC address or None if invalid
"""
if not validate_mac_address(mac):
return None
# Remove all separators and convert to lowercase
mac_clean = re.sub(r'[:-]', '', mac.lower())
# Insert colons every 2 characters
return ':'.join(mac_clean[i:i+2] for i in range(0, 12, 2))
def get_env_vars() -> Dict[str, str]:
"""
Get OpenVPN environment variables.
Returns:
Dictionary of relevant environment variables
"""
env_vars = {}
# Common OpenVPN environment variables
openvpn_vars = [
'username', 'password', 'common_name', 'trusted_ip', 'trusted_port',
'untrusted_ip', 'untrusted_port', 'ifconfig_pool_remote_ip',
'script_type', 'time_ascii', 'time_unix', 'time_duration',
'CLIENT_MAC' # Custom variable set by client
]
for var in openvpn_vars:
value = os.getenv(var)
if value:
env_vars[var] = value
return env_vars
def hash_password(password: str, salt: Optional[str] = None) -> Tuple[str, str]:
"""
Hash password using SHA-256 with salt.
Args:
password: Plain text password
salt: Optional salt (generated if not provided)
Returns:
Tuple of (hashed_password, salt)
"""
if salt is None:
salt = os.urandom(32).hex()
# Combine password and salt
salted_password = password + salt
# Hash with SHA-256
hashed = hashlib.sha256(salted_password.encode()).hexdigest()
return hashed, salt
def verify_password(password: str, hashed_password: str, salt: str) -> bool:
"""
Verify password against hash.
Args:
password: Plain text password to verify
hashed_password: Stored hash
salt: Password salt
Returns:
True if password matches, False otherwise
"""
computed_hash, _ = hash_password(password, salt)
return computed_hash == hashed_password
def format_duration(seconds: int) -> str:
"""
Format duration in seconds to human-readable format.
Args:
seconds: Duration in seconds
Returns:
Formatted duration string (e.g., "2h 30m 45s")
"""
if seconds < 0:
return "0s"
hours = seconds // 3600
minutes = (seconds % 3600) // 60
secs = seconds % 60
parts = []
if hours > 0:
parts.append(f"{hours}h")
if minutes > 0:
parts.append(f"{minutes}m")
if secs > 0 or not parts: # Always show seconds if no other parts
parts.append(f"{secs}s")
return " ".join(parts)
def safe_exit(code: int, message: str = None, logger: Optional[logging.Logger] = None):
"""
Safely exit with proper logging.
Args:
code: Exit code (0 for success, non-zero for failure)
message: Optional exit message
logger: Optional logger instance
"""
if message:
if logger:
if code == 0:
logger.info(message)
else:
logger.error(message)
else:
print(message, file=sys.stderr if code != 0 else sys.stdout)
sys.exit(code)
def get_client_info(env_vars: Dict[str, str]) -> Dict[str, Any]:
"""
Extract and validate client information from environment variables.
Args:
env_vars: Dictionary of environment variables
Returns:
Dictionary with validated client information
"""
client_info = {
'username': env_vars.get('username', '').strip(),
'password': env_vars.get('password', '').strip(),
'common_name': env_vars.get('common_name', '').strip(),
'trusted_ip': env_vars.get('trusted_ip', '').strip(),
'untrusted_ip': env_vars.get('untrusted_ip', '').strip(),
'mac_address': None,
'timestamp': datetime.utcnow()
}
# Process MAC address
raw_mac = env_vars.get('CLIENT_MAC', '').strip()
if raw_mac:
client_info['mac_address'] = normalize_mac_address(raw_mac)
# Validate required fields
client_info['is_valid'] = bool(
client_info['username'] and
client_info['password'] and
client_info['mac_address']
)
return client_info