"""Scan API endpoints.""" import asyncio import logging from typing import List from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from sqlalchemy.orm import Session from app.database import get_db from app.schemas import ( ScanConfigRequest, ScanResponse, ScanStatusResponse, ScanStartResponse, ScanStatus as ScanStatusEnum ) from app.services.scan_service import ScanService from app.api.endpoints.websocket import ( send_scan_progress, send_host_discovered, send_scan_completed, send_scan_failed ) logger = logging.getLogger(__name__) router = APIRouter() @router.post("/start", response_model=ScanStartResponse, status_code=202) async def start_scan( config: ScanConfigRequest, db: Session = Depends(get_db) ): """ Start a new network scan. Args: config: Scan configuration background_tasks: Background task handler db: Database session Returns: Scan start response with scan ID """ try: scan_service = ScanService(db) # Create scan record scan = scan_service.create_scan(config) scan_id = scan.id # Create progress callback for WebSocket updates async def progress_callback(update: dict): """Send progress updates via WebSocket.""" update_type = update.get('type') update_scan_id = update.get('scan_id', scan_id) if update_type == 'scan_progress': await send_scan_progress(update_scan_id, update.get('progress', 0), update.get('current_host')) elif update_type == 'host_discovered': await send_host_discovered(update_scan_id, update.get('host')) elif update_type == 'scan_completed': await send_scan_completed(update_scan_id, {'hosts_found': update.get('hosts_found', 0)}) elif update_type == 'scan_failed': await send_scan_failed(update_scan_id, update.get('error', 'Unknown error')) # Create background task wrapper that uses a new database session async def run_scan_task(): from app.database import SessionLocal scan_db = SessionLocal() try: scan_service_bg = ScanService(scan_db) await scan_service_bg.execute_scan(scan_id, config, progress_callback) finally: scan_db.close() # Create and store the task for this scan task = asyncio.create_task(run_scan_task()) scan_service.active_scans[scan_id] = task logger.info(f"Started scan {scan_id} for {config.network_range}") return ScanStartResponse( scan_id=scan_id, message=f"Scan started for network {config.network_range}", status=ScanStatusEnum.PENDING ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Error starting scan: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Failed to start scan") @router.get("/{scan_id}/status", response_model=ScanStatusResponse) def get_scan_status(scan_id: int, db: Session = Depends(get_db)): """ Get the status of a specific scan. Args: scan_id: Scan ID db: Database session Returns: Scan status information """ scan_service = ScanService(db) scan = scan_service.get_scan_status(scan_id) if not scan: raise HTTPException(status_code=404, detail=f"Scan {scan_id} not found") # Calculate progress progress = 0.0 if scan.status == ScanStatusEnum.COMPLETED.value: progress = 1.0 elif scan.status == ScanStatusEnum.RUNNING.value: # Estimate progress based on hosts found # This is a rough estimate; real-time progress comes from WebSocket if scan.hosts_found > 0: progress = 0.5 # Host discovery done return ScanStatusResponse( id=scan.id, started_at=scan.started_at, completed_at=scan.completed_at, scan_type=scan.scan_type, network_range=scan.network_range, status=ScanStatusEnum(scan.status), hosts_found=scan.hosts_found, ports_scanned=scan.ports_scanned, error_message=scan.error_message, progress=progress, current_host=None, estimated_completion=None ) @router.get("", response_model=List[ScanResponse]) def list_scans( limit: int = 50, offset: int = 0, db: Session = Depends(get_db) ): """ List recent scans. Args: limit: Maximum number of scans to return offset: Number of scans to skip db: Database session Returns: List of scans """ scan_service = ScanService(db) scans = scan_service.list_scans(limit=limit, offset=offset) return [ ScanResponse( id=scan.id, started_at=scan.started_at, completed_at=scan.completed_at, scan_type=scan.scan_type, network_range=scan.network_range, status=ScanStatusEnum(scan.status), hosts_found=scan.hosts_found, ports_scanned=scan.ports_scanned, error_message=scan.error_message ) for scan in scans ] @router.delete("/{scan_id}/cancel") def cancel_scan(scan_id: int, db: Session = Depends(get_db)): """ Cancel a running scan. Args: scan_id: Scan ID db: Database session Returns: Success message """ scan_service = ScanService(db) # Check if scan exists scan = scan_service.get_scan_status(scan_id) if not scan: raise HTTPException(status_code=404, detail=f"Scan {scan_id} not found") # Check if scan is running if scan.status not in [ScanStatusEnum.PENDING.value, ScanStatusEnum.RUNNING.value]: raise HTTPException( status_code=400, detail=f"Cannot cancel scan in status: {scan.status}" ) # Attempt to cancel success = scan_service.cancel_scan(scan_id) if success: return {"message": f"Scan {scan_id} cancelled successfully"} else: raise HTTPException(status_code=500, detail="Failed to cancel scan")