244 lines
6.3 KiB
Python
244 lines
6.3 KiB
Python
|
|
"""
|
||
|
|
Utility functions for VPN Access Server.
|
||
|
|
|
||
|
|
Shared utility functions for logging, MAC address validation,
|
||
|
|
environment variable processing, and other common operations.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import sys
|
||
|
|
import logging
|
||
|
|
import hashlib
|
||
|
|
from typing import Optional, Dict, Any, Tuple
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
|
||
|
|
def setup_logging(log_level: str = 'INFO', log_file: Optional[str] = None) -> logging.Logger:
|
||
|
|
"""
|
||
|
|
Set up logging configuration for the VPN Access Server.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||
|
|
log_file: Optional log file path
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Configured logger instance
|
||
|
|
"""
|
||
|
|
logger = logging.getLogger('vpn_access_server')
|
||
|
|
logger.setLevel(getattr(logging, log_level.upper()))
|
||
|
|
|
||
|
|
# Create formatter
|
||
|
|
formatter = logging.Formatter(
|
||
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||
|
|
)
|
||
|
|
|
||
|
|
# Console handler
|
||
|
|
console_handler = logging.StreamHandler(sys.stdout)
|
||
|
|
console_handler.setFormatter(formatter)
|
||
|
|
logger.addHandler(console_handler)
|
||
|
|
|
||
|
|
# File handler if specified
|
||
|
|
if log_file:
|
||
|
|
try:
|
||
|
|
# Ensure log directory exists
|
||
|
|
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||
|
|
file_handler = logging.FileHandler(log_file)
|
||
|
|
file_handler.setFormatter(formatter)
|
||
|
|
logger.addHandler(file_handler)
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"Could not set up file logging: {e}")
|
||
|
|
|
||
|
|
return logger
|
||
|
|
|
||
|
|
|
||
|
|
def validate_mac_address(mac: str) -> bool:
|
||
|
|
"""
|
||
|
|
Validate MAC address format.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
mac: MAC address string
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if valid MAC address format, False otherwise
|
||
|
|
"""
|
||
|
|
if not mac:
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Normalize MAC address (remove separators and convert to lowercase)
|
||
|
|
mac_clean = re.sub(r'[:-]', '', mac.lower())
|
||
|
|
|
||
|
|
# Check if it's exactly 12 hexadecimal characters
|
||
|
|
if len(mac_clean) != 12:
|
||
|
|
return False
|
||
|
|
|
||
|
|
return all(c in '0123456789abcdef' for c in mac_clean)
|
||
|
|
|
||
|
|
|
||
|
|
def normalize_mac_address(mac: str) -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Normalize MAC address to standard format (XX:XX:XX:XX:XX:XX).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
mac: MAC address string in any common format
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Normalized MAC address or None if invalid
|
||
|
|
"""
|
||
|
|
if not validate_mac_address(mac):
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Remove all separators and convert to lowercase
|
||
|
|
mac_clean = re.sub(r'[:-]', '', mac.lower())
|
||
|
|
|
||
|
|
# Insert colons every 2 characters
|
||
|
|
return ':'.join(mac_clean[i:i+2] for i in range(0, 12, 2))
|
||
|
|
|
||
|
|
|
||
|
|
def get_env_vars() -> Dict[str, str]:
|
||
|
|
"""
|
||
|
|
Get OpenVPN environment variables.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary of relevant environment variables
|
||
|
|
"""
|
||
|
|
env_vars = {}
|
||
|
|
|
||
|
|
# Common OpenVPN environment variables
|
||
|
|
openvpn_vars = [
|
||
|
|
'username', 'password', 'common_name', 'trusted_ip', 'trusted_port',
|
||
|
|
'untrusted_ip', 'untrusted_port', 'ifconfig_pool_remote_ip',
|
||
|
|
'script_type', 'time_ascii', 'time_unix', 'time_duration',
|
||
|
|
'CLIENT_MAC' # Custom variable set by client
|
||
|
|
]
|
||
|
|
|
||
|
|
for var in openvpn_vars:
|
||
|
|
value = os.getenv(var)
|
||
|
|
if value:
|
||
|
|
env_vars[var] = value
|
||
|
|
|
||
|
|
return env_vars
|
||
|
|
|
||
|
|
|
||
|
|
def hash_password(password: str, salt: Optional[str] = None) -> Tuple[str, str]:
|
||
|
|
"""
|
||
|
|
Hash password using SHA-256 with salt.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
password: Plain text password
|
||
|
|
salt: Optional salt (generated if not provided)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (hashed_password, salt)
|
||
|
|
"""
|
||
|
|
if salt is None:
|
||
|
|
salt = os.urandom(32).hex()
|
||
|
|
|
||
|
|
# Combine password and salt
|
||
|
|
salted_password = password + salt
|
||
|
|
|
||
|
|
# Hash with SHA-256
|
||
|
|
hashed = hashlib.sha256(salted_password.encode()).hexdigest()
|
||
|
|
|
||
|
|
return hashed, salt
|
||
|
|
|
||
|
|
|
||
|
|
def verify_password(password: str, hashed_password: str, salt: str) -> bool:
|
||
|
|
"""
|
||
|
|
Verify password against hash.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
password: Plain text password to verify
|
||
|
|
hashed_password: Stored hash
|
||
|
|
salt: Password salt
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if password matches, False otherwise
|
||
|
|
"""
|
||
|
|
computed_hash, _ = hash_password(password, salt)
|
||
|
|
return computed_hash == hashed_password
|
||
|
|
|
||
|
|
|
||
|
|
def format_duration(seconds: int) -> str:
|
||
|
|
"""
|
||
|
|
Format duration in seconds to human-readable format.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
seconds: Duration in seconds
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Formatted duration string (e.g., "2h 30m 45s")
|
||
|
|
"""
|
||
|
|
if seconds < 0:
|
||
|
|
return "0s"
|
||
|
|
|
||
|
|
hours = seconds // 3600
|
||
|
|
minutes = (seconds % 3600) // 60
|
||
|
|
secs = seconds % 60
|
||
|
|
|
||
|
|
parts = []
|
||
|
|
if hours > 0:
|
||
|
|
parts.append(f"{hours}h")
|
||
|
|
if minutes > 0:
|
||
|
|
parts.append(f"{minutes}m")
|
||
|
|
if secs > 0 or not parts: # Always show seconds if no other parts
|
||
|
|
parts.append(f"{secs}s")
|
||
|
|
|
||
|
|
return " ".join(parts)
|
||
|
|
|
||
|
|
|
||
|
|
def safe_exit(code: int, message: str = None, logger: Optional[logging.Logger] = None):
|
||
|
|
"""
|
||
|
|
Safely exit with proper logging.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
code: Exit code (0 for success, non-zero for failure)
|
||
|
|
message: Optional exit message
|
||
|
|
logger: Optional logger instance
|
||
|
|
"""
|
||
|
|
if message:
|
||
|
|
if logger:
|
||
|
|
if code == 0:
|
||
|
|
logger.info(message)
|
||
|
|
else:
|
||
|
|
logger.error(message)
|
||
|
|
else:
|
||
|
|
print(message, file=sys.stderr if code != 0 else sys.stdout)
|
||
|
|
|
||
|
|
sys.exit(code)
|
||
|
|
|
||
|
|
|
||
|
|
def get_client_info(env_vars: Dict[str, str]) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Extract and validate client information from environment variables.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
env_vars: Dictionary of environment variables
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary with validated client information
|
||
|
|
"""
|
||
|
|
client_info = {
|
||
|
|
'username': env_vars.get('username', '').strip(),
|
||
|
|
'password': env_vars.get('password', '').strip(),
|
||
|
|
'common_name': env_vars.get('common_name', '').strip(),
|
||
|
|
'trusted_ip': env_vars.get('trusted_ip', '').strip(),
|
||
|
|
'untrusted_ip': env_vars.get('untrusted_ip', '').strip(),
|
||
|
|
'mac_address': None,
|
||
|
|
'timestamp': datetime.utcnow()
|
||
|
|
}
|
||
|
|
|
||
|
|
# Process MAC address
|
||
|
|
raw_mac = env_vars.get('CLIENT_MAC', '').strip()
|
||
|
|
if raw_mac:
|
||
|
|
client_info['mac_address'] = normalize_mac_address(raw_mac)
|
||
|
|
|
||
|
|
# Validate required fields
|
||
|
|
client_info['is_valid'] = bool(
|
||
|
|
client_info['username'] and
|
||
|
|
client_info['password'] and
|
||
|
|
client_info['mac_address']
|
||
|
|
)
|
||
|
|
|
||
|
|
return client_info
|