Initial commit: Werkzeuge-Sammlung

Enthält:
- rdp_client.py: RDP Client mit GUI und Monitor-Auswahl
- rdp.sh: Bash-basierter RDP Client
- teamleader_test/: Network Scanner Fullstack-App
- teamleader_test2/: Network Mapper CLI

Subdirectories mit eigenem Repo wurden ausgeschlossen.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
root
2026-01-28 09:39:24 +01:00
commit cb073786b3
112 changed files with 23543 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
"""Network Scanner Application Package."""
__version__ = "1.0.0"
__author__ = "DevAgent"

View File

@@ -0,0 +1,13 @@
"""API router initialization."""
from fastapi import APIRouter
from app.api.endpoints import scans, hosts, topology, websocket
api_router = APIRouter()
# Include endpoint routers
api_router.include_router(scans.router, prefix="/scans", tags=["scans"])
api_router.include_router(hosts.router, prefix="/hosts", tags=["hosts"])
api_router.include_router(topology.router, prefix="/topology", tags=["topology"])
api_router.include_router(websocket.router, prefix="/ws", tags=["websocket"])

View File

@@ -0,0 +1 @@
"""API endpoints package."""

View File

@@ -0,0 +1,222 @@
"""Host API endpoints."""
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import or_
from app.database import get_db
from app.models import Host, Service
from app.schemas import HostResponse, HostDetailResponse, ServiceResponse, NetworkStatistics
from app.services.topology_service import TopologyService
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[HostResponse])
def list_hosts(
status: Optional[str] = Query(None, description="Filter by status (online/offline)"),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
search: Optional[str] = Query(None, description="Search by IP or hostname"),
db: Session = Depends(get_db)
):
"""
List discovered hosts.
Args:
status: Filter by host status
limit: Maximum number of hosts to return
offset: Number of hosts to skip
search: Search query
db: Database session
Returns:
List of hosts
"""
query = db.query(Host)
# Apply filters
if status:
query = query.filter(Host.status == status)
if search:
search_pattern = f"%{search}%"
query = query.filter(
or_(
Host.ip_address.like(search_pattern),
Host.hostname.like(search_pattern)
)
)
# Order by last seen
query = query.order_by(Host.last_seen.desc())
# Apply pagination
hosts = query.limit(limit).offset(offset).all()
return hosts
@router.get("/statistics", response_model=NetworkStatistics)
def get_network_statistics(db: Session = Depends(get_db)):
"""
Get network statistics.
Args:
db: Database session
Returns:
Network statistics
"""
topology_service = TopologyService(db)
stats = topology_service.get_network_statistics()
# Get most common services
from sqlalchemy import func
service_counts = db.query(
Service.service_name,
func.count(Service.id).label('count')
).filter(
Service.service_name.isnot(None)
).group_by(
Service.service_name
).order_by(
func.count(Service.id).desc()
).limit(10).all()
# Get last scan time
from app.models import Scan
last_scan = db.query(Scan).order_by(Scan.started_at.desc()).first()
return NetworkStatistics(
total_hosts=stats['total_hosts'],
online_hosts=stats['online_hosts'],
offline_hosts=stats['offline_hosts'],
total_services=stats['total_services'],
total_scans=db.query(func.count(Scan.id)).scalar() or 0,
last_scan=last_scan.started_at if last_scan else None,
most_common_services=[
{'service_name': s[0], 'count': s[1]}
for s in service_counts
]
)
@router.get("/by-service/{service_name}", response_model=List[HostResponse])
def get_hosts_by_service(
service_name: str,
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db)
):
"""
Get all hosts that provide a specific service.
Args:
service_name: Service name to filter by
limit: Maximum number of hosts to return
offset: Number of hosts to skip
db: Database session
Returns:
List of hosts providing the service
"""
hosts = db.query(Host).join(Service).filter(
Service.service_name == service_name
).distinct().order_by(
Host.last_seen.desc()
).limit(limit).offset(offset).all()
return hosts
@router.get("/{host_id}", response_model=HostDetailResponse)
def get_host_detail(host_id: int, db: Session = Depends(get_db)):
"""
Get detailed information about a specific host.
Args:
host_id: Host ID
db: Database session
Returns:
Detailed host information
"""
host = db.query(Host).filter(Host.id == host_id).first()
if not host:
raise HTTPException(status_code=404, detail=f"Host {host_id} not found")
return host
@router.get("/{host_id}/services", response_model=List[ServiceResponse])
def get_host_services(host_id: int, db: Session = Depends(get_db)):
"""
Get all services for a specific host.
Args:
host_id: Host ID
db: Database session
Returns:
List of services
"""
host = db.query(Host).filter(Host.id == host_id).first()
if not host:
raise HTTPException(status_code=404, detail=f"Host {host_id} not found")
return host.services
@router.delete("/{host_id}")
def delete_host(host_id: int, db: Session = Depends(get_db)):
"""
Delete a host from the database.
Args:
host_id: Host ID
db: Database session
Returns:
Success message
"""
host = db.query(Host).filter(Host.id == host_id).first()
if not host:
raise HTTPException(status_code=404, detail=f"Host {host_id} not found")
db.delete(host)
db.commit()
logger.info(f"Deleted host {host_id} ({host.ip_address})")
return {"message": f"Host {host_id} deleted successfully"}
@router.get("/ip/{ip_address}", response_model=HostResponse)
def get_host_by_ip(ip_address: str, db: Session = Depends(get_db)):
"""
Get host information by IP address.
Args:
ip_address: IP address
db: Database session
Returns:
Host information
"""
host = db.query(Host).filter(Host.ip_address == ip_address).first()
if not host:
raise HTTPException(
status_code=404,
detail=f"Host with IP {ip_address} not found"
)
return host

View File

@@ -0,0 +1,209 @@
"""Scan API endpoints."""
import asyncio
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.orm import Session
from app.database import get_db
from app.schemas import (
ScanConfigRequest,
ScanResponse,
ScanStatusResponse,
ScanStartResponse,
ScanStatus as ScanStatusEnum
)
from app.services.scan_service import ScanService
from app.api.endpoints.websocket import (
send_scan_progress,
send_host_discovered,
send_scan_completed,
send_scan_failed
)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/start", response_model=ScanStartResponse, status_code=202)
async def start_scan(
config: ScanConfigRequest,
db: Session = Depends(get_db)
):
"""
Start a new network scan.
Args:
config: Scan configuration
background_tasks: Background task handler
db: Database session
Returns:
Scan start response with scan ID
"""
try:
scan_service = ScanService(db)
# Create scan record
scan = scan_service.create_scan(config)
scan_id = scan.id
# Create progress callback for WebSocket updates
async def progress_callback(update: dict):
"""Send progress updates via WebSocket."""
update_type = update.get('type')
update_scan_id = update.get('scan_id', scan_id)
if update_type == 'scan_progress':
await send_scan_progress(update_scan_id, update.get('progress', 0), update.get('current_host'))
elif update_type == 'host_discovered':
await send_host_discovered(update_scan_id, update.get('host'))
elif update_type == 'scan_completed':
await send_scan_completed(update_scan_id, {'hosts_found': update.get('hosts_found', 0)})
elif update_type == 'scan_failed':
await send_scan_failed(update_scan_id, update.get('error', 'Unknown error'))
# Create background task wrapper that uses a new database session
async def run_scan_task():
from app.database import SessionLocal
scan_db = SessionLocal()
try:
scan_service_bg = ScanService(scan_db)
await scan_service_bg.execute_scan(scan_id, config, progress_callback)
finally:
scan_db.close()
# Create and store the task for this scan
task = asyncio.create_task(run_scan_task())
scan_service.active_scans[scan_id] = task
logger.info(f"Started scan {scan_id} for {config.network_range}")
return ScanStartResponse(
scan_id=scan_id,
message=f"Scan started for network {config.network_range}",
status=ScanStatusEnum.PENDING
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error starting scan: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to start scan")
@router.get("/{scan_id}/status", response_model=ScanStatusResponse)
def get_scan_status(scan_id: int, db: Session = Depends(get_db)):
"""
Get the status of a specific scan.
Args:
scan_id: Scan ID
db: Database session
Returns:
Scan status information
"""
scan_service = ScanService(db)
scan = scan_service.get_scan_status(scan_id)
if not scan:
raise HTTPException(status_code=404, detail=f"Scan {scan_id} not found")
# Calculate progress
progress = 0.0
if scan.status == ScanStatusEnum.COMPLETED.value:
progress = 1.0
elif scan.status == ScanStatusEnum.RUNNING.value:
# Estimate progress based on hosts found
# This is a rough estimate; real-time progress comes from WebSocket
if scan.hosts_found > 0:
progress = 0.5 # Host discovery done
return ScanStatusResponse(
id=scan.id,
started_at=scan.started_at,
completed_at=scan.completed_at,
scan_type=scan.scan_type,
network_range=scan.network_range,
status=ScanStatusEnum(scan.status),
hosts_found=scan.hosts_found,
ports_scanned=scan.ports_scanned,
error_message=scan.error_message,
progress=progress,
current_host=None,
estimated_completion=None
)
@router.get("", response_model=List[ScanResponse])
def list_scans(
limit: int = 50,
offset: int = 0,
db: Session = Depends(get_db)
):
"""
List recent scans.
Args:
limit: Maximum number of scans to return
offset: Number of scans to skip
db: Database session
Returns:
List of scans
"""
scan_service = ScanService(db)
scans = scan_service.list_scans(limit=limit, offset=offset)
return [
ScanResponse(
id=scan.id,
started_at=scan.started_at,
completed_at=scan.completed_at,
scan_type=scan.scan_type,
network_range=scan.network_range,
status=ScanStatusEnum(scan.status),
hosts_found=scan.hosts_found,
ports_scanned=scan.ports_scanned,
error_message=scan.error_message
)
for scan in scans
]
@router.delete("/{scan_id}/cancel")
def cancel_scan(scan_id: int, db: Session = Depends(get_db)):
"""
Cancel a running scan.
Args:
scan_id: Scan ID
db: Database session
Returns:
Success message
"""
scan_service = ScanService(db)
# Check if scan exists
scan = scan_service.get_scan_status(scan_id)
if not scan:
raise HTTPException(status_code=404, detail=f"Scan {scan_id} not found")
# Check if scan is running
if scan.status not in [ScanStatusEnum.PENDING.value, ScanStatusEnum.RUNNING.value]:
raise HTTPException(
status_code=400,
detail=f"Cannot cancel scan in status: {scan.status}"
)
# Attempt to cancel
success = scan_service.cancel_scan(scan_id)
if success:
return {"message": f"Scan {scan_id} cancelled successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to cancel scan")

View File

@@ -0,0 +1,70 @@
"""Topology API endpoints."""
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.schemas import TopologyResponse
from app.services.topology_service import TopologyService
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=TopologyResponse)
def get_network_topology(
include_offline: bool = Query(False, description="Include offline hosts"),
db: Session = Depends(get_db)
):
"""
Get network topology graph data.
Args:
include_offline: Whether to include offline hosts
db: Database session
Returns:
Topology data with nodes and edges
"""
try:
topology_service = TopologyService(db)
topology = topology_service.generate_topology(include_offline=include_offline)
logger.info(f"Generated topology with {len(topology.nodes)} nodes")
return topology
except Exception as e:
logger.error(f"Error generating topology: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to generate topology")
@router.get("/neighbors/{host_id}")
def get_host_neighbors(host_id: int, db: Session = Depends(get_db)):
"""
Get neighboring hosts for a specific host.
Args:
host_id: Host ID
db: Database session
Returns:
List of neighboring hosts
"""
topology_service = TopologyService(db)
neighbors = topology_service.get_host_neighbors(host_id)
return {
'host_id': host_id,
'neighbors': [
{
'id': h.id,
'ip_address': h.ip_address,
'hostname': h.hostname,
'status': h.status
}
for h in neighbors
]
}

View File

@@ -0,0 +1,222 @@
"""WebSocket endpoint for real-time updates."""
import asyncio
import json
import logging
from typing import Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from datetime import datetime
logger = logging.getLogger(__name__)
router = APIRouter()
class ConnectionManager:
"""Manager for WebSocket connections."""
def __init__(self):
"""Initialize connection manager."""
self.active_connections: Set[WebSocket] = set()
async def connect(self, websocket: WebSocket):
"""
Accept and register a new WebSocket connection.
Args:
websocket: WebSocket connection
"""
await websocket.accept()
self.active_connections.add(websocket)
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
"""
Remove a WebSocket connection.
Args:
websocket: WebSocket connection
"""
self.active_connections.discard(websocket)
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
async def send_personal_message(self, message: dict, websocket: WebSocket):
"""
Send a message to a specific WebSocket.
Args:
message: Message to send
websocket: WebSocket connection
"""
try:
await websocket.send_json(message)
except Exception as e:
logger.error(f"Error sending message: {e}")
self.disconnect(websocket)
async def broadcast(self, message: dict):
"""
Broadcast a message to all connected WebSockets.
Args:
message: Message to broadcast
"""
disconnected = set()
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception as e:
logger.error(f"Error broadcasting to connection: {e}")
disconnected.add(connection)
# Clean up disconnected clients
for connection in disconnected:
self.disconnect(connection)
# Global connection manager instance
manager = ConnectionManager()
@router.websocket("")
async def websocket_endpoint(websocket: WebSocket):
"""
WebSocket endpoint for real-time scan updates.
Args:
websocket: WebSocket connection
"""
await manager.connect(websocket)
try:
# Send welcome message
await manager.send_personal_message({
'type': 'connected',
'message': 'Connected to network scanner',
'timestamp': datetime.utcnow().isoformat()
}, websocket)
# Keep connection alive and handle incoming messages
while True:
try:
# Receive messages from client
data = await websocket.receive_text()
# Parse and handle client messages
try:
message = json.loads(data)
await handle_client_message(message, websocket)
except json.JSONDecodeError:
await manager.send_personal_message({
'type': 'error',
'message': 'Invalid JSON format',
'timestamp': datetime.utcnow().isoformat()
}, websocket)
except WebSocketDisconnect:
break
except Exception as e:
logger.error(f"Error in WebSocket loop: {e}")
break
finally:
manager.disconnect(websocket)
async def handle_client_message(message: dict, websocket: WebSocket):
"""
Handle messages from client.
Args:
message: Client message
websocket: WebSocket connection
"""
message_type = message.get('type')
if message_type == 'ping':
# Respond to ping
await manager.send_personal_message({
'type': 'pong',
'timestamp': datetime.utcnow().isoformat()
}, websocket)
elif message_type == 'subscribe':
# Handle subscription requests
scan_id = message.get('scan_id')
if scan_id:
await manager.send_personal_message({
'type': 'subscribed',
'scan_id': scan_id,
'timestamp': datetime.utcnow().isoformat()
}, websocket)
else:
logger.warning(f"Unknown message type: {message_type}")
async def broadcast_scan_update(scan_id: int, update_type: str, data: dict):
"""
Broadcast scan update to all connected clients.
Args:
scan_id: Scan ID
update_type: Type of update
data: Update data
"""
message = {
'type': update_type,
'scan_id': scan_id,
'data': data,
'timestamp': datetime.utcnow().isoformat()
}
await manager.broadcast(message)
async def send_scan_progress(scan_id: int, progress: float, current_host: str = None):
"""
Send scan progress update.
Args:
scan_id: Scan ID
progress: Progress value (0.0 to 1.0)
current_host: Currently scanning host
"""
await broadcast_scan_update(scan_id, 'scan_progress', {
'progress': progress,
'current_host': current_host
})
async def send_host_discovered(scan_id: int, host_data: dict):
"""
Send host discovered notification.
Args:
scan_id: Scan ID
host_data: Host information
"""
await broadcast_scan_update(scan_id, 'host_discovered', host_data)
async def send_scan_completed(scan_id: int, summary: dict):
"""
Send scan completed notification.
Args:
scan_id: Scan ID
summary: Scan summary
"""
await broadcast_scan_update(scan_id, 'scan_completed', summary)
async def send_scan_failed(scan_id: int, error: str):
"""
Send scan failed notification.
Args:
scan_id: Scan ID
error: Error message
"""
await broadcast_scan_update(scan_id, 'scan_failed', {'error': error})

View File

@@ -0,0 +1,43 @@
"""Configuration management for the network scanner application."""
from typing import List
from pydantic_settings import BaseSettings
from pydantic import Field
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Application
app_name: str = Field(default="Network Scanner", alias="APP_NAME")
app_version: str = Field(default="1.0.0", alias="APP_VERSION")
debug: bool = Field(default=False, alias="DEBUG")
# Database
database_url: str = Field(default="sqlite:///./network_scanner.db", alias="DATABASE_URL")
# Scanning
default_scan_timeout: int = Field(default=3, alias="DEFAULT_SCAN_TIMEOUT")
max_concurrent_scans: int = Field(default=50, alias="MAX_CONCURRENT_SCANS")
enable_nmap: bool = Field(default=True, alias="ENABLE_NMAP")
# Network
default_network_range: str = Field(default="192.168.1.0/24", alias="DEFAULT_NETWORK_RANGE")
scan_private_networks_only: bool = Field(default=True, alias="SCAN_PRIVATE_NETWORKS_ONLY")
# API
api_prefix: str = Field(default="/api", alias="API_PREFIX")
cors_origins: List[str] = Field(default=["http://localhost:3000"], alias="CORS_ORIGINS")
# Logging
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
log_file: str = Field(default="logs/network_scanner.log", alias="LOG_FILE")
class Config:
"""Pydantic configuration."""
env_file = ".env"
case_sensitive = False
# Global settings instance
settings = Settings()

View File

@@ -0,0 +1,41 @@
"""Database configuration and session management."""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from typing import Generator
from app.config import settings
# Create database engine
engine = create_engine(
settings.database_url,
connect_args={"check_same_thread": False} if "sqlite" in settings.database_url else {},
echo=settings.debug
)
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for models
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
"""
Dependency function to get database session.
Yields:
Session: SQLAlchemy database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db() -> None:
"""Initialize database tables."""
Base.metadata.create_all(bind=engine)

View File

@@ -0,0 +1,122 @@
"""SQLAlchemy database models."""
from sqlalchemy import Column, Integer, String, DateTime, Float, Text, ForeignKey, Table, JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from app.database import Base
# Association table for many-to-many relationship between scans and hosts
scan_hosts = Table(
'scan_hosts',
Base.metadata,
Column('scan_id', Integer, ForeignKey('scans.id', ondelete='CASCADE'), primary_key=True),
Column('host_id', Integer, ForeignKey('hosts.id', ondelete='CASCADE'), primary_key=True)
)
class Scan(Base):
"""Model for scan operations."""
__tablename__ = 'scans'
id = Column(Integer, primary_key=True, index=True)
started_at = Column(DateTime, nullable=False, default=datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
scan_type = Column(String(50), nullable=False, default='quick')
network_range = Column(String(100), nullable=False)
status = Column(String(20), nullable=False, default='pending')
hosts_found = Column(Integer, default=0)
ports_scanned = Column(Integer, default=0)
error_message = Column(Text, nullable=True)
# Relationships
hosts = relationship('Host', secondary=scan_hosts, back_populates='scans')
def __repr__(self) -> str:
return f"<Scan(id={self.id}, network={self.network_range}, status={self.status})>"
class Host(Base):
"""Model for discovered network hosts."""
__tablename__ = 'hosts'
id = Column(Integer, primary_key=True, index=True)
ip_address = Column(String(45), nullable=False, unique=True, index=True)
hostname = Column(String(255), nullable=True)
mac_address = Column(String(17), nullable=True)
first_seen = Column(DateTime, nullable=False, default=datetime.utcnow)
last_seen = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow)
status = Column(String(20), nullable=False, default='online', index=True)
os_guess = Column(String(255), nullable=True)
device_type = Column(String(50), nullable=True)
vendor = Column(String(255), nullable=True)
notes = Column(Text, nullable=True)
# Relationships
services = relationship('Service', back_populates='host', cascade='all, delete-orphan')
scans = relationship('Scan', secondary=scan_hosts, back_populates='hosts')
outgoing_connections = relationship(
'Connection',
foreign_keys='Connection.source_host_id',
back_populates='source_host',
cascade='all, delete-orphan'
)
incoming_connections = relationship(
'Connection',
foreign_keys='Connection.target_host_id',
back_populates='target_host',
cascade='all, delete-orphan'
)
def __repr__(self) -> str:
return f"<Host(id={self.id}, ip={self.ip_address}, hostname={self.hostname})>"
class Service(Base):
"""Model for services running on hosts (open ports)."""
__tablename__ = 'services'
id = Column(Integer, primary_key=True, index=True)
host_id = Column(Integer, ForeignKey('hosts.id', ondelete='CASCADE'), nullable=False)
port = Column(Integer, nullable=False)
protocol = Column(String(10), nullable=False, default='tcp')
state = Column(String(20), nullable=False, default='open')
service_name = Column(String(100), nullable=True)
service_version = Column(String(255), nullable=True)
banner = Column(Text, nullable=True)
first_seen = Column(DateTime, nullable=False, default=datetime.utcnow)
last_seen = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
host = relationship('Host', back_populates='services')
def __repr__(self) -> str:
return f"<Service(host_id={self.host_id}, port={self.port}, service={self.service_name})>"
class Connection(Base):
"""Model for detected connections between hosts."""
__tablename__ = 'connections'
id = Column(Integer, primary_key=True, index=True)
source_host_id = Column(Integer, ForeignKey('hosts.id', ondelete='CASCADE'), nullable=False, index=True)
target_host_id = Column(Integer, ForeignKey('hosts.id', ondelete='CASCADE'), nullable=False, index=True)
connection_type = Column(String(50), nullable=False)
protocol = Column(String(10), nullable=True)
port = Column(Integer, nullable=True)
confidence = Column(Float, nullable=False, default=1.0)
detected_at = Column(DateTime, nullable=False, default=datetime.utcnow)
last_verified = Column(DateTime, nullable=True)
extra_data = Column(JSON, nullable=True)
# Relationships
source_host = relationship('Host', foreign_keys=[source_host_id], back_populates='outgoing_connections')
target_host = relationship('Host', foreign_keys=[target_host_id], back_populates='incoming_connections')
def __repr__(self) -> str:
return f"<Connection(source={self.source_host_id}, target={self.target_host_id}, type={self.connection_type})>"

View File

@@ -0,0 +1,7 @@
"""Network scanner module."""
from app.scanner.network_scanner import NetworkScanner
from app.scanner.port_scanner import PortScanner
from app.scanner.service_detector import ServiceDetector
__all__ = ['NetworkScanner', 'PortScanner', 'ServiceDetector']

View File

@@ -0,0 +1,242 @@
"""Network scanner implementation for host discovery."""
import socket
import ipaddress
import asyncio
from typing import List, Set, Optional, Callable
from concurrent.futures import ThreadPoolExecutor
import logging
from app.config import settings
logger = logging.getLogger(__name__)
class NetworkScanner:
"""Scanner for discovering active hosts on a network."""
# Common ports for host discovery
DISCOVERY_PORTS = [21, 22, 23, 25, 80, 443, 445, 3389, 8080, 8443]
def __init__(
self,
timeout: int = None,
max_workers: int = None,
progress_callback: Optional[Callable[[str, float], None]] = None
):
"""
Initialize network scanner.
Args:
timeout: Socket connection timeout in seconds
max_workers: Maximum number of concurrent workers
progress_callback: Optional callback for progress updates
"""
self.timeout = timeout or settings.default_scan_timeout
self.max_workers = max_workers or settings.max_concurrent_scans
self.progress_callback = progress_callback
async def scan_network(self, network_range: str) -> List[str]:
"""
Scan a network range for active hosts.
Args:
network_range: Network in CIDR notation (e.g., '192.168.1.0/24')
Returns:
List of active IP addresses
"""
logger.info(f"Starting network scan of {network_range}")
try:
network = ipaddress.ip_network(network_range, strict=False)
# Validate private network if restriction enabled
if settings.scan_private_networks_only and not network.is_private:
raise ValueError(f"Network {network_range} is not a private network")
# Generate list of hosts to scan
hosts = [str(ip) for ip in network.hosts()]
total_hosts = len(hosts)
if total_hosts == 0:
# Single host network
hosts = [str(network.network_address)]
total_hosts = 1
logger.info(f"Scanning {total_hosts} hosts in {network_range}")
# Scan hosts concurrently
active_hosts = await self._scan_hosts_async(hosts)
logger.info(f"Scan completed. Found {len(active_hosts)} active hosts")
return active_hosts
except ValueError as e:
logger.error(f"Invalid network range: {e}")
raise
except Exception as e:
logger.error(f"Error during network scan: {e}")
raise
async def _scan_hosts_async(self, hosts: List[str]) -> List[str]:
"""
Scan multiple hosts asynchronously.
Args:
hosts: List of IP addresses to scan
Returns:
List of active hosts
"""
active_hosts: Set[str] = set()
total = len(hosts)
completed = 0
# Use ThreadPoolExecutor for socket operations
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for host in hosts:
future = loop.run_in_executor(executor, self._check_host, host)
futures.append((host, future))
# Process results as they complete
for host, future in futures:
try:
is_active = await future
if is_active:
active_hosts.add(host)
logger.debug(f"Host {host} is active")
except Exception as e:
logger.debug(f"Error checking host {host}: {e}")
finally:
completed += 1
if self.progress_callback:
progress = completed / total
self.progress_callback(host, progress)
return sorted(list(active_hosts), key=lambda ip: ipaddress.ip_address(ip))
def _check_host(self, ip: str) -> bool:
"""
Check if a host is active by attempting TCP connections.
Args:
ip: IP address to check
Returns:
True if host responds on any discovery port
"""
for port in self.DISCOVERY_PORTS:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
result = sock.connect_ex((ip, port))
sock.close()
if result == 0:
return True
except socket.error:
continue
except Exception as e:
logger.debug(f"Error checking {ip}:{port}: {e}")
continue
return False
def get_local_network_range(self) -> Optional[str]:
"""
Detect local network range.
Returns:
Network range in CIDR notation or None
"""
try:
import netifaces
# Get default gateway interface
gateways = netifaces.gateways()
if 'default' not in gateways or netifaces.AF_INET not in gateways['default']:
return None
default_interface = gateways['default'][netifaces.AF_INET][1]
# Get interface addresses
addrs = netifaces.ifaddresses(default_interface)
if netifaces.AF_INET not in addrs:
return None
# Get IP and netmask
inet_info = addrs[netifaces.AF_INET][0]
ip = inet_info.get('addr')
netmask = inet_info.get('netmask')
if not ip or not netmask:
return None
# Calculate network address
network = ipaddress.ip_network(f"{ip}/{netmask}", strict=False)
return str(network)
except ImportError:
logger.warning("netifaces not available, cannot detect local network")
return None
except Exception as e:
logger.error(f"Error detecting local network: {e}")
return None
def resolve_hostname(self, ip: str) -> Optional[str]:
"""
Resolve IP address to hostname.
Args:
ip: IP address
Returns:
Hostname or None
"""
try:
hostname = socket.gethostbyaddr(ip)[0]
return hostname
except socket.herror:
return None
except Exception as e:
logger.debug(f"Error resolving {ip}: {e}")
return None
def get_mac_address(self, ip: str) -> Optional[str]:
"""
Get MAC address for an IP (requires ARP access).
Args:
ip: IP address
Returns:
MAC address or None
"""
try:
# Try to get MAC from ARP cache
import subprocess
import re
# Platform-specific ARP command
import platform
if platform.system() == 'Windows':
arp_output = subprocess.check_output(['arp', '-a', ip]).decode()
mac_pattern = r'([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})'
else:
arp_output = subprocess.check_output(['arp', '-n', ip]).decode()
mac_pattern = r'([0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}'
match = re.search(mac_pattern, arp_output)
if match:
return match.group(0).upper()
return None
except Exception as e:
logger.debug(f"Error getting MAC for {ip}: {e}")
return None

View File

@@ -0,0 +1,260 @@
"""Nmap integration for advanced scanning capabilities."""
import logging
from typing import Optional, Dict, Any, List
import asyncio
logger = logging.getLogger(__name__)
class NmapScanner:
"""Wrapper for python-nmap with safe execution."""
def __init__(self):
"""Initialize nmap scanner."""
self.nmap_available = self._check_nmap_available()
if not self.nmap_available:
logger.warning("nmap is not available on this system")
def _check_nmap_available(self) -> bool:
"""
Check if nmap is available on the system.
Returns:
True if nmap is available
"""
try:
import nmap
nm = nmap.PortScanner()
nm.nmap_version()
return True
except Exception as e:
logger.debug(f"nmap not available: {e}")
return False
async def scan_host(
self,
host: str,
arguments: str = '-sT -T4'
) -> Optional[Dict[str, Any]]:
"""
Scan a host using nmap.
Args:
host: IP address or hostname
arguments: Nmap arguments (default: TCP connect scan, aggressive timing)
Returns:
Scan results dictionary or None
"""
if not self.nmap_available:
logger.warning("Attempted to use nmap but it's not available")
return None
try:
import nmap
# Run nmap scan in thread pool
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
self._run_nmap_scan,
host,
arguments
)
return result
except Exception as e:
logger.error(f"Error running nmap scan on {host}: {e}")
return None
def _run_nmap_scan(self, host: str, arguments: str) -> Optional[Dict[str, Any]]:
"""
Run nmap scan synchronously.
Args:
host: Host to scan
arguments: Nmap arguments
Returns:
Scan results
"""
try:
import nmap
nm = nmap.PortScanner()
# Sanitize host input
if not self._validate_host(host):
logger.error(f"Invalid host: {host}")
return None
# Execute scan
logger.info(f"Running nmap scan: nmap {arguments} {host}")
nm.scan(hosts=host, arguments=arguments)
# Parse results
if host not in nm.all_hosts():
logger.debug(f"No results for {host}")
return None
host_info = nm[host]
# Extract relevant information
result = {
'hostname': host_info.hostname(),
'state': host_info.state(),
'protocols': list(host_info.all_protocols()),
'ports': []
}
# Extract port information
for proto in host_info.all_protocols():
ports = host_info[proto].keys()
for port in ports:
port_info = host_info[proto][port]
result['ports'].append({
'port': port,
'protocol': proto,
'state': port_info['state'],
'service_name': port_info.get('name'),
'service_version': port_info.get('version'),
'service_product': port_info.get('product'),
'extrainfo': port_info.get('extrainfo')
})
# OS detection if available
if 'osmatch' in host_info:
result['os_matches'] = [
{
'name': os['name'],
'accuracy': os['accuracy']
}
for os in host_info['osmatch']
]
return result
except Exception as e:
logger.error(f"Error in _run_nmap_scan for {host}: {e}")
return None
def _validate_host(self, host: str) -> bool:
"""
Validate host input to prevent command injection.
Args:
host: Host string to validate
Returns:
True if valid
"""
import ipaddress
import re
# Try as IP address
try:
ipaddress.ip_address(host)
return True
except ValueError:
pass
# Try as network range
try:
ipaddress.ip_network(host, strict=False)
return True
except ValueError:
pass
# Try as hostname (alphanumeric, dots, hyphens only)
if re.match(r'^[a-zA-Z0-9.-]+$', host):
return True
return False
def get_scan_arguments(
self,
scan_type: str,
service_detection: bool = True,
os_detection: bool = False,
port_range: Optional[str] = None
) -> str:
"""
Generate nmap arguments based on scan configuration.
Args:
scan_type: Type of scan ('quick', 'standard', 'deep')
service_detection: Enable service/version detection
os_detection: Enable OS detection (requires root)
port_range: Custom port range (e.g., '1-1000' or '80,443,8080')
Returns:
Nmap argument string
"""
args = []
# Use TCP connect scan (no root required)
args.append('-sT')
# Port specification
if port_range:
args.append(f'-p {port_range}')
elif scan_type == 'quick':
args.append('--top-ports 100')
elif scan_type == 'standard':
args.append('--top-ports 1000')
elif scan_type == 'deep':
args.append('-p-') # All ports
# Only show open ports
args.append('--open')
# Timing
if scan_type == 'quick':
args.append('-T5') # Insane
elif scan_type == 'deep':
args.append('-T3') # Normal
else:
args.append('-T4') # Aggressive
# Service detection
if service_detection:
args.append('-sV')
# OS detection (requires root)
if os_detection:
args.append('-O')
logger.warning("OS detection requires root privileges")
return ' '.join(args)
async def scan_network_with_nmap(
self,
network: str,
scan_type: str = 'quick'
) -> List[Dict[str, Any]]:
"""
Scan entire network using nmap.
Args:
network: Network in CIDR notation
scan_type: Type of scan
Returns:
List of host results
"""
if not self.nmap_available:
return []
try:
arguments = self.get_scan_arguments(scan_type)
result = await self.scan_host(network, arguments)
if result:
return [result]
return []
except Exception as e:
logger.error(f"Error scanning network {network}: {e}")
return []

View File

@@ -0,0 +1,213 @@
"""Port scanner implementation."""
import socket
import asyncio
from typing import List, Dict, Set, Optional, Callable
from concurrent.futures import ThreadPoolExecutor
import logging
from app.config import settings
logger = logging.getLogger(__name__)
class PortScanner:
"""Scanner for detecting open ports on hosts."""
# Predefined port ranges for different scan types
PORT_RANGES = {
'quick': [21, 22, 23, 25, 53, 80, 110, 143, 443, 445, 3306, 3389, 5432, 8080, 8443],
'standard': list(range(1, 1001)),
'deep': list(range(1, 65536)),
}
def __init__(
self,
timeout: int = None,
max_workers: int = None,
progress_callback: Optional[Callable[[str, int, float], None]] = None
):
"""
Initialize port scanner.
Args:
timeout: Socket connection timeout in seconds
max_workers: Maximum number of concurrent workers
progress_callback: Optional callback for progress updates (host, port, progress)
"""
self.timeout = timeout or settings.default_scan_timeout
self.max_workers = max_workers or settings.max_concurrent_scans
self.progress_callback = progress_callback
async def scan_host_ports(
self,
host: str,
scan_type: str = 'quick',
custom_ports: Optional[List[int]] = None
) -> List[Dict[str, any]]:
"""
Scan ports on a single host.
Args:
host: IP address or hostname
scan_type: Type of scan ('quick', 'standard', 'deep', or 'custom')
custom_ports: Custom port list (required if scan_type is 'custom')
Returns:
List of dictionaries with port information
"""
logger.info(f"Starting port scan on {host} (type: {scan_type})")
# Determine ports to scan
if scan_type == 'custom' and custom_ports:
ports = custom_ports
elif scan_type in self.PORT_RANGES:
ports = self.PORT_RANGES[scan_type]
else:
ports = self.PORT_RANGES['quick']
# Scan ports
open_ports = await self._scan_ports_async(host, ports)
logger.info(f"Scan completed on {host}. Found {len(open_ports)} open ports")
return open_ports
async def _scan_ports_async(self, host: str, ports: List[int]) -> List[Dict[str, any]]:
"""
Scan multiple ports asynchronously.
Args:
host: Host to scan
ports: List of ports to scan
Returns:
List of open port information
"""
open_ports = []
total = len(ports)
completed = 0
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for port in ports:
future = loop.run_in_executor(executor, self._check_port, host, port)
futures.append((port, future))
# Process results
for port, future in futures:
try:
result = await future
if result:
open_ports.append(result)
logger.debug(f"Found open port {port} on {host}")
except Exception as e:
logger.debug(f"Error checking port {port} on {host}: {e}")
finally:
completed += 1
if self.progress_callback:
progress = completed / total
self.progress_callback(host, port, progress)
return sorted(open_ports, key=lambda x: x['port'])
def _check_port(self, host: str, port: int) -> Optional[Dict[str, any]]:
"""
Check if a port is open on a host.
Args:
host: Host to check
port: Port number
Returns:
Dictionary with port info if open, None otherwise
"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
result = sock.connect_ex((host, port))
sock.close()
if result == 0:
return {
'port': port,
'protocol': 'tcp',
'state': 'open',
'service_name': self._guess_service_name(port)
}
return None
except socket.error as e:
logger.debug(f"Socket error checking {host}:{port}: {e}")
return None
except Exception as e:
logger.debug(f"Error checking {host}:{port}: {e}")
return None
def _guess_service_name(self, port: int) -> Optional[str]:
"""
Guess service name based on well-known ports.
Args:
port: Port number
Returns:
Service name or None
"""
common_services = {
20: 'ftp-data',
21: 'ftp',
22: 'ssh',
23: 'telnet',
25: 'smtp',
53: 'dns',
80: 'http',
110: 'pop3',
143: 'imap',
443: 'https',
445: 'smb',
3306: 'mysql',
3389: 'rdp',
5432: 'postgresql',
5900: 'vnc',
8080: 'http-alt',
8443: 'https-alt',
}
return common_services.get(port)
def parse_port_range(self, port_range: str) -> List[int]:
"""
Parse port range string to list of ports.
Args:
port_range: String like "80,443,8000-8100"
Returns:
List of port numbers
"""
ports = set()
try:
for part in port_range.split(','):
part = part.strip()
if '-' in part:
# Range like "8000-8100"
start, end = map(int, part.split('-'))
if 1 <= start <= end <= 65535:
ports.update(range(start, end + 1))
else:
# Single port
port = int(part)
if 1 <= port <= 65535:
ports.add(port)
return sorted(list(ports))
except ValueError as e:
logger.error(f"Error parsing port range '{port_range}': {e}")
return []

View File

@@ -0,0 +1,250 @@
"""Service detection and banner grabbing implementation."""
import socket
import logging
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class ServiceDetector:
"""Detector for identifying services running on open ports."""
def __init__(self, timeout: int = 3):
"""
Initialize service detector.
Args:
timeout: Socket timeout in seconds
"""
self.timeout = timeout
def detect_service(self, host: str, port: int) -> Dict[str, Any]:
"""
Detect service on a specific port.
Args:
host: Host IP or hostname
port: Port number
Returns:
Dictionary with service information
"""
service_info = {
'port': port,
'protocol': 'tcp',
'service_name': None,
'service_version': None,
'banner': None
}
# Try banner grabbing
banner = self.grab_banner(host, port)
if banner:
service_info['banner'] = banner
# Try to identify service from banner
service_name, version = self._identify_from_banner(banner, port)
if service_name:
service_info['service_name'] = service_name
if version:
service_info['service_version'] = version
# If no banner, use port-based guess
if not service_info['service_name']:
service_info['service_name'] = self._guess_service_from_port(port)
return service_info
def grab_banner(self, host: str, port: int) -> Optional[str]:
"""
Attempt to grab service banner.
Args:
host: Host IP or hostname
port: Port number
Returns:
Banner string or None
"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect((host, port))
# Try to receive banner
try:
banner = sock.recv(1024)
banner_str = banner.decode('utf-8', errors='ignore').strip()
sock.close()
if banner_str:
logger.debug(f"Got banner from {host}:{port}: {banner_str[:100]}")
return banner_str
except socket.timeout:
# Try sending a probe for services that need it
banner_str = self._probe_service(sock, port)
sock.close()
return banner_str
except Exception as e:
logger.debug(f"Error grabbing banner from {host}:{port}: {e}")
return None
def _probe_service(self, sock: socket.socket, port: int) -> Optional[str]:
"""
Send service-specific probe to elicit response.
Args:
sock: Connected socket
port: Port number
Returns:
Response string or None
"""
probes = {
80: b"GET / HTTP/1.0\r\n\r\n",
443: b"GET / HTTP/1.0\r\n\r\n",
8080: b"GET / HTTP/1.0\r\n\r\n",
8443: b"GET / HTTP/1.0\r\n\r\n",
25: b"EHLO test\r\n",
110: b"USER test\r\n",
143: b"A001 CAPABILITY\r\n",
}
probe = probes.get(port, b"\r\n")
try:
sock.send(probe)
response = sock.recv(1024)
return response.decode('utf-8', errors='ignore').strip()
except:
return None
def _identify_from_banner(self, banner: str, port: int) -> tuple[Optional[str], Optional[str]]:
"""
Identify service and version from banner.
Args:
banner: Banner string
port: Port number
Returns:
Tuple of (service_name, version)
"""
banner_lower = banner.lower()
# HTTP servers
if 'http' in banner_lower or port in [80, 443, 8080, 8443]:
if 'apache' in banner_lower:
return self._extract_apache_version(banner)
elif 'nginx' in banner_lower:
return self._extract_nginx_version(banner)
elif 'iis' in banner_lower or 'microsoft' in banner_lower:
return 'IIS', None
else:
return 'HTTP', None
# SSH
if 'ssh' in banner_lower or port == 22:
if 'openssh' in banner_lower:
return self._extract_openssh_version(banner)
return 'SSH', None
# FTP
if 'ftp' in banner_lower or port in [20, 21]:
if 'filezilla' in banner_lower:
return 'FileZilla FTP', None
elif 'proftpd' in banner_lower:
return 'ProFTPD', None
return 'FTP', None
# SMTP
if 'smtp' in banner_lower or 'mail' in banner_lower or port == 25:
if 'postfix' in banner_lower:
return 'Postfix', None
elif 'exim' in banner_lower:
return 'Exim', None
return 'SMTP', None
# MySQL
if 'mysql' in banner_lower or port == 3306:
return 'MySQL', None
# PostgreSQL
if 'postgresql' in banner_lower or port == 5432:
return 'PostgreSQL', None
# Generic identification
if port == 22:
return 'SSH', None
elif port in [80, 8080]:
return 'HTTP', None
elif port in [443, 8443]:
return 'HTTPS', None
return None, None
def _extract_apache_version(self, banner: str) -> tuple[str, Optional[str]]:
"""Extract Apache version from banner."""
import re
match = re.search(r'Apache/?([\d.]+)?', banner, re.IGNORECASE)
if match:
version = match.group(1)
return 'Apache', version
return 'Apache', None
def _extract_nginx_version(self, banner: str) -> tuple[str, Optional[str]]:
"""Extract nginx version from banner."""
import re
match = re.search(r'nginx/?([\d.]+)?', banner, re.IGNORECASE)
if match:
version = match.group(1)
return 'nginx', version
return 'nginx', None
def _extract_openssh_version(self, banner: str) -> tuple[str, Optional[str]]:
"""Extract OpenSSH version from banner."""
import re
match = re.search(r'OpenSSH[_/]?([\d.]+\w*)?', banner, re.IGNORECASE)
if match:
version = match.group(1)
return 'OpenSSH', version
return 'OpenSSH', None
def _guess_service_from_port(self, port: int) -> Optional[str]:
"""
Guess service name from well-known port number.
Args:
port: Port number
Returns:
Service name or None
"""
common_services = {
20: 'ftp-data',
21: 'ftp',
22: 'ssh',
23: 'telnet',
25: 'smtp',
53: 'dns',
80: 'http',
110: 'pop3',
143: 'imap',
443: 'https',
445: 'smb',
993: 'imaps',
995: 'pop3s',
3306: 'mysql',
3389: 'rdp',
5432: 'postgresql',
5900: 'vnc',
6379: 'redis',
8080: 'http-alt',
8443: 'https-alt',
27017: 'mongodb',
}
return common_services.get(port)

View File

@@ -0,0 +1,256 @@
"""Pydantic schemas for API request/response validation."""
from pydantic import BaseModel, Field, IPvAnyAddress, field_validator
from typing import Optional, List, Dict, Any
from datetime import datetime
from enum import Enum
class ScanType(str, Enum):
"""Scan type enumeration."""
QUICK = "quick"
STANDARD = "standard"
DEEP = "deep"
CUSTOM = "custom"
class ScanStatus(str, Enum):
"""Scan status enumeration."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class HostStatus(str, Enum):
"""Host status enumeration."""
ONLINE = "online"
OFFLINE = "offline"
SCANNING = "scanning"
class ConnectionType(str, Enum):
"""Connection type enumeration."""
GATEWAY = "gateway"
SAME_SUBNET = "same_subnet"
SERVICE = "service"
INFERRED = "inferred"
# Service schemas
class ServiceBase(BaseModel):
"""Base service schema."""
port: int = Field(..., ge=1, le=65535)
protocol: str = Field(default="tcp", pattern="^(tcp|udp)$")
state: str = Field(default="open")
service_name: Optional[str] = None
service_version: Optional[str] = None
banner: Optional[str] = None
class ServiceCreate(ServiceBase):
"""Schema for creating a service."""
host_id: int
class ServiceResponse(ServiceBase):
"""Schema for service response."""
id: int
host_id: int
first_seen: datetime
last_seen: datetime
class Config:
from_attributes = True
# Host schemas
class HostBase(BaseModel):
"""Base host schema."""
ip_address: str
hostname: Optional[str] = None
mac_address: Optional[str] = None
@field_validator('ip_address')
@classmethod
def validate_ip(cls, v: str) -> str:
"""Validate IP address format."""
import ipaddress
try:
ipaddress.ip_address(v)
return v
except ValueError:
raise ValueError(f"Invalid IP address: {v}")
class HostCreate(HostBase):
"""Schema for creating a host."""
device_type: Optional[str] = None
os_guess: Optional[str] = None
vendor: Optional[str] = None
class HostResponse(HostBase):
"""Schema for host response."""
id: int
first_seen: datetime
last_seen: datetime
status: HostStatus
device_type: Optional[str] = None
os_guess: Optional[str] = None
vendor: Optional[str] = None
notes: Optional[str] = None
services: List[ServiceResponse] = []
class Config:
from_attributes = True
class HostDetailResponse(HostResponse):
"""Detailed host response with connection info."""
outgoing_connections: List['ConnectionResponse'] = []
incoming_connections: List['ConnectionResponse'] = []
# Connection schemas
class ConnectionBase(BaseModel):
"""Base connection schema."""
source_host_id: int
target_host_id: int
connection_type: ConnectionType
protocol: Optional[str] = None
port: Optional[int] = None
confidence: float = Field(default=1.0, ge=0.0, le=1.0)
class ConnectionCreate(ConnectionBase):
"""Schema for creating a connection."""
metadata: Optional[Dict[str, Any]] = None
class ConnectionResponse(ConnectionBase):
"""Schema for connection response."""
id: int
detected_at: datetime
last_verified: Optional[datetime] = None
metadata: Optional[Dict[str, Any]] = None
class Config:
from_attributes = True
# Scan schemas
class ScanConfigRequest(BaseModel):
"""Schema for scan configuration request."""
network_range: str
scan_type: ScanType = Field(default=ScanType.QUICK)
port_range: Optional[str] = None
include_service_detection: bool = True
use_nmap: bool = True
@field_validator('network_range')
@classmethod
def validate_network(cls, v: str) -> str:
"""Validate network range format."""
import ipaddress
try:
network = ipaddress.ip_network(v, strict=False)
# Check if it's a private network
if not network.is_private:
raise ValueError("Only private network ranges are allowed")
return v
except ValueError as e:
raise ValueError(f"Invalid network range: {e}")
class ScanResponse(BaseModel):
"""Schema for scan response."""
id: int
started_at: datetime
completed_at: Optional[datetime] = None
scan_type: str
network_range: str
status: ScanStatus
hosts_found: int = 0
ports_scanned: int = 0
error_message: Optional[str] = None
class Config:
from_attributes = True
class ScanStatusResponse(ScanResponse):
"""Schema for detailed scan status response."""
progress: float = Field(default=0.0, ge=0.0, le=1.0)
current_host: Optional[str] = None
estimated_completion: Optional[datetime] = None
class ScanStartResponse(BaseModel):
"""Schema for scan start response."""
scan_id: int
message: str
status: ScanStatus
# Topology schemas
class TopologyNode(BaseModel):
"""Schema for topology graph node."""
id: str
ip: str
hostname: Optional[str]
type: str
status: str
service_count: int
connections: int = 0
class TopologyEdge(BaseModel):
"""Schema for topology graph edge."""
source: str
target: str
type: str = "default"
confidence: float = 0.5
class TopologyResponse(BaseModel):
"""Schema for topology graph response."""
nodes: List[TopologyNode]
edges: List[TopologyEdge]
statistics: Dict[str, Any] = Field(default_factory=dict)
# WebSocket message schemas
class WSMessageType(str, Enum):
"""WebSocket message type enumeration."""
SCAN_STARTED = "scan_started"
SCAN_PROGRESS = "scan_progress"
HOST_DISCOVERED = "host_discovered"
SERVICE_DISCOVERED = "service_discovered"
SCAN_COMPLETED = "scan_completed"
SCAN_FAILED = "scan_failed"
ERROR = "error"
class WSMessage(BaseModel):
"""Schema for WebSocket messages."""
type: WSMessageType
data: Dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.utcnow)
# Statistics schemas
class NetworkStatistics(BaseModel):
"""Schema for network statistics."""
total_hosts: int
online_hosts: int
offline_hosts: int
total_services: int
total_scans: int
last_scan: Optional[datetime] = None
most_common_services: List[Dict[str, Any]] = []
# Rebuild models to resolve forward references
HostDetailResponse.model_rebuild()

View File

@@ -0,0 +1,6 @@
"""Business logic services."""
from app.services.scan_service import ScanService
from app.services.topology_service import TopologyService
__all__ = ['ScanService', 'TopologyService']

View File

@@ -0,0 +1,553 @@
"""Scan service for orchestrating network scanning operations."""
import asyncio
import logging
from datetime import datetime
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from app.models import Scan, Host, Service, Connection
from app.schemas import ScanConfigRequest, ScanStatus as ScanStatusEnum
from app.scanner.network_scanner import NetworkScanner
from app.scanner.port_scanner import PortScanner
from app.scanner.service_detector import ServiceDetector
from app.scanner.nmap_scanner import NmapScanner
from app.config import settings
logger = logging.getLogger(__name__)
class ScanService:
"""Service for managing network scans."""
def __init__(self, db: Session):
"""
Initialize scan service.
Args:
db: Database session
"""
self.db = db
self.active_scans: Dict[int, asyncio.Task] = {}
self.cancel_requested: Dict[int, bool] = {}
def create_scan(self, config: ScanConfigRequest) -> Scan:
"""
Create a new scan record.
Args:
config: Scan configuration
Returns:
Created scan object
"""
scan = Scan(
scan_type=config.scan_type.value,
network_range=config.network_range,
status=ScanStatusEnum.PENDING.value,
started_at=datetime.utcnow()
)
self.db.add(scan)
self.db.commit()
self.db.refresh(scan)
logger.info(f"Created scan {scan.id} for {config.network_range}")
return scan
def cancel_scan(self, scan_id: int) -> bool:
"""
Cancel a running scan.
Args:
scan_id: Scan ID to cancel
Returns:
True if scan was cancelled, False if not found or not running
"""
try:
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
logger.warning(f"Scan {scan_id} not found")
return False
if scan.status not in [ScanStatusEnum.PENDING.value, ScanStatusEnum.RUNNING.value]:
logger.warning(f"Scan {scan_id} is not running (status: {scan.status})")
return False
# Mark for cancellation
self.cancel_requested[scan_id] = True
# Cancel the task if it exists
if scan_id in self.active_scans:
task = self.active_scans[scan_id]
task.cancel()
del self.active_scans[scan_id]
# Update scan status
scan.status = ScanStatusEnum.CANCELLED.value
scan.completed_at = datetime.utcnow()
self.db.commit()
logger.info(f"Cancelled scan {scan_id}")
return True
except Exception as e:
logger.error(f"Error cancelling scan {scan_id}: {e}")
self.db.rollback()
return False
async def execute_scan(
self,
scan_id: int,
config: ScanConfigRequest,
progress_callback: Optional[callable] = None
) -> None:
"""
Execute a network scan.
Args:
scan_id: Scan ID
config: Scan configuration
progress_callback: Optional callback for progress updates
"""
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
logger.error(f"Scan {scan_id} not found")
return
try:
# Initialize cancellation flag
self.cancel_requested[scan_id] = False
# Update scan status
scan.status = ScanStatusEnum.RUNNING.value
self.db.commit()
logger.info(f"Starting scan {scan_id}")
# Check for cancellation
if self.cancel_requested.get(scan_id):
raise asyncio.CancelledError("Scan cancelled by user")
# Initialize scanners
network_scanner = NetworkScanner(
progress_callback=lambda host, progress: self._on_host_progress(
scan_id, host, progress, progress_callback
)
)
# Phase 1: Host Discovery
logger.info(f"Phase 1: Discovering hosts in {config.network_range}")
active_hosts = await network_scanner.scan_network(config.network_range)
scan.hosts_found = len(active_hosts)
self.db.commit()
logger.info(f"Found {len(active_hosts)} active hosts")
# Check for cancellation
if self.cancel_requested.get(scan_id):
raise asyncio.CancelledError("Scan cancelled by user")
# Send progress update
if progress_callback:
await progress_callback({
'type': 'scan_progress',
'scan_id': scan_id,
'progress': 0.3,
'current_host': f"Found {len(active_hosts)} hosts"
})
# Phase 2: Port Scanning and Service Detection
if config.use_nmap and settings.enable_nmap:
await self._scan_with_nmap(scan, active_hosts, config, progress_callback)
else:
await self._scan_with_socket(scan, active_hosts, config, progress_callback)
# Phase 3: Detect Connections
await self._detect_connections(scan, network_scanner)
# Mark scan as completed
scan.status = ScanStatusEnum.COMPLETED.value
scan.completed_at = datetime.utcnow()
self.db.commit()
logger.info(f"Scan {scan_id} completed successfully")
if progress_callback:
await progress_callback({
'type': 'scan_completed',
'scan_id': scan_id,
'hosts_found': scan.hosts_found
})
except asyncio.CancelledError:
logger.info(f"Scan {scan_id} was cancelled")
scan.status = ScanStatusEnum.CANCELLED.value
scan.completed_at = datetime.utcnow()
self.db.commit()
if progress_callback:
await progress_callback({
'type': 'scan_completed',
'scan_id': scan_id,
'hosts_found': scan.hosts_found
})
except Exception as e:
logger.error(f"Error executing scan {scan_id}: {e}", exc_info=True)
scan.status = ScanStatusEnum.FAILED.value
scan.error_message = str(e)
scan.completed_at = datetime.utcnow()
self.db.commit()
if progress_callback:
await progress_callback({
'type': 'scan_failed',
'scan_id': scan_id,
'error': str(e)
})
finally:
# Cleanup
self.cancel_requested.pop(scan_id, None)
self.active_scans.pop(scan_id, None)
async def _scan_with_socket(
self,
scan: Scan,
hosts: list,
config: ScanConfigRequest,
progress_callback: Optional[callable]
) -> None:
"""Scan hosts using socket-based scanning."""
port_scanner = PortScanner(
progress_callback=lambda host, port, progress: self._on_port_progress(
scan.id, host, port, progress, progress_callback
)
)
service_detector = ServiceDetector()
for idx, ip in enumerate(hosts, 1):
try:
# Check for cancellation
if self.cancel_requested.get(scan.id):
logger.info(f"Scan {scan.id} cancelled during port scanning")
raise asyncio.CancelledError("Scan cancelled by user")
logger.info(f"Scanning host {idx}/{len(hosts)}: {ip}")
# Get or create host
host = self._get_or_create_host(ip)
self.db.commit() # Commit to ensure host.id is set
self.db.refresh(host)
# Send host discovered notification
if progress_callback:
await progress_callback({
'type': 'host_discovered',
'scan_id': scan.id,
'host': {
'ip_address': ip,
'status': 'online'
}
})
# Scan ports
custom_ports = None
if config.port_range:
custom_ports = port_scanner.parse_port_range(config.port_range)
open_ports = await port_scanner.scan_host_ports(
ip,
scan_type=config.scan_type.value,
custom_ports=custom_ports
)
scan.ports_scanned += len(open_ports)
# Detect services
if config.include_service_detection:
for port_info in open_ports:
service_info = service_detector.detect_service(ip, port_info['port'])
port_info.update(service_info)
# Store services
self._store_services(host, open_ports)
# Associate host with scan
if host not in scan.hosts:
scan.hosts.append(host)
self.db.commit()
# Send progress update
if progress_callback:
progress = 0.3 + (0.6 * (idx / len(hosts))) # 30-90% for port scanning
await progress_callback({
'type': 'scan_progress',
'scan_id': scan.id,
'progress': progress,
'current_host': f"Scanning {ip} ({idx}/{len(hosts)})"
})
except Exception as e:
logger.error(f"Error scanning host {ip}: {e}")
continue
async def _scan_with_nmap(
self,
scan: Scan,
hosts: list,
config: ScanConfigRequest,
progress_callback: Optional[callable]
) -> None:
"""Scan hosts using nmap."""
nmap_scanner = NmapScanner()
if not nmap_scanner.nmap_available:
logger.warning("Nmap not available, falling back to socket scanning")
await self._scan_with_socket(scan, hosts, config, progress_callback)
return
# Scan each host with nmap
for idx, ip in enumerate(hosts, 1):
try:
logger.info(f"Scanning host {idx}/{len(hosts)} with nmap: {ip}")
# Get or create host
host = self._get_or_create_host(ip)
self.db.commit() # Commit to ensure host.id is set
self.db.refresh(host)
# Build nmap arguments
port_range = config.port_range if config.port_range else None
arguments = nmap_scanner.get_scan_arguments(
scan_type=config.scan_type.value,
service_detection=config.include_service_detection,
port_range=port_range
)
# Execute nmap scan
result = await nmap_scanner.scan_host(ip, arguments)
if result:
# Update hostname if available
if result.get('hostname'):
host.hostname = result['hostname']
# Store services
if result.get('ports'):
self._store_services(host, result['ports'])
scan.ports_scanned += len(result['ports'])
# Store OS information
if result.get('os_matches'):
best_match = max(result['os_matches'], key=lambda x: float(x['accuracy']))
host.os_guess = best_match['name']
# Associate host with scan
if host not in scan.hosts:
scan.hosts.append(host)
self.db.commit()
except Exception as e:
logger.error(f"Error scanning host {ip} with nmap: {e}")
continue
def _get_or_create_host(self, ip: str) -> Host:
"""Get existing host or create new one."""
host = self.db.query(Host).filter(Host.ip_address == ip).first()
if host:
host.last_seen = datetime.utcnow()
host.status = 'online'
else:
host = Host(
ip_address=ip,
status='online',
first_seen=datetime.utcnow(),
last_seen=datetime.utcnow()
)
self.db.add(host)
return host
def _store_services(self, host: Host, services_data: list) -> None:
"""Store or update services for a host."""
for service_info in services_data:
# Check if service already exists
service = self.db.query(Service).filter(
Service.host_id == host.id,
Service.port == service_info['port'],
Service.protocol == service_info.get('protocol', 'tcp')
).first()
if service:
# Update existing service
service.last_seen = datetime.utcnow()
service.state = service_info.get('state', 'open')
if service_info.get('service_name'):
service.service_name = service_info['service_name']
if service_info.get('service_version'):
service.service_version = service_info['service_version']
if service_info.get('banner'):
service.banner = service_info['banner']
else:
# Create new service
service = Service(
host_id=host.id,
port=service_info['port'],
protocol=service_info.get('protocol', 'tcp'),
state=service_info.get('state', 'open'),
service_name=service_info.get('service_name'),
service_version=service_info.get('service_version'),
banner=service_info.get('banner'),
first_seen=datetime.utcnow(),
last_seen=datetime.utcnow()
)
self.db.add(service)
async def _detect_connections(self, scan: Scan, network_scanner: NetworkScanner) -> None:
"""Detect connections between hosts."""
try:
# Get gateway
gateway_ip = network_scanner.get_local_network_range()
if gateway_ip:
gateway_network = gateway_ip.split('/')[0].rsplit('.', 1)[0] + '.1'
# Find or create gateway host
gateway_host = self.db.query(Host).filter(
Host.ip_address == gateway_network
).first()
if gateway_host:
# Connect all hosts to gateway
for host in scan.hosts:
if host.id != gateway_host.id:
self._create_connection(
host.id,
gateway_host.id,
'gateway',
confidence=0.9
)
# Create connections based on services
for host in scan.hosts:
for service in host.services:
# If host has client-type services, it might connect to servers
if service.service_name in ['http', 'https', 'ssh']:
# Find potential servers on the network
for other_host in scan.hosts:
if other_host.id != host.id:
for other_service in other_host.services:
if (other_service.port == service.port and
other_service.service_name in ['http', 'https', 'ssh']):
self._create_connection(
host.id,
other_host.id,
'service',
protocol='tcp',
port=service.port,
confidence=0.5
)
self.db.commit()
except Exception as e:
logger.error(f"Error detecting connections: {e}")
def _create_connection(
self,
source_id: int,
target_id: int,
conn_type: str,
protocol: Optional[str] = None,
port: Optional[int] = None,
confidence: float = 1.0
) -> None:
"""Create a connection if it doesn't exist."""
existing = self.db.query(Connection).filter(
Connection.source_host_id == source_id,
Connection.target_host_id == target_id,
Connection.connection_type == conn_type
).first()
if not existing:
connection = Connection(
source_host_id=source_id,
target_host_id=target_id,
connection_type=conn_type,
protocol=protocol,
port=port,
confidence=confidence,
detected_at=datetime.utcnow()
)
self.db.add(connection)
def _on_host_progress(
self,
scan_id: int,
host: str,
progress: float,
callback: Optional[callable]
) -> None:
"""Handle host discovery progress."""
if callback:
asyncio.create_task(callback({
'type': 'scan_progress',
'scan_id': scan_id,
'current_host': host,
'progress': progress * 0.5 # Host discovery is first 50%
}))
def _on_port_progress(
self,
scan_id: int,
host: str,
port: int,
progress: float,
callback: Optional[callable]
) -> None:
"""Handle port scanning progress."""
if callback:
asyncio.create_task(callback({
'type': 'scan_progress',
'scan_id': scan_id,
'current_host': host,
'current_port': port,
'progress': 0.5 + (progress * 0.5) # Port scanning is second 50%
}))
def get_scan_status(self, scan_id: int) -> Optional[Scan]:
"""Get scan status by ID."""
return self.db.query(Scan).filter(Scan.id == scan_id).first()
def list_scans(self, limit: int = 50, offset: int = 0) -> list:
"""List recent scans."""
return self.db.query(Scan)\
.order_by(Scan.started_at.desc())\
.limit(limit)\
.offset(offset)\
.all()
def cancel_scan(self, scan_id: int) -> bool:
"""Cancel a running scan."""
if scan_id in self.active_scans:
task = self.active_scans[scan_id]
task.cancel()
del self.active_scans[scan_id]
scan = self.get_scan_status(scan_id)
if scan:
scan.status = ScanStatusEnum.CANCELLED.value
scan.completed_at = datetime.utcnow()
self.db.commit()
return True
return False

View File

@@ -0,0 +1,256 @@
"""Topology service for network graph generation."""
import logging
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.models import Host, Service, Connection
from app.schemas import TopologyNode, TopologyEdge, TopologyResponse
logger = logging.getLogger(__name__)
class TopologyService:
"""Service for generating network topology graphs."""
# Node type colors
NODE_COLORS = {
'gateway': '#FF6B6B',
'server': '#4ECDC4',
'workstation': '#45B7D1',
'device': '#96CEB4',
'unknown': '#95A5A6'
}
def __init__(self, db: Session):
"""
Initialize topology service.
Args:
db: Database session
"""
self.db = db
def generate_topology(self, include_offline: bool = False) -> TopologyResponse:
"""
Generate network topology graph.
Args:
include_offline: Include offline hosts
Returns:
Topology response with nodes and edges
"""
logger.info("Generating network topology")
# Get hosts
query = self.db.query(Host)
if not include_offline:
query = query.filter(Host.status == 'online')
hosts = query.all()
# Generate nodes
nodes = []
for host in hosts:
node = self._create_node(host)
nodes.append(node)
# Generate edges from connections
edges = []
connections = self.db.query(Connection).all()
for conn in connections:
# Only include edges if both hosts are in the topology
source_in_topology = any(n.id == str(conn.source_host_id) for n in nodes)
target_in_topology = any(n.id == str(conn.target_host_id) for n in nodes)
if source_in_topology and target_in_topology:
edge = self._create_edge(conn)
edges.append(edge)
# Generate statistics
statistics = self._generate_statistics(hosts, connections)
logger.info(f"Generated topology with {len(nodes)} nodes and {len(edges)} edges")
return TopologyResponse(
nodes=nodes,
edges=edges,
statistics=statistics
)
def _create_node(self, host: Host) -> TopologyNode:
"""
Create a topology node from a host.
Args:
host: Host model
Returns:
TopologyNode
"""
# Determine device type
device_type = self._determine_device_type(host)
# Count connections
connections = self.db.query(Connection).filter(
(Connection.source_host_id == host.id) |
(Connection.target_host_id == host.id)
).count()
return TopologyNode(
id=str(host.id),
ip=host.ip_address,
hostname=host.hostname,
type=device_type,
status=host.status,
service_count=len(host.services),
connections=connections
)
def _determine_device_type(self, host: Host) -> str:
"""
Determine device type based on host information.
Args:
host: Host model
Returns:
Device type string
"""
# Check if explicitly set
if host.device_type:
return host.device_type
# Infer from services
service_names = [s.service_name for s in host.services if s.service_name]
# Check for gateway indicators
if any(s.port == 53 for s in host.services): # DNS server
return 'gateway'
# Check for server indicators
server_services = ['http', 'https', 'ssh', 'smtp', 'mysql', 'postgresql', 'ftp']
if any(svc in service_names for svc in server_services):
if len(host.services) > 5:
return 'server'
# Check for workstation indicators
if any(s.port == 3389 for s in host.services): # RDP
return 'workstation'
# Default to device
if len(host.services) > 0:
return 'device'
return 'unknown'
def _create_edge(self, connection: Connection) -> TopologyEdge:
"""
Create a topology edge from a connection.
Args:
connection: Connection model
Returns:
TopologyEdge
"""
return TopologyEdge(
source=str(connection.source_host_id),
target=str(connection.target_host_id),
type=connection.connection_type or 'default',
confidence=connection.confidence
)
def _generate_statistics(
self,
hosts: List[Host],
connections: List[Connection]
) -> Dict[str, Any]:
"""
Generate statistics about the topology.
Args:
hosts: List of hosts
connections: List of connections
Returns:
Statistics dictionary
"""
# Count isolated nodes (no connections)
isolated = 0
for host in hosts:
conn_count = self.db.query(Connection).filter(
(Connection.source_host_id == host.id) |
(Connection.target_host_id == host.id)
).count()
if conn_count == 0:
isolated += 1
# Calculate average connections
avg_connections = len(connections) / max(len(hosts), 1) if hosts else 0
return {
'total_nodes': len(hosts),
'total_edges': len(connections),
'isolated_nodes': isolated,
'avg_connections': round(avg_connections, 2)
}
def get_host_neighbors(self, host_id: int) -> List[Host]:
"""
Get all hosts connected to a specific host.
Args:
host_id: Host ID
Returns:
List of connected hosts
"""
# Get outgoing connections
outgoing = self.db.query(Connection).filter(
Connection.source_host_id == host_id
).all()
# Get incoming connections
incoming = self.db.query(Connection).filter(
Connection.target_host_id == host_id
).all()
# Collect unique neighbor IDs
neighbor_ids = set()
for conn in outgoing:
neighbor_ids.add(conn.target_host_id)
for conn in incoming:
neighbor_ids.add(conn.source_host_id)
# Get host objects
neighbors = self.db.query(Host).filter(
Host.id.in_(neighbor_ids)
).all()
return neighbors
def get_network_statistics(self) -> Dict[str, Any]:
"""
Get network statistics.
Returns:
Statistics dictionary
"""
total_hosts = self.db.query(func.count(Host.id)).scalar()
online_hosts = self.db.query(func.count(Host.id)).filter(
Host.status == 'online'
).scalar()
total_services = self.db.query(func.count(Service.id)).scalar()
return {
'total_hosts': total_hosts,
'online_hosts': online_hosts,
'offline_hosts': total_hosts - online_hosts,
'total_services': total_services,
'total_connections': self.db.query(func.count(Connection.id)).scalar()
}