"""Scan service for orchestrating network scanning operations.""" import asyncio import logging from datetime import datetime from typing import Optional, Dict, Any from sqlalchemy.orm import Session from app.models import Scan, Host, Service, Connection from app.schemas import ScanConfigRequest, ScanStatus as ScanStatusEnum from app.scanner.network_scanner import NetworkScanner from app.scanner.port_scanner import PortScanner from app.scanner.service_detector import ServiceDetector from app.scanner.nmap_scanner import NmapScanner from app.config import settings logger = logging.getLogger(__name__) class ScanService: """Service for managing network scans.""" def __init__(self, db: Session): """ Initialize scan service. Args: db: Database session """ self.db = db self.active_scans: Dict[int, asyncio.Task] = {} self.cancel_requested: Dict[int, bool] = {} def create_scan(self, config: ScanConfigRequest) -> Scan: """ Create a new scan record. Args: config: Scan configuration Returns: Created scan object """ scan = Scan( scan_type=config.scan_type.value, network_range=config.network_range, status=ScanStatusEnum.PENDING.value, started_at=datetime.utcnow() ) self.db.add(scan) self.db.commit() self.db.refresh(scan) logger.info(f"Created scan {scan.id} for {config.network_range}") return scan def cancel_scan(self, scan_id: int) -> bool: """ Cancel a running scan. Args: scan_id: Scan ID to cancel Returns: True if scan was cancelled, False if not found or not running """ try: scan = self.db.query(Scan).filter(Scan.id == scan_id).first() if not scan: logger.warning(f"Scan {scan_id} not found") return False if scan.status not in [ScanStatusEnum.PENDING.value, ScanStatusEnum.RUNNING.value]: logger.warning(f"Scan {scan_id} is not running (status: {scan.status})") return False # Mark for cancellation self.cancel_requested[scan_id] = True # Cancel the task if it exists if scan_id in self.active_scans: task = self.active_scans[scan_id] task.cancel() del self.active_scans[scan_id] # Update scan status scan.status = ScanStatusEnum.CANCELLED.value scan.completed_at = datetime.utcnow() self.db.commit() logger.info(f"Cancelled scan {scan_id}") return True except Exception as e: logger.error(f"Error cancelling scan {scan_id}: {e}") self.db.rollback() return False async def execute_scan( self, scan_id: int, config: ScanConfigRequest, progress_callback: Optional[callable] = None ) -> None: """ Execute a network scan. Args: scan_id: Scan ID config: Scan configuration progress_callback: Optional callback for progress updates """ scan = self.db.query(Scan).filter(Scan.id == scan_id).first() if not scan: logger.error(f"Scan {scan_id} not found") return try: # Initialize cancellation flag self.cancel_requested[scan_id] = False # Update scan status scan.status = ScanStatusEnum.RUNNING.value self.db.commit() logger.info(f"Starting scan {scan_id}") # Check for cancellation if self.cancel_requested.get(scan_id): raise asyncio.CancelledError("Scan cancelled by user") # Initialize scanners network_scanner = NetworkScanner( progress_callback=lambda host, progress: self._on_host_progress( scan_id, host, progress, progress_callback ) ) # Phase 1: Host Discovery logger.info(f"Phase 1: Discovering hosts in {config.network_range}") active_hosts = await network_scanner.scan_network(config.network_range) scan.hosts_found = len(active_hosts) self.db.commit() logger.info(f"Found {len(active_hosts)} active hosts") # Check for cancellation if self.cancel_requested.get(scan_id): raise asyncio.CancelledError("Scan cancelled by user") # Send progress update if progress_callback: await progress_callback({ 'type': 'scan_progress', 'scan_id': scan_id, 'progress': 0.3, 'current_host': f"Found {len(active_hosts)} hosts" }) # Phase 2: Port Scanning and Service Detection if config.use_nmap and settings.enable_nmap: await self._scan_with_nmap(scan, active_hosts, config, progress_callback) else: await self._scan_with_socket(scan, active_hosts, config, progress_callback) # Phase 3: Detect Connections await self._detect_connections(scan, network_scanner) # Mark scan as completed scan.status = ScanStatusEnum.COMPLETED.value scan.completed_at = datetime.utcnow() self.db.commit() logger.info(f"Scan {scan_id} completed successfully") if progress_callback: await progress_callback({ 'type': 'scan_completed', 'scan_id': scan_id, 'hosts_found': scan.hosts_found }) except asyncio.CancelledError: logger.info(f"Scan {scan_id} was cancelled") scan.status = ScanStatusEnum.CANCELLED.value scan.completed_at = datetime.utcnow() self.db.commit() if progress_callback: await progress_callback({ 'type': 'scan_completed', 'scan_id': scan_id, 'hosts_found': scan.hosts_found }) except Exception as e: logger.error(f"Error executing scan {scan_id}: {e}", exc_info=True) scan.status = ScanStatusEnum.FAILED.value scan.error_message = str(e) scan.completed_at = datetime.utcnow() self.db.commit() if progress_callback: await progress_callback({ 'type': 'scan_failed', 'scan_id': scan_id, 'error': str(e) }) finally: # Cleanup self.cancel_requested.pop(scan_id, None) self.active_scans.pop(scan_id, None) async def _scan_with_socket( self, scan: Scan, hosts: list, config: ScanConfigRequest, progress_callback: Optional[callable] ) -> None: """Scan hosts using socket-based scanning.""" port_scanner = PortScanner( progress_callback=lambda host, port, progress: self._on_port_progress( scan.id, host, port, progress, progress_callback ) ) service_detector = ServiceDetector() for idx, ip in enumerate(hosts, 1): try: # Check for cancellation if self.cancel_requested.get(scan.id): logger.info(f"Scan {scan.id} cancelled during port scanning") raise asyncio.CancelledError("Scan cancelled by user") logger.info(f"Scanning host {idx}/{len(hosts)}: {ip}") # Get or create host host = self._get_or_create_host(ip) self.db.commit() # Commit to ensure host.id is set self.db.refresh(host) # Send host discovered notification if progress_callback: await progress_callback({ 'type': 'host_discovered', 'scan_id': scan.id, 'host': { 'ip_address': ip, 'status': 'online' } }) # Scan ports custom_ports = None if config.port_range: custom_ports = port_scanner.parse_port_range(config.port_range) open_ports = await port_scanner.scan_host_ports( ip, scan_type=config.scan_type.value, custom_ports=custom_ports ) scan.ports_scanned += len(open_ports) # Detect services if config.include_service_detection: for port_info in open_ports: service_info = service_detector.detect_service(ip, port_info['port']) port_info.update(service_info) # Store services self._store_services(host, open_ports) # Associate host with scan if host not in scan.hosts: scan.hosts.append(host) self.db.commit() # Send progress update if progress_callback: progress = 0.3 + (0.6 * (idx / len(hosts))) # 30-90% for port scanning await progress_callback({ 'type': 'scan_progress', 'scan_id': scan.id, 'progress': progress, 'current_host': f"Scanning {ip} ({idx}/{len(hosts)})" }) except Exception as e: logger.error(f"Error scanning host {ip}: {e}") continue async def _scan_with_nmap( self, scan: Scan, hosts: list, config: ScanConfigRequest, progress_callback: Optional[callable] ) -> None: """Scan hosts using nmap.""" nmap_scanner = NmapScanner() if not nmap_scanner.nmap_available: logger.warning("Nmap not available, falling back to socket scanning") await self._scan_with_socket(scan, hosts, config, progress_callback) return # Scan each host with nmap for idx, ip in enumerate(hosts, 1): try: logger.info(f"Scanning host {idx}/{len(hosts)} with nmap: {ip}") # Get or create host host = self._get_or_create_host(ip) self.db.commit() # Commit to ensure host.id is set self.db.refresh(host) # Build nmap arguments port_range = config.port_range if config.port_range else None arguments = nmap_scanner.get_scan_arguments( scan_type=config.scan_type.value, service_detection=config.include_service_detection, port_range=port_range ) # Execute nmap scan result = await nmap_scanner.scan_host(ip, arguments) if result: # Update hostname if available if result.get('hostname'): host.hostname = result['hostname'] # Store services if result.get('ports'): self._store_services(host, result['ports']) scan.ports_scanned += len(result['ports']) # Store OS information if result.get('os_matches'): best_match = max(result['os_matches'], key=lambda x: float(x['accuracy'])) host.os_guess = best_match['name'] # Associate host with scan if host not in scan.hosts: scan.hosts.append(host) self.db.commit() except Exception as e: logger.error(f"Error scanning host {ip} with nmap: {e}") continue def _get_or_create_host(self, ip: str) -> Host: """Get existing host or create new one.""" host = self.db.query(Host).filter(Host.ip_address == ip).first() if host: host.last_seen = datetime.utcnow() host.status = 'online' else: host = Host( ip_address=ip, status='online', first_seen=datetime.utcnow(), last_seen=datetime.utcnow() ) self.db.add(host) return host def _store_services(self, host: Host, services_data: list) -> None: """Store or update services for a host.""" for service_info in services_data: # Check if service already exists service = self.db.query(Service).filter( Service.host_id == host.id, Service.port == service_info['port'], Service.protocol == service_info.get('protocol', 'tcp') ).first() if service: # Update existing service service.last_seen = datetime.utcnow() service.state = service_info.get('state', 'open') if service_info.get('service_name'): service.service_name = service_info['service_name'] if service_info.get('service_version'): service.service_version = service_info['service_version'] if service_info.get('banner'): service.banner = service_info['banner'] else: # Create new service service = Service( host_id=host.id, port=service_info['port'], protocol=service_info.get('protocol', 'tcp'), state=service_info.get('state', 'open'), service_name=service_info.get('service_name'), service_version=service_info.get('service_version'), banner=service_info.get('banner'), first_seen=datetime.utcnow(), last_seen=datetime.utcnow() ) self.db.add(service) async def _detect_connections(self, scan: Scan, network_scanner: NetworkScanner) -> None: """Detect connections between hosts.""" try: # Get gateway gateway_ip = network_scanner.get_local_network_range() if gateway_ip: gateway_network = gateway_ip.split('/')[0].rsplit('.', 1)[0] + '.1' # Find or create gateway host gateway_host = self.db.query(Host).filter( Host.ip_address == gateway_network ).first() if gateway_host: # Connect all hosts to gateway for host in scan.hosts: if host.id != gateway_host.id: self._create_connection( host.id, gateway_host.id, 'gateway', confidence=0.9 ) # Create connections based on services for host in scan.hosts: for service in host.services: # If host has client-type services, it might connect to servers if service.service_name in ['http', 'https', 'ssh']: # Find potential servers on the network for other_host in scan.hosts: if other_host.id != host.id: for other_service in other_host.services: if (other_service.port == service.port and other_service.service_name in ['http', 'https', 'ssh']): self._create_connection( host.id, other_host.id, 'service', protocol='tcp', port=service.port, confidence=0.5 ) self.db.commit() except Exception as e: logger.error(f"Error detecting connections: {e}") def _create_connection( self, source_id: int, target_id: int, conn_type: str, protocol: Optional[str] = None, port: Optional[int] = None, confidence: float = 1.0 ) -> None: """Create a connection if it doesn't exist.""" existing = self.db.query(Connection).filter( Connection.source_host_id == source_id, Connection.target_host_id == target_id, Connection.connection_type == conn_type ).first() if not existing: connection = Connection( source_host_id=source_id, target_host_id=target_id, connection_type=conn_type, protocol=protocol, port=port, confidence=confidence, detected_at=datetime.utcnow() ) self.db.add(connection) def _on_host_progress( self, scan_id: int, host: str, progress: float, callback: Optional[callable] ) -> None: """Handle host discovery progress.""" if callback: asyncio.create_task(callback({ 'type': 'scan_progress', 'scan_id': scan_id, 'current_host': host, 'progress': progress * 0.5 # Host discovery is first 50% })) def _on_port_progress( self, scan_id: int, host: str, port: int, progress: float, callback: Optional[callable] ) -> None: """Handle port scanning progress.""" if callback: asyncio.create_task(callback({ 'type': 'scan_progress', 'scan_id': scan_id, 'current_host': host, 'current_port': port, 'progress': 0.5 + (progress * 0.5) # Port scanning is second 50% })) def get_scan_status(self, scan_id: int) -> Optional[Scan]: """Get scan status by ID.""" return self.db.query(Scan).filter(Scan.id == scan_id).first() def list_scans(self, limit: int = 50, offset: int = 0) -> list: """List recent scans.""" return self.db.query(Scan)\ .order_by(Scan.started_at.desc())\ .limit(limit)\ .offset(offset)\ .all() def cancel_scan(self, scan_id: int) -> bool: """Cancel a running scan.""" if scan_id in self.active_scans: task = self.active_scans[scan_id] task.cancel() del self.active_scans[scan_id] scan = self.get_scan_status(scan_id) if scan: scan.status = ScanStatusEnum.CANCELLED.value scan.completed_at = datetime.utcnow() self.db.commit() return True return False