""" Database operations for VPN Access Server. Handles MySQL connections, connection pooling, and all database operations for users, MAC addresses, and session management. """ import logging import mysql.connector from mysql.connector import pooling, Error from typing import Optional, List, Dict, Any, Tuple from datetime import datetime, timedelta from contextlib import contextmanager from config import config class DatabaseManager: """Manages MySQL database connections and operations.""" def __init__(self): self.logger = logging.getLogger(__name__) self._connection_pool = None self._initialize_pool() def _initialize_pool(self): """Initialize MySQL connection pool.""" try: pool_config = { 'pool_name': 'vpn_access_pool', 'pool_size': config.database.pool_size, 'pool_reset_session': True, 'host': config.database.host, 'port': config.database.port, 'database': config.database.database, 'user': config.database.username, 'password': config.database.password, 'charset': config.database.charset, 'autocommit': config.database.autocommit, 'time_zone': '+00:00' # Use UTC } self._connection_pool = pooling.MySQLConnectionPool(**pool_config) self.logger.info("Database connection pool initialized successfully") except Error as e: self.logger.error(f"Failed to initialize database pool: {e}") raise @contextmanager def get_connection(self): """Context manager for database connections.""" connection = None try: connection = self._connection_pool.get_connection() yield connection except Error as e: self.logger.error(f"Database connection error: {e}") if connection: connection.rollback() raise finally: if connection and connection.is_connected(): connection.close() def execute_query(self, query: str, params: Tuple = None, fetch: str = None) -> Any: """ Execute a database query. Args: query: SQL query string params: Query parameters tuple fetch: 'one', 'all', or None for SELECT queries Returns: Query results or affected row count """ with self.get_connection() as connection: cursor = connection.cursor(dictionary=True) try: cursor.execute(query, params or ()) if fetch == 'one': return cursor.fetchone() elif fetch == 'all': return cursor.fetchall() elif query.strip().upper().startswith('SELECT'): return cursor.fetchall() else: connection.commit() return cursor.rowcount finally: cursor.close() # User operations def get_user_by_username(self, username: str) -> Optional[Dict[str, Any]]: """Get user details by username.""" query = """ SELECT id, username, employee_name, employee_email, is_active, session_limit, created_at, updated_at FROM employees WHERE username = %s AND is_active = TRUE """ return self.execute_query(query, (username,), 'one') def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]: """Get user details by ID.""" query = """ SELECT id, username, employee_name, employee_email, is_active, session_limit, created_at, updated_at FROM employees WHERE id = %s AND is_active = TRUE """ return self.execute_query(query, (user_id,), 'one') # MAC address operations def get_user_mac_addresses(self, user_id: int) -> List[str]: """Get all MAC addresses for a user.""" query = """ SELECT mac FROM mac_addresses WHERE employee_id = %s ORDER BY created_at """ results = self.execute_query(query, (user_id,), 'all') return [row['mac'] for row in results] if results else [] def is_mac_authorized(self, user_id: int, mac_address: str) -> bool: """Check if MAC address is authorized for user.""" query = """ SELECT COUNT(*) as count FROM mac_addresses WHERE employee_id = %s AND mac = %s """ result = self.execute_query(query, (user_id, mac_address), 'one') return result['count'] > 0 if result else False def add_mac_address(self, user_id: int, mac_address: str) -> bool: """Add a MAC address for a user.""" query = """ INSERT INTO mac_addresses (employee_id, mac, created_at, updated_at) VALUES (%s, %s, %s, %s) ON DUPLICATE KEY UPDATE updated_at = VALUES(updated_at) """ now = datetime.utcnow() try: self.execute_query(query, (user_id, mac_address, now, now)) return True except Error as e: self.logger.error(f"Failed to add MAC address: {e}") return False # Session operations def create_session(self, user_id: int) -> Optional[int]: """Create a new session record.""" query = """ INSERT INTO sessions (employee_id, start_time, created_at, updated_at) VALUES (%s, %s, %s, %s) """ now = datetime.utcnow() try: self.execute_query(query, (user_id, now, now, now)) # Get the last inserted ID with self.get_connection() as connection: cursor = connection.cursor() cursor.execute("SELECT LAST_INSERT_ID()") result = cursor.fetchone() return result[0] if result else None except Error as e: self.logger.error(f"Failed to create session: {e}") return None def end_session(self, session_id: int) -> bool: """End a session and calculate duration.""" query = """ UPDATE sessions SET end_time = %s, duration = TIMESTAMPDIFF(SECOND, start_time, %s), updated_at = %s WHERE id = %s AND end_time IS NULL """ now = datetime.utcnow() try: rows_affected = self.execute_query(query, (now, now, now, session_id)) return rows_affected > 0 except Error as e: self.logger.error(f"Failed to end session: {e}") return False def get_user_daily_session_time(self, user_id: int, date: datetime = None) -> int: """Get total session time for a user on a specific date (in seconds).""" if date is None: date = datetime.utcnow().date() start_of_day = datetime.combine(date, datetime.min.time()) end_of_day = start_of_day + timedelta(days=1) query = """ SELECT COALESCE(SUM( CASE WHEN end_time IS NULL THEN TIMESTAMPDIFF(SECOND, start_time, %s) ELSE duration END ), 0) as total_seconds FROM sessions WHERE employee_id = %s AND start_time >= %s AND start_time < %s """ result = self.execute_query(query, (datetime.utcnow(), user_id, start_of_day, end_of_day), 'one') return result['total_seconds'] if result else 0 def get_active_session(self, user_id: int) -> Optional[Dict[str, Any]]: """Get active session for a user.""" query = """ SELECT id, employee_id, start_time, created_at FROM sessions WHERE employee_id = %s AND end_time IS NULL ORDER BY start_time DESC LIMIT 1 """ return self.execute_query(query, (user_id,), 'one') def cleanup_old_sessions(self, days_old: int = 90) -> int: """Clean up old session records.""" cutoff_date = datetime.utcnow() - timedelta(days=days_old) query = """ DELETE FROM sessions WHERE created_at < %s """ try: return self.execute_query(query, (cutoff_date,)) except Error as e: self.logger.error(f"Failed to cleanup old sessions: {e}") return 0 def health_check(self) -> bool: """Check database connectivity.""" try: with self.get_connection() as connection: cursor = connection.cursor() cursor.execute("SELECT 1") cursor.fetchone() return True except Error as e: self.logger.error(f"Database health check failed: {e}") return False # Global database manager instance db = DatabaseManager()