VPN/access/db.py

251 lines
8.7 KiB
Python
Raw Normal View History

2025-09-27 16:06:32 +00:00
"""
Database operations for VPN Access Server.
Handles MySQL connections, connection pooling, and all database operations
for users, MAC addresses, and session management.
"""
import logging
import mysql.connector
from mysql.connector import pooling, Error
from typing import Optional, List, Dict, Any, Tuple
from datetime import datetime, timedelta
from contextlib import contextmanager
from config import config
class DatabaseManager:
"""Manages MySQL database connections and operations."""
def __init__(self):
self.logger = logging.getLogger(__name__)
self._connection_pool = None
self._initialize_pool()
def _initialize_pool(self):
"""Initialize MySQL connection pool."""
try:
pool_config = {
'pool_name': 'vpn_access_pool',
'pool_size': config.database.pool_size,
'pool_reset_session': True,
'host': config.database.host,
'port': config.database.port,
'database': config.database.database,
'user': config.database.username,
'password': config.database.password,
'charset': config.database.charset,
'autocommit': config.database.autocommit,
'time_zone': '+00:00' # Use UTC
}
self._connection_pool = pooling.MySQLConnectionPool(**pool_config)
self.logger.info("Database connection pool initialized successfully")
except Error as e:
self.logger.error(f"Failed to initialize database pool: {e}")
raise
@contextmanager
def get_connection(self):
"""Context manager for database connections."""
connection = None
try:
connection = self._connection_pool.get_connection()
yield connection
except Error as e:
self.logger.error(f"Database connection error: {e}")
if connection:
connection.rollback()
raise
finally:
if connection and connection.is_connected():
connection.close()
def execute_query(self, query: str, params: Tuple = None, fetch: str = None) -> Any:
"""
Execute a database query.
Args:
query: SQL query string
params: Query parameters tuple
fetch: 'one', 'all', or None for SELECT queries
Returns:
Query results or affected row count
"""
with self.get_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
cursor.execute(query, params or ())
if fetch == 'one':
return cursor.fetchone()
elif fetch == 'all':
return cursor.fetchall()
elif query.strip().upper().startswith('SELECT'):
return cursor.fetchall()
else:
connection.commit()
return cursor.rowcount
finally:
cursor.close()
# User operations
def get_user_by_username(self, username: str) -> Optional[Dict[str, Any]]:
"""Get user details by username."""
query = """
SELECT id, username, employee_name, employee_email, is_active,
session_limit, created_at, updated_at
FROM employees
WHERE username = %s AND is_active = TRUE
"""
return self.execute_query(query, (username,), 'one')
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
"""Get user details by ID."""
query = """
SELECT id, username, employee_name, employee_email, is_active,
session_limit, created_at, updated_at
FROM employees
WHERE id = %s AND is_active = TRUE
"""
return self.execute_query(query, (user_id,), 'one')
# MAC address operations
def get_user_mac_addresses(self, user_id: int) -> List[str]:
"""Get all MAC addresses for a user."""
query = """
SELECT mac FROM mac_addresses
WHERE employee_id = %s
ORDER BY created_at
"""
results = self.execute_query(query, (user_id,), 'all')
return [row['mac'] for row in results] if results else []
def is_mac_authorized(self, user_id: int, mac_address: str) -> bool:
"""Check if MAC address is authorized for user."""
query = """
SELECT COUNT(*) as count FROM mac_addresses
WHERE employee_id = %s AND mac = %s
"""
result = self.execute_query(query, (user_id, mac_address), 'one')
return result['count'] > 0 if result else False
def add_mac_address(self, user_id: int, mac_address: str) -> bool:
"""Add a MAC address for a user."""
query = """
INSERT INTO mac_addresses (employee_id, mac, created_at, updated_at)
VALUES (%s, %s, %s, %s)
ON DUPLICATE KEY UPDATE updated_at = VALUES(updated_at)
"""
now = datetime.utcnow()
try:
self.execute_query(query, (user_id, mac_address, now, now))
return True
except Error as e:
self.logger.error(f"Failed to add MAC address: {e}")
return False
# Session operations
def create_session(self, user_id: int) -> Optional[int]:
"""Create a new session record."""
query = """
INSERT INTO sessions (employee_id, start_time, created_at, updated_at)
VALUES (%s, %s, %s, %s)
"""
now = datetime.utcnow()
try:
self.execute_query(query, (user_id, now, now, now))
# Get the last inserted ID
with self.get_connection() as connection:
cursor = connection.cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
result = cursor.fetchone()
return result[0] if result else None
except Error as e:
self.logger.error(f"Failed to create session: {e}")
return None
def end_session(self, session_id: int) -> bool:
"""End a session and calculate duration."""
query = """
UPDATE sessions
SET end_time = %s,
duration = TIMESTAMPDIFF(SECOND, start_time, %s),
updated_at = %s
WHERE id = %s AND end_time IS NULL
"""
now = datetime.utcnow()
try:
rows_affected = self.execute_query(query, (now, now, now, session_id))
return rows_affected > 0
except Error as e:
self.logger.error(f"Failed to end session: {e}")
return False
def get_user_daily_session_time(self, user_id: int, date: datetime = None) -> int:
"""Get total session time for a user on a specific date (in seconds)."""
if date is None:
date = datetime.utcnow().date()
start_of_day = datetime.combine(date, datetime.min.time())
end_of_day = start_of_day + timedelta(days=1)
query = """
SELECT COALESCE(SUM(
CASE
WHEN end_time IS NULL THEN TIMESTAMPDIFF(SECOND, start_time, %s)
ELSE duration
END
), 0) as total_seconds
FROM sessions
WHERE employee_id = %s
AND start_time >= %s
AND start_time < %s
"""
result = self.execute_query(query, (datetime.utcnow(), user_id, start_of_day, end_of_day), 'one')
return result['total_seconds'] if result else 0
def get_active_session(self, user_id: int) -> Optional[Dict[str, Any]]:
"""Get active session for a user."""
query = """
SELECT id, employee_id, start_time, created_at
FROM sessions
WHERE employee_id = %s AND end_time IS NULL
ORDER BY start_time DESC
LIMIT 1
"""
return self.execute_query(query, (user_id,), 'one')
def cleanup_old_sessions(self, days_old: int = 90) -> int:
"""Clean up old session records."""
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query = """
DELETE FROM sessions
WHERE created_at < %s
"""
try:
return self.execute_query(query, (cutoff_date,))
except Error as e:
self.logger.error(f"Failed to cleanup old sessions: {e}")
return 0
def health_check(self) -> bool:
"""Check database connectivity."""
try:
with self.get_connection() as connection:
cursor = connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Error as e:
self.logger.error(f"Database health check failed: {e}")
return False
# Global database manager instance
db = DatabaseManager()