"""Pydantic schemas for API request/response validation.""" from pydantic import BaseModel, Field, IPvAnyAddress, field_validator from typing import Optional, List, Dict, Any from datetime import datetime from enum import Enum class ScanType(str, Enum): """Scan type enumeration.""" QUICK = "quick" STANDARD = "standard" DEEP = "deep" CUSTOM = "custom" class ScanStatus(str, Enum): """Scan status enumeration.""" PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class HostStatus(str, Enum): """Host status enumeration.""" ONLINE = "online" OFFLINE = "offline" SCANNING = "scanning" class ConnectionType(str, Enum): """Connection type enumeration.""" GATEWAY = "gateway" SAME_SUBNET = "same_subnet" SERVICE = "service" INFERRED = "inferred" # Service schemas class ServiceBase(BaseModel): """Base service schema.""" port: int = Field(..., ge=1, le=65535) protocol: str = Field(default="tcp", pattern="^(tcp|udp)$") state: str = Field(default="open") service_name: Optional[str] = None service_version: Optional[str] = None banner: Optional[str] = None class ServiceCreate(ServiceBase): """Schema for creating a service.""" host_id: int class ServiceResponse(ServiceBase): """Schema for service response.""" id: int host_id: int first_seen: datetime last_seen: datetime class Config: from_attributes = True # Host schemas class HostBase(BaseModel): """Base host schema.""" ip_address: str hostname: Optional[str] = None mac_address: Optional[str] = None @field_validator('ip_address') @classmethod def validate_ip(cls, v: str) -> str: """Validate IP address format.""" import ipaddress try: ipaddress.ip_address(v) return v except ValueError: raise ValueError(f"Invalid IP address: {v}") class HostCreate(HostBase): """Schema for creating a host.""" device_type: Optional[str] = None os_guess: Optional[str] = None vendor: Optional[str] = None class HostResponse(HostBase): """Schema for host response.""" id: int first_seen: datetime last_seen: datetime status: HostStatus device_type: Optional[str] = None os_guess: Optional[str] = None vendor: Optional[str] = None notes: Optional[str] = None services: List[ServiceResponse] = [] class Config: from_attributes = True class HostDetailResponse(HostResponse): """Detailed host response with connection info.""" outgoing_connections: List['ConnectionResponse'] = [] incoming_connections: List['ConnectionResponse'] = [] # Connection schemas class ConnectionBase(BaseModel): """Base connection schema.""" source_host_id: int target_host_id: int connection_type: ConnectionType protocol: Optional[str] = None port: Optional[int] = None confidence: float = Field(default=1.0, ge=0.0, le=1.0) class ConnectionCreate(ConnectionBase): """Schema for creating a connection.""" metadata: Optional[Dict[str, Any]] = None class ConnectionResponse(ConnectionBase): """Schema for connection response.""" id: int detected_at: datetime last_verified: Optional[datetime] = None metadata: Optional[Dict[str, Any]] = None class Config: from_attributes = True # Scan schemas class ScanConfigRequest(BaseModel): """Schema for scan configuration request.""" network_range: str scan_type: ScanType = Field(default=ScanType.QUICK) port_range: Optional[str] = None include_service_detection: bool = True use_nmap: bool = True @field_validator('network_range') @classmethod def validate_network(cls, v: str) -> str: """Validate network range format.""" import ipaddress try: network = ipaddress.ip_network(v, strict=False) # Check if it's a private network if not network.is_private: raise ValueError("Only private network ranges are allowed") return v except ValueError as e: raise ValueError(f"Invalid network range: {e}") class ScanResponse(BaseModel): """Schema for scan response.""" id: int started_at: datetime completed_at: Optional[datetime] = None scan_type: str network_range: str status: ScanStatus hosts_found: int = 0 ports_scanned: int = 0 error_message: Optional[str] = None class Config: from_attributes = True class ScanStatusResponse(ScanResponse): """Schema for detailed scan status response.""" progress: float = Field(default=0.0, ge=0.0, le=1.0) current_host: Optional[str] = None estimated_completion: Optional[datetime] = None class ScanStartResponse(BaseModel): """Schema for scan start response.""" scan_id: int message: str status: ScanStatus # Topology schemas class TopologyNode(BaseModel): """Schema for topology graph node.""" id: str ip: str hostname: Optional[str] type: str status: str service_count: int connections: int = 0 class TopologyEdge(BaseModel): """Schema for topology graph edge.""" source: str target: str type: str = "default" confidence: float = 0.5 class TopologyResponse(BaseModel): """Schema for topology graph response.""" nodes: List[TopologyNode] edges: List[TopologyEdge] statistics: Dict[str, Any] = Field(default_factory=dict) # WebSocket message schemas class WSMessageType(str, Enum): """WebSocket message type enumeration.""" SCAN_STARTED = "scan_started" SCAN_PROGRESS = "scan_progress" HOST_DISCOVERED = "host_discovered" SERVICE_DISCOVERED = "service_discovered" SCAN_COMPLETED = "scan_completed" SCAN_FAILED = "scan_failed" ERROR = "error" class WSMessage(BaseModel): """Schema for WebSocket messages.""" type: WSMessageType data: Dict[str, Any] timestamp: datetime = Field(default_factory=datetime.utcnow) # Statistics schemas class NetworkStatistics(BaseModel): """Schema for network statistics.""" total_hosts: int online_hosts: int offline_hosts: int total_services: int total_scans: int last_scan: Optional[datetime] = None most_common_services: List[Dict[str, Any]] = [] # Rebuild models to resolve forward references HostDetailResponse.model_rebuild()