"""WebSocket endpoint for real-time updates.""" import asyncio import json import logging from typing import Set from fastapi import APIRouter, WebSocket, WebSocketDisconnect from datetime import datetime logger = logging.getLogger(__name__) router = APIRouter() class ConnectionManager: """Manager for WebSocket connections.""" def __init__(self): """Initialize connection manager.""" self.active_connections: Set[WebSocket] = set() async def connect(self, websocket: WebSocket): """ Accept and register a new WebSocket connection. Args: websocket: WebSocket connection """ await websocket.accept() self.active_connections.add(websocket) logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}") def disconnect(self, websocket: WebSocket): """ Remove a WebSocket connection. Args: websocket: WebSocket connection """ self.active_connections.discard(websocket) logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}") async def send_personal_message(self, message: dict, websocket: WebSocket): """ Send a message to a specific WebSocket. Args: message: Message to send websocket: WebSocket connection """ try: await websocket.send_json(message) except Exception as e: logger.error(f"Error sending message: {e}") self.disconnect(websocket) async def broadcast(self, message: dict): """ Broadcast a message to all connected WebSockets. Args: message: Message to broadcast """ disconnected = set() for connection in self.active_connections: try: await connection.send_json(message) except Exception as e: logger.error(f"Error broadcasting to connection: {e}") disconnected.add(connection) # Clean up disconnected clients for connection in disconnected: self.disconnect(connection) # Global connection manager instance manager = ConnectionManager() @router.websocket("") async def websocket_endpoint(websocket: WebSocket): """ WebSocket endpoint for real-time scan updates. Args: websocket: WebSocket connection """ await manager.connect(websocket) try: # Send welcome message await manager.send_personal_message({ 'type': 'connected', 'message': 'Connected to network scanner', 'timestamp': datetime.utcnow().isoformat() }, websocket) # Keep connection alive and handle incoming messages while True: try: # Receive messages from client data = await websocket.receive_text() # Parse and handle client messages try: message = json.loads(data) await handle_client_message(message, websocket) except json.JSONDecodeError: await manager.send_personal_message({ 'type': 'error', 'message': 'Invalid JSON format', 'timestamp': datetime.utcnow().isoformat() }, websocket) except WebSocketDisconnect: break except Exception as e: logger.error(f"Error in WebSocket loop: {e}") break finally: manager.disconnect(websocket) async def handle_client_message(message: dict, websocket: WebSocket): """ Handle messages from client. Args: message: Client message websocket: WebSocket connection """ message_type = message.get('type') if message_type == 'ping': # Respond to ping await manager.send_personal_message({ 'type': 'pong', 'timestamp': datetime.utcnow().isoformat() }, websocket) elif message_type == 'subscribe': # Handle subscription requests scan_id = message.get('scan_id') if scan_id: await manager.send_personal_message({ 'type': 'subscribed', 'scan_id': scan_id, 'timestamp': datetime.utcnow().isoformat() }, websocket) else: logger.warning(f"Unknown message type: {message_type}") async def broadcast_scan_update(scan_id: int, update_type: str, data: dict): """ Broadcast scan update to all connected clients. Args: scan_id: Scan ID update_type: Type of update data: Update data """ message = { 'type': update_type, 'scan_id': scan_id, 'data': data, 'timestamp': datetime.utcnow().isoformat() } await manager.broadcast(message) async def send_scan_progress(scan_id: int, progress: float, current_host: str = None): """ Send scan progress update. Args: scan_id: Scan ID progress: Progress value (0.0 to 1.0) current_host: Currently scanning host """ await broadcast_scan_update(scan_id, 'scan_progress', { 'progress': progress, 'current_host': current_host }) async def send_host_discovered(scan_id: int, host_data: dict): """ Send host discovered notification. Args: scan_id: Scan ID host_data: Host information """ await broadcast_scan_update(scan_id, 'host_discovered', host_data) async def send_scan_completed(scan_id: int, summary: dict): """ Send scan completed notification. Args: scan_id: Scan ID summary: Scan summary """ await broadcast_scan_update(scan_id, 'scan_completed', summary) async def send_scan_failed(scan_id: int, error: str): """ Send scan failed notification. Args: scan_id: Scan ID error: Error message """ await broadcast_scan_update(scan_id, 'scan_failed', {'error': error})