260 lines
9.1 KiB
Python
260 lines
9.1 KiB
Python
"""
|
|
Database operations for VPN Access Server.
|
|
|
|
Handles MySQL connections, connection pooling, and all database operations
|
|
for users, MAC addresses, and session management.
|
|
"""
|
|
|
|
import logging
|
|
import mysql.connector
|
|
from mysql.connector import pooling, Error
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
from datetime import datetime, timedelta, timezone
|
|
from contextlib import contextmanager
|
|
|
|
from config import config
|
|
|
|
|
|
class DatabaseManager:
|
|
"""Manages MySQL database connections and operations."""
|
|
|
|
def __init__(self):
|
|
self.logger = logging.getLogger(__name__)
|
|
self._connection_pool = None
|
|
self._initialize_pool()
|
|
|
|
def _initialize_pool(self):
|
|
"""Initialize MySQL connection pool."""
|
|
try:
|
|
pool_config = {
|
|
'pool_name': 'vpn_access_pool',
|
|
'pool_size': config.database.pool_size,
|
|
'pool_reset_session': True,
|
|
'host': config.database.host,
|
|
'port': config.database.port,
|
|
'database': config.database.database,
|
|
'user': config.database.username,
|
|
'password': config.database.password,
|
|
'charset': config.database.charset,
|
|
'autocommit': config.database.autocommit,
|
|
'time_zone': '+00:00' # Use UTC
|
|
}
|
|
|
|
self._connection_pool = pooling.MySQLConnectionPool(**pool_config)
|
|
self.logger.info("Database connection pool initialized successfully")
|
|
|
|
except Error as e:
|
|
self.logger.error(f"Failed to initialize database pool: {e}")
|
|
raise
|
|
|
|
@contextmanager
|
|
def get_connection(self):
|
|
"""Context manager for database connections."""
|
|
connection = None
|
|
try:
|
|
connection = self._connection_pool.get_connection()
|
|
yield connection
|
|
except Error as e:
|
|
self.logger.error(f"Database connection error: {e}")
|
|
if connection:
|
|
connection.rollback()
|
|
raise
|
|
finally:
|
|
if connection and connection.is_connected():
|
|
connection.close()
|
|
|
|
def execute_query(self, query: str, params: Tuple = None, fetch: str = None) -> Any:
|
|
"""
|
|
Execute a database query.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
params: Query parameters tuple
|
|
fetch: 'one', 'all', or None for SELECT queries
|
|
|
|
Returns:
|
|
Query results or affected row count
|
|
"""
|
|
with self.get_connection() as connection:
|
|
cursor = connection.cursor(dictionary=True)
|
|
try:
|
|
cursor.execute(query, params or ())
|
|
|
|
if fetch == 'one':
|
|
return cursor.fetchone()
|
|
elif fetch == 'all':
|
|
return cursor.fetchall()
|
|
elif query.strip().upper().startswith('SELECT'):
|
|
return cursor.fetchall()
|
|
else:
|
|
connection.commit()
|
|
return cursor.rowcount
|
|
|
|
finally:
|
|
cursor.close()
|
|
|
|
# User operations
|
|
def get_user_by_username(self, username: str) -> Optional[Dict[str, Any]]:
|
|
"""Get user details by username."""
|
|
query = """
|
|
SELECT id, username, employee_name, employee_email, is_active,
|
|
session_limit, created_at, updated_at
|
|
FROM employees
|
|
WHERE username = %s AND is_active = TRUE
|
|
"""
|
|
return self.execute_query(query, (username,), 'one')
|
|
|
|
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
|
|
"""Get user details by ID."""
|
|
query = """
|
|
SELECT id, username, employee_name, employee_email, is_active,
|
|
session_limit, created_at, updated_at
|
|
FROM employees
|
|
WHERE id = %s AND is_active = TRUE
|
|
"""
|
|
return self.execute_query(query, (user_id,), 'one')
|
|
|
|
def get_user_password(self, username: str) -> Optional[Dict[str, Any]]:
|
|
"""Get user details by username and password."""
|
|
query = """
|
|
SELECT PASSWORD FROM VPN.EMPLOYEES
|
|
WHERE VPN.EMPLOYEES.USERNAME = %s
|
|
AND VPN.EMPLOYEES.IS_ACTIVE = TRUE
|
|
"""
|
|
return self.execute_query(query, (username,), 'one')
|
|
|
|
# MAC address operations
|
|
def get_user_mac_addresses(self, user_id: int) -> List[str]:
|
|
"""Get all MAC addresses for a user."""
|
|
query = """
|
|
SELECT mac FROM mac_addresses
|
|
WHERE employee_id = %s
|
|
ORDER BY created_at
|
|
"""
|
|
results = self.execute_query(query, (user_id,), 'all')
|
|
return [row['mac'] for row in results] if results else []
|
|
|
|
def is_mac_authorized(self, user_id: int, mac_address: str) -> bool:
|
|
"""Check if MAC address is authorized for user."""
|
|
query = """
|
|
SELECT COUNT(*) as count FROM mac_addresses
|
|
WHERE employee_id = %s AND mac = %s
|
|
"""
|
|
result = self.execute_query(query, (user_id, mac_address), 'one')
|
|
return result['count'] > 0 if result else False
|
|
|
|
def add_mac_address(self, user_id: int, mac_address: str) -> bool:
|
|
"""Add a MAC address for a user."""
|
|
query = """
|
|
INSERT INTO mac_addresses (employee_id, mac, created_at, updated_at)
|
|
VALUES (%s, %s, %s, %s)
|
|
ON DUPLICATE KEY UPDATE updated_at = VALUES(updated_at)
|
|
"""
|
|
now = datetime.now(timezone(timedelta(hours=7)))
|
|
try:
|
|
self.execute_query(query, (user_id, mac_address, now, now))
|
|
return True
|
|
except Error as e:
|
|
self.logger.error(f"Failed to add MAC address: {e}")
|
|
return False
|
|
|
|
# Session operations
|
|
def create_session(self, user_id: int) -> Optional[int]:
|
|
"""Create a new session record."""
|
|
query = """
|
|
INSERT INTO sessions (employee_id, start_time, created_at, updated_at)
|
|
VALUES (%s, %s, %s, %s)
|
|
"""
|
|
now = datetime.now(timezone(timedelta(hours=7)))
|
|
try:
|
|
self.execute_query(query, (user_id, now, now, now))
|
|
# Get the last inserted ID
|
|
with self.get_connection() as connection:
|
|
cursor = connection.cursor()
|
|
cursor.execute("SELECT LAST_INSERT_ID()")
|
|
result = cursor.fetchone()
|
|
return result[0] if result else None
|
|
except Error as e:
|
|
self.logger.error(f"Failed to create session: {e}")
|
|
return None
|
|
|
|
def end_session(self, session_id: int) -> bool:
|
|
"""End a session and calculate duration."""
|
|
query = """
|
|
UPDATE sessions
|
|
SET end_time = %s,
|
|
duration = TIMESTAMPDIFF(SECOND, start_time, %s),
|
|
updated_at = %s
|
|
WHERE id = %s AND end_time IS NULL
|
|
"""
|
|
now = datetime.now(timezone(timedelta(hours=7)))
|
|
try:
|
|
rows_affected = self.execute_query(query, (now, now, now, session_id))
|
|
return rows_affected > 0
|
|
except Error as e:
|
|
self.logger.error(f"Failed to end session: {e}")
|
|
return False
|
|
|
|
def get_user_daily_session_time(self, user_id: int, date: datetime = None) -> int:
|
|
"""Get total session time for a user on a specific date (in seconds)."""
|
|
if date is None:
|
|
now = datetime.now(timezone(timedelta(hours=7))).date()
|
|
|
|
start_of_day = datetime.combine(date, datetime.min.time())
|
|
end_of_day = start_of_day + timedelta(days=1)
|
|
|
|
query = """
|
|
SELECT COALESCE(SUM(
|
|
CASE
|
|
WHEN end_time IS NULL THEN TIMESTAMPDIFF(SECOND, start_time, %s)
|
|
ELSE duration
|
|
END
|
|
), 0) as total_seconds
|
|
FROM sessions
|
|
WHERE employee_id = %s
|
|
AND start_time >= %s
|
|
AND start_time < %s
|
|
"""
|
|
|
|
result = self.execute_query(query, (datetime.utcnow(), user_id, start_of_day, end_of_day), 'one')
|
|
return result['total_seconds'] if result else 0
|
|
|
|
def get_active_session(self, user_id: int) -> Optional[Dict[str, Any]]:
|
|
"""Get active session for a user."""
|
|
query = """
|
|
SELECT id, employee_id, start_time, created_at
|
|
FROM sessions
|
|
WHERE employee_id = %s AND end_time IS NULL
|
|
ORDER BY start_time DESC
|
|
LIMIT 1
|
|
"""
|
|
return self.execute_query(query, (user_id,), 'one')
|
|
|
|
def cleanup_old_sessions(self, days_old: int = 90) -> int:
|
|
"""Clean up old session records."""
|
|
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
|
query = """
|
|
DELETE FROM sessions
|
|
WHERE created_at < %s
|
|
"""
|
|
try:
|
|
return self.execute_query(query, (cutoff_date,))
|
|
except Error as e:
|
|
self.logger.error(f"Failed to cleanup old sessions: {e}")
|
|
return 0
|
|
|
|
def health_check(self) -> bool:
|
|
"""Check database connectivity."""
|
|
try:
|
|
with self.get_connection() as connection:
|
|
cursor = connection.cursor()
|
|
cursor.execute("SELECT 1")
|
|
cursor.fetchone()
|
|
return True
|
|
except Error as e:
|
|
self.logger.error(f"Database health check failed: {e}")
|
|
return False
|
|
|
|
|
|
# Global database manager instance
|
|
db = DatabaseManager() |