"""Network scanning helpers for the LAN graph tool.""" from __future__ import annotations import asyncio import dataclasses import ipaddress import json import re import socket import subprocess import time from dataclasses import dataclass, field from pathlib import Path from typing import Iterable, List, Optional, Set, Tuple NEIGHBOR_PATTERN = re.compile(r"^(?P\d+\.\d+\.\d+\.\d+)") @dataclass class HostNode: ip: str dns_name: Optional[str] reachable: bool last_seen: float via_ssh: bool = False comment: Optional[str] = None def to_dict(self) -> dict: return { "ip": self.ip, "dns_name": self.dns_name, "reachable": self.reachable, "last_seen": self.last_seen, "via_ssh": self.via_ssh, "comment": self.comment, } @dataclass class ConnectionEdge: source: str target: str relation: str def to_dict(self) -> dict: return { "source": self.source, "target": self.target, "relation": self.relation, } @dataclass class ScanResult: cidr: str gateway: str nodes: List[HostNode] edges: List[ConnectionEdge] generated_at: float def to_dict(self) -> dict: return { "cidr": self.cidr, "gateway": self.gateway, "generated_at": self.generated_at, "nodes": [node.to_dict() for node in self.nodes], "edges": [edge.to_dict() for edge in self.edges], } @dataclass class _HostProbe: ip: str reachable: bool dns_name: Optional[str] neighbors: Set[str] via_ssh: bool def parse_neighbor_ips(raw_output: str) -> Set[str]: ips: Set[str] = set() for line in raw_output.splitlines(): match = NEIGHBOR_PATTERN.match(line.strip()) if match: ips.add(match.group("ip")) return ips class NetworkScanner: def __init__(self, ssh_user: Optional[str] = None, ssh_key_path: Optional[str] = None): self.ssh_user = ssh_user self.ssh_key_path = ssh_key_path self._ping_concurrency = 64 self._ssh_concurrency = 10 self._ssh_semaphore: Optional[asyncio.Semaphore] = None async def scan( self, cidr: Optional[str] = None, concurrency: Optional[int] = None, ssh_timeout: float = 5.0, ) -> ScanResult: if concurrency: self._ping_concurrency = concurrency network, gateway, local_ip = self._discover_network(cidr) host_ips = list(network.hosts()) semaphore = asyncio.Semaphore(self._ping_concurrency) self._ssh_semaphore = asyncio.Semaphore(self._ssh_concurrency) probes = [] for host in host_ips: probes.append(self._probe_host(str(host), semaphore, ssh_timeout)) probe_results = await asyncio.gather(*probes) nodes: List[HostNode] = [] node_map: dict[str, HostNode] = {} edges: List[ConnectionEdge] = [] timestamp = time.time() # add gateway and scanner host nodes first gateway_node = HostNode( ip=gateway, dns_name=None, reachable=True, last_seen=timestamp, comment="default gateway", ) node_map[gateway] = gateway_node local_node = HostNode( ip=local_ip, dns_name=None, reachable=True, last_seen=timestamp, comment="this scanner", ) node_map[local_ip] = local_node final_edges: List[ConnectionEdge] = [] seen_edges: Set[Tuple[str, str, str]] = set() def append_edge(source: str, target: str, relation: str) -> None: key = (source, target, relation) if key in seen_edges: return seen_edges.add(key) final_edges.append(ConnectionEdge(source=source, target=target, relation=relation)) for probe in probe_results: if not probe.reachable: continue node = HostNode( ip=probe.ip, dns_name=probe.dns_name, reachable=True, last_seen=timestamp, via_ssh=probe.via_ssh, ) node_map[probe.ip] = node if probe.ip != local_ip: append_edge(local_ip, probe.ip, "scan") if probe.ip != gateway: append_edge(gateway, probe.ip, "gateway") for probe in probe_results: if not probe.reachable or not probe.neighbors: continue source = probe.ip for neighbor in probe.neighbors: if neighbor not in node_map: continue append_edge(source, neighbor, "neighbor") result = ScanResult( cidr=str(network), gateway=gateway, nodes=list(node_map.values()), edges=final_edges, generated_at=timestamp, ) return result async def _probe_host(self, ip: str, semaphore: asyncio.Semaphore, ssh_timeout: float) -> _HostProbe: async with semaphore: alive = await self._ping(ip) name = await asyncio.to_thread(self._resolve_dns, ip) if alive else None neighbors: Set[str] = set() via_ssh = False if alive: neighbors, via_ssh = await self._collect_neighbors(ip, ssh_timeout) return _HostProbe(ip=ip, reachable=alive, dns_name=name, neighbors=neighbors, via_ssh=via_ssh) async def _ping(self, ip: str) -> bool: process = await asyncio.create_subprocess_exec( "ping", "-c", "1", "-W", "1", ip, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL, ) await process.communicate() return process.returncode == 0 def _resolve_dns(self, ip: str) -> Optional[str]: try: return socket.gethostbyaddr(ip)[0] except OSError: return None async def _collect_neighbors(self, ip: str, ssh_timeout: float) -> Tuple[Set[str], bool]: ssh_command = ["ssh", "-o", "BatchMode=yes", "-o", f"ConnectTimeout={ssh_timeout}"] if self.ssh_key_path: ssh_command.extend(["-i", str(self.ssh_key_path)]) target = f"{self.ssh_user}@{ip}" if self.ssh_user else ip ssh_command.append(target) ssh_command.extend(["ip", "-4", "neigh", "show"]) if self._ssh_semaphore is None: self._ssh_semaphore = asyncio.Semaphore(self._ssh_concurrency) async with self._ssh_semaphore: process = await asyncio.create_subprocess_exec( *ssh_command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, ) stdout, _ = await process.communicate() if process.returncode != 0: return set(), False neighbors = parse_neighbor_ips(stdout.decode("utf-8", errors="ignore")) return neighbors, True def _discover_network(self, cidr_override: Optional[str]) -> Tuple[ipaddress.IPv4Network, str, str]: if cidr_override is not None: network = ipaddress.IPv4Network(cidr_override, strict=False) else: network = self._without_override() gateway = self._default_gateway() local_ip = self._local_ip_from_network(network) return network, gateway, local_ip def _without_override(self) -> ipaddress.IPv4Network: output = subprocess.run( ["ip", "-o", "-4", "addr", "show", "scope", "global"], capture_output=True, text=True, check=False, ) for line in output.stdout.splitlines(): parts = line.split() if len(parts) < 4: continue iface = parts[1] if iface == "lo": continue inet = parts[3] return ipaddress.IPv4Network(inet, strict=False) raise RuntimeError("Unable to determine local IPv4 network") def _default_gateway(self) -> str: output = subprocess.run( ["ip", "route", "show", "default"], capture_output=True, text=True, check=False, ) for line in output.stdout.splitlines(): parts = line.split() if "via" in parts: via_index = parts.index("via") return parts[via_index + 1] raise RuntimeError("Unable to determine default gateway") def _local_ip_from_network(self, network: ipaddress.IPv4Network) -> str: output = subprocess.run( ["ip", "-o", "-4", "addr", "show", "scope", "global"], capture_output=True, text=True, check=False, ) for line in output.stdout.splitlines(): parts = line.split() if len(parts) < 4: continue iface = parts[1] if iface == "lo": continue inet = parts[3] if ipaddress.IPv4Address(inet.split("/")[0]) in network: return inet.split("/")[0] raise RuntimeError("Unable to determine scanner IP within network")