"""Port scanner implementation.""" import socket import asyncio from typing import List, Dict, Set, Optional, Callable from concurrent.futures import ThreadPoolExecutor import logging from app.config import settings logger = logging.getLogger(__name__) class PortScanner: """Scanner for detecting open ports on hosts.""" # Predefined port ranges for different scan types PORT_RANGES = { 'quick': [21, 22, 23, 25, 53, 80, 110, 143, 443, 445, 3306, 3389, 5432, 8080, 8443], 'standard': list(range(1, 1001)), 'deep': list(range(1, 65536)), } def __init__( self, timeout: int = None, max_workers: int = None, progress_callback: Optional[Callable[[str, int, float], None]] = None ): """ Initialize port scanner. Args: timeout: Socket connection timeout in seconds max_workers: Maximum number of concurrent workers progress_callback: Optional callback for progress updates (host, port, progress) """ self.timeout = timeout or settings.default_scan_timeout self.max_workers = max_workers or settings.max_concurrent_scans self.progress_callback = progress_callback async def scan_host_ports( self, host: str, scan_type: str = 'quick', custom_ports: Optional[List[int]] = None ) -> List[Dict[str, any]]: """ Scan ports on a single host. Args: host: IP address or hostname scan_type: Type of scan ('quick', 'standard', 'deep', or 'custom') custom_ports: Custom port list (required if scan_type is 'custom') Returns: List of dictionaries with port information """ logger.info(f"Starting port scan on {host} (type: {scan_type})") # Determine ports to scan if scan_type == 'custom' and custom_ports: ports = custom_ports elif scan_type in self.PORT_RANGES: ports = self.PORT_RANGES[scan_type] else: ports = self.PORT_RANGES['quick'] # Scan ports open_ports = await self._scan_ports_async(host, ports) logger.info(f"Scan completed on {host}. Found {len(open_ports)} open ports") return open_ports async def _scan_ports_async(self, host: str, ports: List[int]) -> List[Dict[str, any]]: """ Scan multiple ports asynchronously. Args: host: Host to scan ports: List of ports to scan Returns: List of open port information """ open_ports = [] total = len(ports) completed = 0 loop = asyncio.get_event_loop() with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [] for port in ports: future = loop.run_in_executor(executor, self._check_port, host, port) futures.append((port, future)) # Process results for port, future in futures: try: result = await future if result: open_ports.append(result) logger.debug(f"Found open port {port} on {host}") except Exception as e: logger.debug(f"Error checking port {port} on {host}: {e}") finally: completed += 1 if self.progress_callback: progress = completed / total self.progress_callback(host, port, progress) return sorted(open_ports, key=lambda x: x['port']) def _check_port(self, host: str, port: int) -> Optional[Dict[str, any]]: """ Check if a port is open on a host. Args: host: Host to check port: Port number Returns: Dictionary with port info if open, None otherwise """ try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(self.timeout) result = sock.connect_ex((host, port)) sock.close() if result == 0: return { 'port': port, 'protocol': 'tcp', 'state': 'open', 'service_name': self._guess_service_name(port) } return None except socket.error as e: logger.debug(f"Socket error checking {host}:{port}: {e}") return None except Exception as e: logger.debug(f"Error checking {host}:{port}: {e}") return None def _guess_service_name(self, port: int) -> Optional[str]: """ Guess service name based on well-known ports. Args: port: Port number Returns: Service name or None """ common_services = { 20: 'ftp-data', 21: 'ftp', 22: 'ssh', 23: 'telnet', 25: 'smtp', 53: 'dns', 80: 'http', 110: 'pop3', 143: 'imap', 443: 'https', 445: 'smb', 3306: 'mysql', 3389: 'rdp', 5432: 'postgresql', 5900: 'vnc', 8080: 'http-alt', 8443: 'https-alt', } return common_services.get(port) def parse_port_range(self, port_range: str) -> List[int]: """ Parse port range string to list of ports. Args: port_range: String like "80,443,8000-8100" Returns: List of port numbers """ ports = set() try: for part in port_range.split(','): part = part.strip() if '-' in part: # Range like "8000-8100" start, end = map(int, part.split('-')) if 1 <= start <= end <= 65535: ports.update(range(start, end + 1)) else: # Single port port = int(part) if 1 <= port <= 65535: ports.add(port) return sorted(list(ports)) except ValueError as e: logger.error(f"Error parsing port range '{port_range}': {e}") return []