#!/usr/bin/env python3 """ Session management script for VPN Access Server. This script is called by OpenVPN via client-connect and client-disconnect directives. It handles session tracking, time enforcement, and cleanup. Environment variable script_type determines the action: - "client-connect": Start new session - "client-disconnect": End session and update duration """ import sys import os # Add the access module to the Python path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from config import config from db import db from utils import setup_logging, get_env_vars, get_client_info, safe_exit, format_duration def start_session(user_id: int, username: str) -> bool: """ Start a new session for the user. Args: user_id: User ID from database username: Username for logging Returns: True if session started successfully, False otherwise """ try: # Check if user already has an active session active_session = db.get_active_session(user_id) if active_session: logger.warning(f"User {username} already has an active session (ID: {active_session['id']})") # End the previous session before starting a new one db.end_session(active_session['id']) # Create new session session_id = db.create_session(user_id) if session_id: logger.info(f"Session started for user {username} (Session ID: {session_id})") return True else: logger.error(f"Failed to create session for user {username}") return False except Exception as e: logger.error(f"Error starting session for user {username}: {e}") return False def end_session(user_id: int, username: str) -> bool: """ End the active session for the user. Args: user_id: User ID from database username: Username for logging Returns: True if session ended successfully, False otherwise """ try: # Find active session active_session = db.get_active_session(user_id) if not active_session: logger.warning(f"No active session found for user {username}") return True # Not an error, session might have been cleaned up # End the session session_id = active_session['id'] if db.end_session(session_id): # Calculate session duration for logging from datetime import datetime start_time = active_session['start_time'] duration_seconds = int((datetime.utcnow() - start_time).total_seconds()) logger.info(f"Session ended for user {username} " f"(Session ID: {session_id}, Duration: {format_duration(duration_seconds)})") return True else: logger.error(f"Failed to end session {session_id} for user {username}") return False except Exception as e: logger.error(f"Error ending session for user {username}: {e}") return False def handle_client_connect(): """Handle client-connect event.""" try: # Get environment variables env_vars = get_env_vars() client_info = get_client_info(env_vars) # Validate client information if not client_info['username']: safe_exit(1, "Username not provided in client-connect", logger) # Get user from database user = db.get_user_by_username(client_info['username']) if not user: safe_exit(1, f"User not found: {client_info['username']}", logger) user_id = user['id'] username = client_info['username'] logger.info(f"Client connect event for user: {username} " f"from IP: {client_info['untrusted_ip']}") # Start session if start_session(user_id, username): safe_exit(0, f"Session started for user: {username}", logger) else: safe_exit(1, f"Failed to start session for user: {username}", logger) except Exception as e: logger.error(f"Error in client-connect handler: {e}") safe_exit(1, f"Client-connect error: {e}", logger) def handle_client_disconnect(): """Handle client-disconnect event.""" try: # Get environment variables env_vars = get_env_vars() client_info = get_client_info(env_vars) # Validate client information if not client_info['username']: safe_exit(1, "Username not provided in client-disconnect", logger) # Get user from database user = db.get_user_by_username(client_info['username']) if not user: safe_exit(1, f"User not found: {client_info['username']}", logger) user_id = user['id'] username = client_info['username'] logger.info(f"Client disconnect event for user: {username}") # End session if end_session(user_id, username): safe_exit(0, f"Session ended for user: {username}", logger) else: safe_exit(1, f"Failed to end session for user: {username}", logger) except Exception as e: logger.error(f"Error in client-disconnect handler: {e}") safe_exit(1, f"Client-disconnect error: {e}", logger) def cleanup_old_sessions(): """Clean up old session records.""" try: cleanup_days = int(os.getenv('SESSION_CLEANUP_DAYS', '90')) deleted_count = db.cleanup_old_sessions(cleanup_days) logger.info(f"Cleaned up {deleted_count} old session records (older than {cleanup_days} days)") except Exception as e: logger.error(f"Error during session cleanup: {e}") def main(): """Main session management function.""" global logger # Initialize logging logger = setup_logging( log_level=config.server.log_level, log_file=config.server.log_file ) try: # Validate configuration if not config.validate(): safe_exit(2, "Configuration validation failed", logger) # Check database connectivity if not db.health_check(): safe_exit(2, "Database connection failed", logger) # Get script type from environment script_type = os.getenv('script_type', '').lower() if script_type == 'client-connect': handle_client_connect() elif script_type == 'client-disconnect': handle_client_disconnect() elif script_type == 'cleanup': # Manual cleanup mode (can be run via cron) cleanup_old_sessions() safe_exit(0, "Session cleanup completed", logger) else: safe_exit(2, f"Unknown script type: {script_type}", logger) except KeyboardInterrupt: safe_exit(1, "Session management interrupted", logger) except Exception as e: logger.error(f"Unexpected error in session management: {e}") safe_exit(2, f"Internal error: {e}", logger) if __name__ == "__main__": main()