"""Topology service for network graph generation.""" import logging from typing import List, Dict, Any from sqlalchemy.orm import Session from sqlalchemy import func from app.models import Host, Service, Connection from app.schemas import TopologyNode, TopologyEdge, TopologyResponse logger = logging.getLogger(__name__) class TopologyService: """Service for generating network topology graphs.""" # Node type colors NODE_COLORS = { 'gateway': '#FF6B6B', 'server': '#4ECDC4', 'workstation': '#45B7D1', 'device': '#96CEB4', 'unknown': '#95A5A6' } def __init__(self, db: Session): """ Initialize topology service. Args: db: Database session """ self.db = db def generate_topology(self, include_offline: bool = False) -> TopologyResponse: """ Generate network topology graph. Args: include_offline: Include offline hosts Returns: Topology response with nodes and edges """ logger.info("Generating network topology") # Get hosts query = self.db.query(Host) if not include_offline: query = query.filter(Host.status == 'online') hosts = query.all() # Generate nodes nodes = [] for host in hosts: node = self._create_node(host) nodes.append(node) # Generate edges from connections edges = [] connections = self.db.query(Connection).all() for conn in connections: # Only include edges if both hosts are in the topology source_in_topology = any(n.id == str(conn.source_host_id) for n in nodes) target_in_topology = any(n.id == str(conn.target_host_id) for n in nodes) if source_in_topology and target_in_topology: edge = self._create_edge(conn) edges.append(edge) # Generate statistics statistics = self._generate_statistics(hosts, connections) logger.info(f"Generated topology with {len(nodes)} nodes and {len(edges)} edges") return TopologyResponse( nodes=nodes, edges=edges, statistics=statistics ) def _create_node(self, host: Host) -> TopologyNode: """ Create a topology node from a host. Args: host: Host model Returns: TopologyNode """ # Determine device type device_type = self._determine_device_type(host) # Count connections connections = self.db.query(Connection).filter( (Connection.source_host_id == host.id) | (Connection.target_host_id == host.id) ).count() return TopologyNode( id=str(host.id), ip=host.ip_address, hostname=host.hostname, type=device_type, status=host.status, service_count=len(host.services), connections=connections ) def _determine_device_type(self, host: Host) -> str: """ Determine device type based on host information. Args: host: Host model Returns: Device type string """ # Check if explicitly set if host.device_type: return host.device_type # Infer from services service_names = [s.service_name for s in host.services if s.service_name] # Check for gateway indicators if any(s.port == 53 for s in host.services): # DNS server return 'gateway' # Check for server indicators server_services = ['http', 'https', 'ssh', 'smtp', 'mysql', 'postgresql', 'ftp'] if any(svc in service_names for svc in server_services): if len(host.services) > 5: return 'server' # Check for workstation indicators if any(s.port == 3389 for s in host.services): # RDP return 'workstation' # Default to device if len(host.services) > 0: return 'device' return 'unknown' def _create_edge(self, connection: Connection) -> TopologyEdge: """ Create a topology edge from a connection. Args: connection: Connection model Returns: TopologyEdge """ return TopologyEdge( source=str(connection.source_host_id), target=str(connection.target_host_id), type=connection.connection_type or 'default', confidence=connection.confidence ) def _generate_statistics( self, hosts: List[Host], connections: List[Connection] ) -> Dict[str, Any]: """ Generate statistics about the topology. Args: hosts: List of hosts connections: List of connections Returns: Statistics dictionary """ # Count isolated nodes (no connections) isolated = 0 for host in hosts: conn_count = self.db.query(Connection).filter( (Connection.source_host_id == host.id) | (Connection.target_host_id == host.id) ).count() if conn_count == 0: isolated += 1 # Calculate average connections avg_connections = len(connections) / max(len(hosts), 1) if hosts else 0 return { 'total_nodes': len(hosts), 'total_edges': len(connections), 'isolated_nodes': isolated, 'avg_connections': round(avg_connections, 2) } def get_host_neighbors(self, host_id: int) -> List[Host]: """ Get all hosts connected to a specific host. Args: host_id: Host ID Returns: List of connected hosts """ # Get outgoing connections outgoing = self.db.query(Connection).filter( Connection.source_host_id == host_id ).all() # Get incoming connections incoming = self.db.query(Connection).filter( Connection.target_host_id == host_id ).all() # Collect unique neighbor IDs neighbor_ids = set() for conn in outgoing: neighbor_ids.add(conn.target_host_id) for conn in incoming: neighbor_ids.add(conn.source_host_id) # Get host objects neighbors = self.db.query(Host).filter( Host.id.in_(neighbor_ids) ).all() return neighbors def get_network_statistics(self) -> Dict[str, Any]: """ Get network statistics. Returns: Statistics dictionary """ total_hosts = self.db.query(func.count(Host.id)).scalar() online_hosts = self.db.query(func.count(Host.id)).filter( Host.status == 'online' ).scalar() total_services = self.db.query(func.count(Service.id)).scalar() return { 'total_hosts': total_hosts, 'online_hosts': online_hosts, 'offline_hosts': total_hosts - online_hosts, 'total_services': total_services, 'total_connections': self.db.query(func.count(Connection.id)).scalar() }