Initial commit: Werkzeuge-Sammlung
Enthält: - rdp_client.py: RDP Client mit GUI und Monitor-Auswahl - rdp.sh: Bash-basierter RDP Client - teamleader_test/: Network Scanner Fullstack-App - teamleader_test2/: Network Mapper CLI Subdirectories mit eigenem Repo wurden ausgeschlossen. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
4
teamleader_test/app/__init__.py
Normal file
4
teamleader_test/app/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Network Scanner Application Package."""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "DevAgent"
|
||||
13
teamleader_test/app/api/__init__.py
Normal file
13
teamleader_test/app/api/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""API router initialization."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import scans, hosts, topology, websocket
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# Include endpoint routers
|
||||
api_router.include_router(scans.router, prefix="/scans", tags=["scans"])
|
||||
api_router.include_router(hosts.router, prefix="/hosts", tags=["hosts"])
|
||||
api_router.include_router(topology.router, prefix="/topology", tags=["topology"])
|
||||
api_router.include_router(websocket.router, prefix="/ws", tags=["websocket"])
|
||||
1
teamleader_test/app/api/endpoints/__init__.py
Normal file
1
teamleader_test/app/api/endpoints/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API endpoints package."""
|
||||
222
teamleader_test/app/api/endpoints/hosts.py
Normal file
222
teamleader_test/app/api/endpoints/hosts.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Host API endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Host, Service
|
||||
from app.schemas import HostResponse, HostDetailResponse, ServiceResponse, NetworkStatistics
|
||||
from app.services.topology_service import TopologyService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[HostResponse])
|
||||
def list_hosts(
|
||||
status: Optional[str] = Query(None, description="Filter by status (online/offline)"),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
search: Optional[str] = Query(None, description="Search by IP or hostname"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List discovered hosts.
|
||||
|
||||
Args:
|
||||
status: Filter by host status
|
||||
limit: Maximum number of hosts to return
|
||||
offset: Number of hosts to skip
|
||||
search: Search query
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of hosts
|
||||
"""
|
||||
query = db.query(Host)
|
||||
|
||||
# Apply filters
|
||||
if status:
|
||||
query = query.filter(Host.status == status)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
Host.ip_address.like(search_pattern),
|
||||
Host.hostname.like(search_pattern)
|
||||
)
|
||||
)
|
||||
|
||||
# Order by last seen
|
||||
query = query.order_by(Host.last_seen.desc())
|
||||
|
||||
# Apply pagination
|
||||
hosts = query.limit(limit).offset(offset).all()
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=NetworkStatistics)
|
||||
def get_network_statistics(db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get network statistics.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Network statistics
|
||||
"""
|
||||
topology_service = TopologyService(db)
|
||||
stats = topology_service.get_network_statistics()
|
||||
|
||||
# Get most common services
|
||||
from sqlalchemy import func
|
||||
service_counts = db.query(
|
||||
Service.service_name,
|
||||
func.count(Service.id).label('count')
|
||||
).filter(
|
||||
Service.service_name.isnot(None)
|
||||
).group_by(
|
||||
Service.service_name
|
||||
).order_by(
|
||||
func.count(Service.id).desc()
|
||||
).limit(10).all()
|
||||
|
||||
# Get last scan time
|
||||
from app.models import Scan
|
||||
last_scan = db.query(Scan).order_by(Scan.started_at.desc()).first()
|
||||
|
||||
return NetworkStatistics(
|
||||
total_hosts=stats['total_hosts'],
|
||||
online_hosts=stats['online_hosts'],
|
||||
offline_hosts=stats['offline_hosts'],
|
||||
total_services=stats['total_services'],
|
||||
total_scans=db.query(func.count(Scan.id)).scalar() or 0,
|
||||
last_scan=last_scan.started_at if last_scan else None,
|
||||
most_common_services=[
|
||||
{'service_name': s[0], 'count': s[1]}
|
||||
for s in service_counts
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/by-service/{service_name}", response_model=List[HostResponse])
|
||||
def get_hosts_by_service(
|
||||
service_name: str,
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get all hosts that provide a specific service.
|
||||
|
||||
Args:
|
||||
service_name: Service name to filter by
|
||||
limit: Maximum number of hosts to return
|
||||
offset: Number of hosts to skip
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of hosts providing the service
|
||||
"""
|
||||
hosts = db.query(Host).join(Service).filter(
|
||||
Service.service_name == service_name
|
||||
).distinct().order_by(
|
||||
Host.last_seen.desc()
|
||||
).limit(limit).offset(offset).all()
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
@router.get("/{host_id}", response_model=HostDetailResponse)
|
||||
def get_host_detail(host_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get detailed information about a specific host.
|
||||
|
||||
Args:
|
||||
host_id: Host ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Detailed host information
|
||||
"""
|
||||
host = db.query(Host).filter(Host.id == host_id).first()
|
||||
|
||||
if not host:
|
||||
raise HTTPException(status_code=404, detail=f"Host {host_id} not found")
|
||||
|
||||
return host
|
||||
|
||||
|
||||
@router.get("/{host_id}/services", response_model=List[ServiceResponse])
|
||||
def get_host_services(host_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get all services for a specific host.
|
||||
|
||||
Args:
|
||||
host_id: Host ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of services
|
||||
"""
|
||||
host = db.query(Host).filter(Host.id == host_id).first()
|
||||
|
||||
if not host:
|
||||
raise HTTPException(status_code=404, detail=f"Host {host_id} not found")
|
||||
|
||||
return host.services
|
||||
|
||||
|
||||
@router.delete("/{host_id}")
|
||||
def delete_host(host_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Delete a host from the database.
|
||||
|
||||
Args:
|
||||
host_id: Host ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
host = db.query(Host).filter(Host.id == host_id).first()
|
||||
|
||||
if not host:
|
||||
raise HTTPException(status_code=404, detail=f"Host {host_id} not found")
|
||||
|
||||
db.delete(host)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Deleted host {host_id} ({host.ip_address})")
|
||||
|
||||
return {"message": f"Host {host_id} deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/ip/{ip_address}", response_model=HostResponse)
|
||||
def get_host_by_ip(ip_address: str, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get host information by IP address.
|
||||
|
||||
Args:
|
||||
ip_address: IP address
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Host information
|
||||
"""
|
||||
host = db.query(Host).filter(Host.ip_address == ip_address).first()
|
||||
|
||||
if not host:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Host with IP {ip_address} not found"
|
||||
)
|
||||
|
||||
return host
|
||||
209
teamleader_test/app/api/endpoints/scans.py
Normal file
209
teamleader_test/app/api/endpoints/scans.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Scan API endpoints."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.schemas import (
|
||||
ScanConfigRequest,
|
||||
ScanResponse,
|
||||
ScanStatusResponse,
|
||||
ScanStartResponse,
|
||||
ScanStatus as ScanStatusEnum
|
||||
)
|
||||
from app.services.scan_service import ScanService
|
||||
from app.api.endpoints.websocket import (
|
||||
send_scan_progress,
|
||||
send_host_discovered,
|
||||
send_scan_completed,
|
||||
send_scan_failed
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/start", response_model=ScanStartResponse, status_code=202)
|
||||
async def start_scan(
|
||||
config: ScanConfigRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Start a new network scan.
|
||||
|
||||
Args:
|
||||
config: Scan configuration
|
||||
background_tasks: Background task handler
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Scan start response with scan ID
|
||||
"""
|
||||
try:
|
||||
scan_service = ScanService(db)
|
||||
|
||||
# Create scan record
|
||||
scan = scan_service.create_scan(config)
|
||||
scan_id = scan.id
|
||||
|
||||
# Create progress callback for WebSocket updates
|
||||
async def progress_callback(update: dict):
|
||||
"""Send progress updates via WebSocket."""
|
||||
update_type = update.get('type')
|
||||
update_scan_id = update.get('scan_id', scan_id)
|
||||
|
||||
if update_type == 'scan_progress':
|
||||
await send_scan_progress(update_scan_id, update.get('progress', 0), update.get('current_host'))
|
||||
elif update_type == 'host_discovered':
|
||||
await send_host_discovered(update_scan_id, update.get('host'))
|
||||
elif update_type == 'scan_completed':
|
||||
await send_scan_completed(update_scan_id, {'hosts_found': update.get('hosts_found', 0)})
|
||||
elif update_type == 'scan_failed':
|
||||
await send_scan_failed(update_scan_id, update.get('error', 'Unknown error'))
|
||||
|
||||
# Create background task wrapper that uses a new database session
|
||||
async def run_scan_task():
|
||||
from app.database import SessionLocal
|
||||
scan_db = SessionLocal()
|
||||
try:
|
||||
scan_service_bg = ScanService(scan_db)
|
||||
await scan_service_bg.execute_scan(scan_id, config, progress_callback)
|
||||
finally:
|
||||
scan_db.close()
|
||||
|
||||
# Create and store the task for this scan
|
||||
task = asyncio.create_task(run_scan_task())
|
||||
scan_service.active_scans[scan_id] = task
|
||||
|
||||
logger.info(f"Started scan {scan_id} for {config.network_range}")
|
||||
|
||||
return ScanStartResponse(
|
||||
scan_id=scan_id,
|
||||
message=f"Scan started for network {config.network_range}",
|
||||
status=ScanStatusEnum.PENDING
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting scan: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to start scan")
|
||||
|
||||
|
||||
@router.get("/{scan_id}/status", response_model=ScanStatusResponse)
|
||||
def get_scan_status(scan_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get the status of a specific scan.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Scan status information
|
||||
"""
|
||||
scan_service = ScanService(db)
|
||||
scan = scan_service.get_scan_status(scan_id)
|
||||
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail=f"Scan {scan_id} not found")
|
||||
|
||||
# Calculate progress
|
||||
progress = 0.0
|
||||
if scan.status == ScanStatusEnum.COMPLETED.value:
|
||||
progress = 1.0
|
||||
elif scan.status == ScanStatusEnum.RUNNING.value:
|
||||
# Estimate progress based on hosts found
|
||||
# This is a rough estimate; real-time progress comes from WebSocket
|
||||
if scan.hosts_found > 0:
|
||||
progress = 0.5 # Host discovery done
|
||||
|
||||
return ScanStatusResponse(
|
||||
id=scan.id,
|
||||
started_at=scan.started_at,
|
||||
completed_at=scan.completed_at,
|
||||
scan_type=scan.scan_type,
|
||||
network_range=scan.network_range,
|
||||
status=ScanStatusEnum(scan.status),
|
||||
hosts_found=scan.hosts_found,
|
||||
ports_scanned=scan.ports_scanned,
|
||||
error_message=scan.error_message,
|
||||
progress=progress,
|
||||
current_host=None,
|
||||
estimated_completion=None
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ScanResponse])
|
||||
def list_scans(
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List recent scans.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of scans to return
|
||||
offset: Number of scans to skip
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of scans
|
||||
"""
|
||||
scan_service = ScanService(db)
|
||||
scans = scan_service.list_scans(limit=limit, offset=offset)
|
||||
|
||||
return [
|
||||
ScanResponse(
|
||||
id=scan.id,
|
||||
started_at=scan.started_at,
|
||||
completed_at=scan.completed_at,
|
||||
scan_type=scan.scan_type,
|
||||
network_range=scan.network_range,
|
||||
status=ScanStatusEnum(scan.status),
|
||||
hosts_found=scan.hosts_found,
|
||||
ports_scanned=scan.ports_scanned,
|
||||
error_message=scan.error_message
|
||||
)
|
||||
for scan in scans
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{scan_id}/cancel")
|
||||
def cancel_scan(scan_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Cancel a running scan.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
scan_service = ScanService(db)
|
||||
|
||||
# Check if scan exists
|
||||
scan = scan_service.get_scan_status(scan_id)
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail=f"Scan {scan_id} not found")
|
||||
|
||||
# Check if scan is running
|
||||
if scan.status not in [ScanStatusEnum.PENDING.value, ScanStatusEnum.RUNNING.value]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel scan in status: {scan.status}"
|
||||
)
|
||||
|
||||
# Attempt to cancel
|
||||
success = scan_service.cancel_scan(scan_id)
|
||||
|
||||
if success:
|
||||
return {"message": f"Scan {scan_id} cancelled successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to cancel scan")
|
||||
70
teamleader_test/app/api/endpoints/topology.py
Normal file
70
teamleader_test/app/api/endpoints/topology.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Topology API endpoints."""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.schemas import TopologyResponse
|
||||
from app.services.topology_service import TopologyService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=TopologyResponse)
|
||||
def get_network_topology(
|
||||
include_offline: bool = Query(False, description="Include offline hosts"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get network topology graph data.
|
||||
|
||||
Args:
|
||||
include_offline: Whether to include offline hosts
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Topology data with nodes and edges
|
||||
"""
|
||||
try:
|
||||
topology_service = TopologyService(db)
|
||||
topology = topology_service.generate_topology(include_offline=include_offline)
|
||||
|
||||
logger.info(f"Generated topology with {len(topology.nodes)} nodes")
|
||||
|
||||
return topology
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating topology: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to generate topology")
|
||||
|
||||
|
||||
@router.get("/neighbors/{host_id}")
|
||||
def get_host_neighbors(host_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get neighboring hosts for a specific host.
|
||||
|
||||
Args:
|
||||
host_id: Host ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of neighboring hosts
|
||||
"""
|
||||
topology_service = TopologyService(db)
|
||||
neighbors = topology_service.get_host_neighbors(host_id)
|
||||
|
||||
return {
|
||||
'host_id': host_id,
|
||||
'neighbors': [
|
||||
{
|
||||
'id': h.id,
|
||||
'ip_address': h.ip_address,
|
||||
'hostname': h.hostname,
|
||||
'status': h.status
|
||||
}
|
||||
for h in neighbors
|
||||
]
|
||||
}
|
||||
222
teamleader_test/app/api/endpoints/websocket.py
Normal file
222
teamleader_test/app/api/endpoints/websocket.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""WebSocket endpoint for real-time updates."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Set
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manager for WebSocket connections."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize connection manager."""
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
"""
|
||||
Accept and register a new WebSocket connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
"""
|
||||
Remove a WebSocket connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
self.active_connections.discard(websocket)
|
||||
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
async def send_personal_message(self, message: dict, websocket: WebSocket):
|
||||
"""
|
||||
Send a message to a specific WebSocket.
|
||||
|
||||
Args:
|
||||
message: Message to send
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending message: {e}")
|
||||
self.disconnect(websocket)
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""
|
||||
Broadcast a message to all connected WebSockets.
|
||||
|
||||
Args:
|
||||
message: Message to broadcast
|
||||
"""
|
||||
disconnected = set()
|
||||
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting to connection: {e}")
|
||||
disconnected.add(connection)
|
||||
|
||||
# Clean up disconnected clients
|
||||
for connection in disconnected:
|
||||
self.disconnect(connection)
|
||||
|
||||
|
||||
# Global connection manager instance
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket("")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for real-time scan updates.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
await manager.connect(websocket)
|
||||
|
||||
try:
|
||||
# Send welcome message
|
||||
await manager.send_personal_message({
|
||||
'type': 'connected',
|
||||
'message': 'Connected to network scanner',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}, websocket)
|
||||
|
||||
# Keep connection alive and handle incoming messages
|
||||
while True:
|
||||
try:
|
||||
# Receive messages from client
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Parse and handle client messages
|
||||
try:
|
||||
message = json.loads(data)
|
||||
await handle_client_message(message, websocket)
|
||||
except json.JSONDecodeError:
|
||||
await manager.send_personal_message({
|
||||
'type': 'error',
|
||||
'message': 'Invalid JSON format',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket loop: {e}")
|
||||
break
|
||||
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
async def handle_client_message(message: dict, websocket: WebSocket):
|
||||
"""
|
||||
Handle messages from client.
|
||||
|
||||
Args:
|
||||
message: Client message
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
message_type = message.get('type')
|
||||
|
||||
if message_type == 'ping':
|
||||
# Respond to ping
|
||||
await manager.send_personal_message({
|
||||
'type': 'pong',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}, websocket)
|
||||
|
||||
elif message_type == 'subscribe':
|
||||
# Handle subscription requests
|
||||
scan_id = message.get('scan_id')
|
||||
if scan_id:
|
||||
await manager.send_personal_message({
|
||||
'type': 'subscribed',
|
||||
'scan_id': scan_id,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}, websocket)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message_type}")
|
||||
|
||||
|
||||
async def broadcast_scan_update(scan_id: int, update_type: str, data: dict):
|
||||
"""
|
||||
Broadcast scan update to all connected clients.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
update_type: Type of update
|
||||
data: Update data
|
||||
"""
|
||||
message = {
|
||||
'type': update_type,
|
||||
'scan_id': scan_id,
|
||||
'data': data,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await manager.broadcast(message)
|
||||
|
||||
|
||||
async def send_scan_progress(scan_id: int, progress: float, current_host: str = None):
|
||||
"""
|
||||
Send scan progress update.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
progress: Progress value (0.0 to 1.0)
|
||||
current_host: Currently scanning host
|
||||
"""
|
||||
await broadcast_scan_update(scan_id, 'scan_progress', {
|
||||
'progress': progress,
|
||||
'current_host': current_host
|
||||
})
|
||||
|
||||
|
||||
async def send_host_discovered(scan_id: int, host_data: dict):
|
||||
"""
|
||||
Send host discovered notification.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
host_data: Host information
|
||||
"""
|
||||
await broadcast_scan_update(scan_id, 'host_discovered', host_data)
|
||||
|
||||
|
||||
async def send_scan_completed(scan_id: int, summary: dict):
|
||||
"""
|
||||
Send scan completed notification.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
summary: Scan summary
|
||||
"""
|
||||
await broadcast_scan_update(scan_id, 'scan_completed', summary)
|
||||
|
||||
|
||||
async def send_scan_failed(scan_id: int, error: str):
|
||||
"""
|
||||
Send scan failed notification.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
error: Error message
|
||||
"""
|
||||
await broadcast_scan_update(scan_id, 'scan_failed', {'error': error})
|
||||
43
teamleader_test/app/config.py
Normal file
43
teamleader_test/app/config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Configuration management for the network scanner application."""
|
||||
|
||||
from typing import List
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Application
|
||||
app_name: str = Field(default="Network Scanner", alias="APP_NAME")
|
||||
app_version: str = Field(default="1.0.0", alias="APP_VERSION")
|
||||
debug: bool = Field(default=False, alias="DEBUG")
|
||||
|
||||
# Database
|
||||
database_url: str = Field(default="sqlite:///./network_scanner.db", alias="DATABASE_URL")
|
||||
|
||||
# Scanning
|
||||
default_scan_timeout: int = Field(default=3, alias="DEFAULT_SCAN_TIMEOUT")
|
||||
max_concurrent_scans: int = Field(default=50, alias="MAX_CONCURRENT_SCANS")
|
||||
enable_nmap: bool = Field(default=True, alias="ENABLE_NMAP")
|
||||
|
||||
# Network
|
||||
default_network_range: str = Field(default="192.168.1.0/24", alias="DEFAULT_NETWORK_RANGE")
|
||||
scan_private_networks_only: bool = Field(default=True, alias="SCAN_PRIVATE_NETWORKS_ONLY")
|
||||
|
||||
# API
|
||||
api_prefix: str = Field(default="/api", alias="API_PREFIX")
|
||||
cors_origins: List[str] = Field(default=["http://localhost:3000"], alias="CORS_ORIGINS")
|
||||
|
||||
# Logging
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
log_file: str = Field(default="logs/network_scanner.log", alias="LOG_FILE")
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration."""
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
41
teamleader_test/app/database.py
Normal file
41
teamleader_test/app/database.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Database configuration and session management."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from typing import Generator
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
# Create database engine
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
connect_args={"check_same_thread": False} if "sqlite" in settings.database_url else {},
|
||||
echo=settings.debug
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# Base class for models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Dependency function to get database session.
|
||||
|
||||
Yields:
|
||||
Session: SQLAlchemy database session
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize database tables."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
122
teamleader_test/app/models.py
Normal file
122
teamleader_test/app/models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""SQLAlchemy database models."""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Text, ForeignKey, Table, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between scans and hosts
|
||||
scan_hosts = Table(
|
||||
'scan_hosts',
|
||||
Base.metadata,
|
||||
Column('scan_id', Integer, ForeignKey('scans.id', ondelete='CASCADE'), primary_key=True),
|
||||
Column('host_id', Integer, ForeignKey('hosts.id', ondelete='CASCADE'), primary_key=True)
|
||||
)
|
||||
|
||||
|
||||
class Scan(Base):
|
||||
"""Model for scan operations."""
|
||||
|
||||
__tablename__ = 'scans'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
started_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
scan_type = Column(String(50), nullable=False, default='quick')
|
||||
network_range = Column(String(100), nullable=False)
|
||||
status = Column(String(20), nullable=False, default='pending')
|
||||
hosts_found = Column(Integer, default=0)
|
||||
ports_scanned = Column(Integer, default=0)
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
hosts = relationship('Host', secondary=scan_hosts, back_populates='scans')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Scan(id={self.id}, network={self.network_range}, status={self.status})>"
|
||||
|
||||
|
||||
class Host(Base):
|
||||
"""Model for discovered network hosts."""
|
||||
|
||||
__tablename__ = 'hosts'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
ip_address = Column(String(45), nullable=False, unique=True, index=True)
|
||||
hostname = Column(String(255), nullable=True)
|
||||
mac_address = Column(String(17), nullable=True)
|
||||
first_seen = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
last_seen = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
status = Column(String(20), nullable=False, default='online', index=True)
|
||||
os_guess = Column(String(255), nullable=True)
|
||||
device_type = Column(String(50), nullable=True)
|
||||
vendor = Column(String(255), nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
services = relationship('Service', back_populates='host', cascade='all, delete-orphan')
|
||||
scans = relationship('Scan', secondary=scan_hosts, back_populates='hosts')
|
||||
outgoing_connections = relationship(
|
||||
'Connection',
|
||||
foreign_keys='Connection.source_host_id',
|
||||
back_populates='source_host',
|
||||
cascade='all, delete-orphan'
|
||||
)
|
||||
incoming_connections = relationship(
|
||||
'Connection',
|
||||
foreign_keys='Connection.target_host_id',
|
||||
back_populates='target_host',
|
||||
cascade='all, delete-orphan'
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Host(id={self.id}, ip={self.ip_address}, hostname={self.hostname})>"
|
||||
|
||||
|
||||
class Service(Base):
|
||||
"""Model for services running on hosts (open ports)."""
|
||||
|
||||
__tablename__ = 'services'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
host_id = Column(Integer, ForeignKey('hosts.id', ondelete='CASCADE'), nullable=False)
|
||||
port = Column(Integer, nullable=False)
|
||||
protocol = Column(String(10), nullable=False, default='tcp')
|
||||
state = Column(String(20), nullable=False, default='open')
|
||||
service_name = Column(String(100), nullable=True)
|
||||
service_version = Column(String(255), nullable=True)
|
||||
banner = Column(Text, nullable=True)
|
||||
first_seen = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
last_seen = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
host = relationship('Host', back_populates='services')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Service(host_id={self.host_id}, port={self.port}, service={self.service_name})>"
|
||||
|
||||
|
||||
class Connection(Base):
|
||||
"""Model for detected connections between hosts."""
|
||||
|
||||
__tablename__ = 'connections'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
source_host_id = Column(Integer, ForeignKey('hosts.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
target_host_id = Column(Integer, ForeignKey('hosts.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
connection_type = Column(String(50), nullable=False)
|
||||
protocol = Column(String(10), nullable=True)
|
||||
port = Column(Integer, nullable=True)
|
||||
confidence = Column(Float, nullable=False, default=1.0)
|
||||
detected_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
last_verified = Column(DateTime, nullable=True)
|
||||
extra_data = Column(JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
source_host = relationship('Host', foreign_keys=[source_host_id], back_populates='outgoing_connections')
|
||||
target_host = relationship('Host', foreign_keys=[target_host_id], back_populates='incoming_connections')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Connection(source={self.source_host_id}, target={self.target_host_id}, type={self.connection_type})>"
|
||||
7
teamleader_test/app/scanner/__init__.py
Normal file
7
teamleader_test/app/scanner/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Network scanner module."""
|
||||
|
||||
from app.scanner.network_scanner import NetworkScanner
|
||||
from app.scanner.port_scanner import PortScanner
|
||||
from app.scanner.service_detector import ServiceDetector
|
||||
|
||||
__all__ = ['NetworkScanner', 'PortScanner', 'ServiceDetector']
|
||||
242
teamleader_test/app/scanner/network_scanner.py
Normal file
242
teamleader_test/app/scanner/network_scanner.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Network scanner implementation for host discovery."""
|
||||
|
||||
import socket
|
||||
import ipaddress
|
||||
import asyncio
|
||||
from typing import List, Set, Optional, Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NetworkScanner:
|
||||
"""Scanner for discovering active hosts on a network."""
|
||||
|
||||
# Common ports for host discovery
|
||||
DISCOVERY_PORTS = [21, 22, 23, 25, 80, 443, 445, 3389, 8080, 8443]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = None,
|
||||
max_workers: int = None,
|
||||
progress_callback: Optional[Callable[[str, float], None]] = None
|
||||
):
|
||||
"""
|
||||
Initialize network scanner.
|
||||
|
||||
Args:
|
||||
timeout: Socket connection timeout in seconds
|
||||
max_workers: Maximum number of concurrent workers
|
||||
progress_callback: Optional callback for progress updates
|
||||
"""
|
||||
self.timeout = timeout or settings.default_scan_timeout
|
||||
self.max_workers = max_workers or settings.max_concurrent_scans
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
async def scan_network(self, network_range: str) -> List[str]:
|
||||
"""
|
||||
Scan a network range for active hosts.
|
||||
|
||||
Args:
|
||||
network_range: Network in CIDR notation (e.g., '192.168.1.0/24')
|
||||
|
||||
Returns:
|
||||
List of active IP addresses
|
||||
"""
|
||||
logger.info(f"Starting network scan of {network_range}")
|
||||
|
||||
try:
|
||||
network = ipaddress.ip_network(network_range, strict=False)
|
||||
|
||||
# Validate private network if restriction enabled
|
||||
if settings.scan_private_networks_only and not network.is_private:
|
||||
raise ValueError(f"Network {network_range} is not a private network")
|
||||
|
||||
# Generate list of hosts to scan
|
||||
hosts = [str(ip) for ip in network.hosts()]
|
||||
total_hosts = len(hosts)
|
||||
|
||||
if total_hosts == 0:
|
||||
# Single host network
|
||||
hosts = [str(network.network_address)]
|
||||
total_hosts = 1
|
||||
|
||||
logger.info(f"Scanning {total_hosts} hosts in {network_range}")
|
||||
|
||||
# Scan hosts concurrently
|
||||
active_hosts = await self._scan_hosts_async(hosts)
|
||||
|
||||
logger.info(f"Scan completed. Found {len(active_hosts)} active hosts")
|
||||
return active_hosts
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid network range: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during network scan: {e}")
|
||||
raise
|
||||
|
||||
async def _scan_hosts_async(self, hosts: List[str]) -> List[str]:
|
||||
"""
|
||||
Scan multiple hosts asynchronously.
|
||||
|
||||
Args:
|
||||
hosts: List of IP addresses to scan
|
||||
|
||||
Returns:
|
||||
List of active hosts
|
||||
"""
|
||||
active_hosts: Set[str] = set()
|
||||
total = len(hosts)
|
||||
completed = 0
|
||||
|
||||
# Use ThreadPoolExecutor for socket operations
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
futures = []
|
||||
|
||||
for host in hosts:
|
||||
future = loop.run_in_executor(executor, self._check_host, host)
|
||||
futures.append((host, future))
|
||||
|
||||
# Process results as they complete
|
||||
for host, future in futures:
|
||||
try:
|
||||
is_active = await future
|
||||
if is_active:
|
||||
active_hosts.add(host)
|
||||
logger.debug(f"Host {host} is active")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking host {host}: {e}")
|
||||
finally:
|
||||
completed += 1
|
||||
if self.progress_callback:
|
||||
progress = completed / total
|
||||
self.progress_callback(host, progress)
|
||||
|
||||
return sorted(list(active_hosts), key=lambda ip: ipaddress.ip_address(ip))
|
||||
|
||||
def _check_host(self, ip: str) -> bool:
|
||||
"""
|
||||
Check if a host is active by attempting TCP connections.
|
||||
|
||||
Args:
|
||||
ip: IP address to check
|
||||
|
||||
Returns:
|
||||
True if host responds on any discovery port
|
||||
"""
|
||||
for port in self.DISCOVERY_PORTS:
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
result = sock.connect_ex((ip, port))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
return True
|
||||
except socket.error:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking {ip}:{port}: {e}")
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def get_local_network_range(self) -> Optional[str]:
|
||||
"""
|
||||
Detect local network range.
|
||||
|
||||
Returns:
|
||||
Network range in CIDR notation or None
|
||||
"""
|
||||
try:
|
||||
import netifaces
|
||||
|
||||
# Get default gateway interface
|
||||
gateways = netifaces.gateways()
|
||||
if 'default' not in gateways or netifaces.AF_INET not in gateways['default']:
|
||||
return None
|
||||
|
||||
default_interface = gateways['default'][netifaces.AF_INET][1]
|
||||
|
||||
# Get interface addresses
|
||||
addrs = netifaces.ifaddresses(default_interface)
|
||||
if netifaces.AF_INET not in addrs:
|
||||
return None
|
||||
|
||||
# Get IP and netmask
|
||||
inet_info = addrs[netifaces.AF_INET][0]
|
||||
ip = inet_info.get('addr')
|
||||
netmask = inet_info.get('netmask')
|
||||
|
||||
if not ip or not netmask:
|
||||
return None
|
||||
|
||||
# Calculate network address
|
||||
network = ipaddress.ip_network(f"{ip}/{netmask}", strict=False)
|
||||
return str(network)
|
||||
|
||||
except ImportError:
|
||||
logger.warning("netifaces not available, cannot detect local network")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting local network: {e}")
|
||||
return None
|
||||
|
||||
def resolve_hostname(self, ip: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve IP address to hostname.
|
||||
|
||||
Args:
|
||||
ip: IP address
|
||||
|
||||
Returns:
|
||||
Hostname or None
|
||||
"""
|
||||
try:
|
||||
hostname = socket.gethostbyaddr(ip)[0]
|
||||
return hostname
|
||||
except socket.herror:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error resolving {ip}: {e}")
|
||||
return None
|
||||
|
||||
def get_mac_address(self, ip: str) -> Optional[str]:
|
||||
"""
|
||||
Get MAC address for an IP (requires ARP access).
|
||||
|
||||
Args:
|
||||
ip: IP address
|
||||
|
||||
Returns:
|
||||
MAC address or None
|
||||
"""
|
||||
try:
|
||||
# Try to get MAC from ARP cache
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
# Platform-specific ARP command
|
||||
import platform
|
||||
if platform.system() == 'Windows':
|
||||
arp_output = subprocess.check_output(['arp', '-a', ip]).decode()
|
||||
mac_pattern = r'([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})'
|
||||
else:
|
||||
arp_output = subprocess.check_output(['arp', '-n', ip]).decode()
|
||||
mac_pattern = r'([0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}'
|
||||
|
||||
match = re.search(mac_pattern, arp_output)
|
||||
if match:
|
||||
return match.group(0).upper()
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting MAC for {ip}: {e}")
|
||||
return None
|
||||
260
teamleader_test/app/scanner/nmap_scanner.py
Normal file
260
teamleader_test/app/scanner/nmap_scanner.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Nmap integration for advanced scanning capabilities."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NmapScanner:
|
||||
"""Wrapper for python-nmap with safe execution."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize nmap scanner."""
|
||||
self.nmap_available = self._check_nmap_available()
|
||||
if not self.nmap_available:
|
||||
logger.warning("nmap is not available on this system")
|
||||
|
||||
def _check_nmap_available(self) -> bool:
|
||||
"""
|
||||
Check if nmap is available on the system.
|
||||
|
||||
Returns:
|
||||
True if nmap is available
|
||||
"""
|
||||
try:
|
||||
import nmap
|
||||
nm = nmap.PortScanner()
|
||||
nm.nmap_version()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"nmap not available: {e}")
|
||||
return False
|
||||
|
||||
async def scan_host(
|
||||
self,
|
||||
host: str,
|
||||
arguments: str = '-sT -T4'
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Scan a host using nmap.
|
||||
|
||||
Args:
|
||||
host: IP address or hostname
|
||||
arguments: Nmap arguments (default: TCP connect scan, aggressive timing)
|
||||
|
||||
Returns:
|
||||
Scan results dictionary or None
|
||||
"""
|
||||
if not self.nmap_available:
|
||||
logger.warning("Attempted to use nmap but it's not available")
|
||||
return None
|
||||
|
||||
try:
|
||||
import nmap
|
||||
|
||||
# Run nmap scan in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
self._run_nmap_scan,
|
||||
host,
|
||||
arguments
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running nmap scan on {host}: {e}")
|
||||
return None
|
||||
|
||||
def _run_nmap_scan(self, host: str, arguments: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Run nmap scan synchronously.
|
||||
|
||||
Args:
|
||||
host: Host to scan
|
||||
arguments: Nmap arguments
|
||||
|
||||
Returns:
|
||||
Scan results
|
||||
"""
|
||||
try:
|
||||
import nmap
|
||||
|
||||
nm = nmap.PortScanner()
|
||||
|
||||
# Sanitize host input
|
||||
if not self._validate_host(host):
|
||||
logger.error(f"Invalid host: {host}")
|
||||
return None
|
||||
|
||||
# Execute scan
|
||||
logger.info(f"Running nmap scan: nmap {arguments} {host}")
|
||||
nm.scan(hosts=host, arguments=arguments)
|
||||
|
||||
# Parse results
|
||||
if host not in nm.all_hosts():
|
||||
logger.debug(f"No results for {host}")
|
||||
return None
|
||||
|
||||
host_info = nm[host]
|
||||
|
||||
# Extract relevant information
|
||||
result = {
|
||||
'hostname': host_info.hostname(),
|
||||
'state': host_info.state(),
|
||||
'protocols': list(host_info.all_protocols()),
|
||||
'ports': []
|
||||
}
|
||||
|
||||
# Extract port information
|
||||
for proto in host_info.all_protocols():
|
||||
ports = host_info[proto].keys()
|
||||
for port in ports:
|
||||
port_info = host_info[proto][port]
|
||||
result['ports'].append({
|
||||
'port': port,
|
||||
'protocol': proto,
|
||||
'state': port_info['state'],
|
||||
'service_name': port_info.get('name'),
|
||||
'service_version': port_info.get('version'),
|
||||
'service_product': port_info.get('product'),
|
||||
'extrainfo': port_info.get('extrainfo')
|
||||
})
|
||||
|
||||
# OS detection if available
|
||||
if 'osmatch' in host_info:
|
||||
result['os_matches'] = [
|
||||
{
|
||||
'name': os['name'],
|
||||
'accuracy': os['accuracy']
|
||||
}
|
||||
for os in host_info['osmatch']
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _run_nmap_scan for {host}: {e}")
|
||||
return None
|
||||
|
||||
def _validate_host(self, host: str) -> bool:
|
||||
"""
|
||||
Validate host input to prevent command injection.
|
||||
|
||||
Args:
|
||||
host: Host string to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
"""
|
||||
import ipaddress
|
||||
import re
|
||||
|
||||
# Try as IP address
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try as network range
|
||||
try:
|
||||
ipaddress.ip_network(host, strict=False)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try as hostname (alphanumeric, dots, hyphens only)
|
||||
if re.match(r'^[a-zA-Z0-9.-]+$', host):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_scan_arguments(
|
||||
self,
|
||||
scan_type: str,
|
||||
service_detection: bool = True,
|
||||
os_detection: bool = False,
|
||||
port_range: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate nmap arguments based on scan configuration.
|
||||
|
||||
Args:
|
||||
scan_type: Type of scan ('quick', 'standard', 'deep')
|
||||
service_detection: Enable service/version detection
|
||||
os_detection: Enable OS detection (requires root)
|
||||
port_range: Custom port range (e.g., '1-1000' or '80,443,8080')
|
||||
|
||||
Returns:
|
||||
Nmap argument string
|
||||
"""
|
||||
args = []
|
||||
|
||||
# Use TCP connect scan (no root required)
|
||||
args.append('-sT')
|
||||
|
||||
# Port specification
|
||||
if port_range:
|
||||
args.append(f'-p {port_range}')
|
||||
elif scan_type == 'quick':
|
||||
args.append('--top-ports 100')
|
||||
elif scan_type == 'standard':
|
||||
args.append('--top-ports 1000')
|
||||
elif scan_type == 'deep':
|
||||
args.append('-p-') # All ports
|
||||
|
||||
# Only show open ports
|
||||
args.append('--open')
|
||||
|
||||
# Timing
|
||||
if scan_type == 'quick':
|
||||
args.append('-T5') # Insane
|
||||
elif scan_type == 'deep':
|
||||
args.append('-T3') # Normal
|
||||
else:
|
||||
args.append('-T4') # Aggressive
|
||||
|
||||
# Service detection
|
||||
if service_detection:
|
||||
args.append('-sV')
|
||||
|
||||
# OS detection (requires root)
|
||||
if os_detection:
|
||||
args.append('-O')
|
||||
logger.warning("OS detection requires root privileges")
|
||||
|
||||
return ' '.join(args)
|
||||
|
||||
async def scan_network_with_nmap(
|
||||
self,
|
||||
network: str,
|
||||
scan_type: str = 'quick'
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Scan entire network using nmap.
|
||||
|
||||
Args:
|
||||
network: Network in CIDR notation
|
||||
scan_type: Type of scan
|
||||
|
||||
Returns:
|
||||
List of host results
|
||||
"""
|
||||
if not self.nmap_available:
|
||||
return []
|
||||
|
||||
try:
|
||||
arguments = self.get_scan_arguments(scan_type)
|
||||
result = await self.scan_host(network, arguments)
|
||||
|
||||
if result:
|
||||
return [result]
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning network {network}: {e}")
|
||||
return []
|
||||
213
teamleader_test/app/scanner/port_scanner.py
Normal file
213
teamleader_test/app/scanner/port_scanner.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Port scanner implementation."""
|
||||
|
||||
import socket
|
||||
import asyncio
|
||||
from typing import List, Dict, Set, Optional, Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PortScanner:
|
||||
"""Scanner for detecting open ports on hosts."""
|
||||
|
||||
# Predefined port ranges for different scan types
|
||||
PORT_RANGES = {
|
||||
'quick': [21, 22, 23, 25, 53, 80, 110, 143, 443, 445, 3306, 3389, 5432, 8080, 8443],
|
||||
'standard': list(range(1, 1001)),
|
||||
'deep': list(range(1, 65536)),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = None,
|
||||
max_workers: int = None,
|
||||
progress_callback: Optional[Callable[[str, int, float], None]] = None
|
||||
):
|
||||
"""
|
||||
Initialize port scanner.
|
||||
|
||||
Args:
|
||||
timeout: Socket connection timeout in seconds
|
||||
max_workers: Maximum number of concurrent workers
|
||||
progress_callback: Optional callback for progress updates (host, port, progress)
|
||||
"""
|
||||
self.timeout = timeout or settings.default_scan_timeout
|
||||
self.max_workers = max_workers or settings.max_concurrent_scans
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
async def scan_host_ports(
|
||||
self,
|
||||
host: str,
|
||||
scan_type: str = 'quick',
|
||||
custom_ports: Optional[List[int]] = None
|
||||
) -> List[Dict[str, any]]:
|
||||
"""
|
||||
Scan ports on a single host.
|
||||
|
||||
Args:
|
||||
host: IP address or hostname
|
||||
scan_type: Type of scan ('quick', 'standard', 'deep', or 'custom')
|
||||
custom_ports: Custom port list (required if scan_type is 'custom')
|
||||
|
||||
Returns:
|
||||
List of dictionaries with port information
|
||||
"""
|
||||
logger.info(f"Starting port scan on {host} (type: {scan_type})")
|
||||
|
||||
# Determine ports to scan
|
||||
if scan_type == 'custom' and custom_ports:
|
||||
ports = custom_ports
|
||||
elif scan_type in self.PORT_RANGES:
|
||||
ports = self.PORT_RANGES[scan_type]
|
||||
else:
|
||||
ports = self.PORT_RANGES['quick']
|
||||
|
||||
# Scan ports
|
||||
open_ports = await self._scan_ports_async(host, ports)
|
||||
|
||||
logger.info(f"Scan completed on {host}. Found {len(open_ports)} open ports")
|
||||
return open_ports
|
||||
|
||||
async def _scan_ports_async(self, host: str, ports: List[int]) -> List[Dict[str, any]]:
|
||||
"""
|
||||
Scan multiple ports asynchronously.
|
||||
|
||||
Args:
|
||||
host: Host to scan
|
||||
ports: List of ports to scan
|
||||
|
||||
Returns:
|
||||
List of open port information
|
||||
"""
|
||||
open_ports = []
|
||||
total = len(ports)
|
||||
completed = 0
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
futures = []
|
||||
|
||||
for port in ports:
|
||||
future = loop.run_in_executor(executor, self._check_port, host, port)
|
||||
futures.append((port, future))
|
||||
|
||||
# Process results
|
||||
for port, future in futures:
|
||||
try:
|
||||
result = await future
|
||||
if result:
|
||||
open_ports.append(result)
|
||||
logger.debug(f"Found open port {port} on {host}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking port {port} on {host}: {e}")
|
||||
finally:
|
||||
completed += 1
|
||||
if self.progress_callback:
|
||||
progress = completed / total
|
||||
self.progress_callback(host, port, progress)
|
||||
|
||||
return sorted(open_ports, key=lambda x: x['port'])
|
||||
|
||||
def _check_port(self, host: str, port: int) -> Optional[Dict[str, any]]:
|
||||
"""
|
||||
Check if a port is open on a host.
|
||||
|
||||
Args:
|
||||
host: Host to check
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Dictionary with port info if open, None otherwise
|
||||
"""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
result = sock.connect_ex((host, port))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
return {
|
||||
'port': port,
|
||||
'protocol': 'tcp',
|
||||
'state': 'open',
|
||||
'service_name': self._guess_service_name(port)
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except socket.error as e:
|
||||
logger.debug(f"Socket error checking {host}:{port}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking {host}:{port}: {e}")
|
||||
return None
|
||||
|
||||
def _guess_service_name(self, port: int) -> Optional[str]:
|
||||
"""
|
||||
Guess service name based on well-known ports.
|
||||
|
||||
Args:
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Service name or None
|
||||
"""
|
||||
common_services = {
|
||||
20: 'ftp-data',
|
||||
21: 'ftp',
|
||||
22: 'ssh',
|
||||
23: 'telnet',
|
||||
25: 'smtp',
|
||||
53: 'dns',
|
||||
80: 'http',
|
||||
110: 'pop3',
|
||||
143: 'imap',
|
||||
443: 'https',
|
||||
445: 'smb',
|
||||
3306: 'mysql',
|
||||
3389: 'rdp',
|
||||
5432: 'postgresql',
|
||||
5900: 'vnc',
|
||||
8080: 'http-alt',
|
||||
8443: 'https-alt',
|
||||
}
|
||||
|
||||
return common_services.get(port)
|
||||
|
||||
def parse_port_range(self, port_range: str) -> List[int]:
|
||||
"""
|
||||
Parse port range string to list of ports.
|
||||
|
||||
Args:
|
||||
port_range: String like "80,443,8000-8100"
|
||||
|
||||
Returns:
|
||||
List of port numbers
|
||||
"""
|
||||
ports = set()
|
||||
|
||||
try:
|
||||
for part in port_range.split(','):
|
||||
part = part.strip()
|
||||
|
||||
if '-' in part:
|
||||
# Range like "8000-8100"
|
||||
start, end = map(int, part.split('-'))
|
||||
if 1 <= start <= end <= 65535:
|
||||
ports.update(range(start, end + 1))
|
||||
else:
|
||||
# Single port
|
||||
port = int(part)
|
||||
if 1 <= port <= 65535:
|
||||
ports.add(port)
|
||||
|
||||
return sorted(list(ports))
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error parsing port range '{port_range}': {e}")
|
||||
return []
|
||||
250
teamleader_test/app/scanner/service_detector.py
Normal file
250
teamleader_test/app/scanner/service_detector.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Service detection and banner grabbing implementation."""
|
||||
|
||||
import socket
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceDetector:
|
||||
"""Detector for identifying services running on open ports."""
|
||||
|
||||
def __init__(self, timeout: int = 3):
|
||||
"""
|
||||
Initialize service detector.
|
||||
|
||||
Args:
|
||||
timeout: Socket timeout in seconds
|
||||
"""
|
||||
self.timeout = timeout
|
||||
|
||||
def detect_service(self, host: str, port: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect service on a specific port.
|
||||
|
||||
Args:
|
||||
host: Host IP or hostname
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Dictionary with service information
|
||||
"""
|
||||
service_info = {
|
||||
'port': port,
|
||||
'protocol': 'tcp',
|
||||
'service_name': None,
|
||||
'service_version': None,
|
||||
'banner': None
|
||||
}
|
||||
|
||||
# Try banner grabbing
|
||||
banner = self.grab_banner(host, port)
|
||||
if banner:
|
||||
service_info['banner'] = banner
|
||||
|
||||
# Try to identify service from banner
|
||||
service_name, version = self._identify_from_banner(banner, port)
|
||||
if service_name:
|
||||
service_info['service_name'] = service_name
|
||||
if version:
|
||||
service_info['service_version'] = version
|
||||
|
||||
# If no banner, use port-based guess
|
||||
if not service_info['service_name']:
|
||||
service_info['service_name'] = self._guess_service_from_port(port)
|
||||
|
||||
return service_info
|
||||
|
||||
def grab_banner(self, host: str, port: int) -> Optional[str]:
|
||||
"""
|
||||
Attempt to grab service banner.
|
||||
|
||||
Args:
|
||||
host: Host IP or hostname
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Banner string or None
|
||||
"""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
sock.connect((host, port))
|
||||
|
||||
# Try to receive banner
|
||||
try:
|
||||
banner = sock.recv(1024)
|
||||
banner_str = banner.decode('utf-8', errors='ignore').strip()
|
||||
sock.close()
|
||||
|
||||
if banner_str:
|
||||
logger.debug(f"Got banner from {host}:{port}: {banner_str[:100]}")
|
||||
return banner_str
|
||||
except socket.timeout:
|
||||
# Try sending a probe for services that need it
|
||||
banner_str = self._probe_service(sock, port)
|
||||
sock.close()
|
||||
return banner_str
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error grabbing banner from {host}:{port}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _probe_service(self, sock: socket.socket, port: int) -> Optional[str]:
|
||||
"""
|
||||
Send service-specific probe to elicit response.
|
||||
|
||||
Args:
|
||||
sock: Connected socket
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Response string or None
|
||||
"""
|
||||
probes = {
|
||||
80: b"GET / HTTP/1.0\r\n\r\n",
|
||||
443: b"GET / HTTP/1.0\r\n\r\n",
|
||||
8080: b"GET / HTTP/1.0\r\n\r\n",
|
||||
8443: b"GET / HTTP/1.0\r\n\r\n",
|
||||
25: b"EHLO test\r\n",
|
||||
110: b"USER test\r\n",
|
||||
143: b"A001 CAPABILITY\r\n",
|
||||
}
|
||||
|
||||
probe = probes.get(port, b"\r\n")
|
||||
|
||||
try:
|
||||
sock.send(probe)
|
||||
response = sock.recv(1024)
|
||||
return response.decode('utf-8', errors='ignore').strip()
|
||||
except:
|
||||
return None
|
||||
|
||||
def _identify_from_banner(self, banner: str, port: int) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Identify service and version from banner.
|
||||
|
||||
Args:
|
||||
banner: Banner string
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Tuple of (service_name, version)
|
||||
"""
|
||||
banner_lower = banner.lower()
|
||||
|
||||
# HTTP servers
|
||||
if 'http' in banner_lower or port in [80, 443, 8080, 8443]:
|
||||
if 'apache' in banner_lower:
|
||||
return self._extract_apache_version(banner)
|
||||
elif 'nginx' in banner_lower:
|
||||
return self._extract_nginx_version(banner)
|
||||
elif 'iis' in banner_lower or 'microsoft' in banner_lower:
|
||||
return 'IIS', None
|
||||
else:
|
||||
return 'HTTP', None
|
||||
|
||||
# SSH
|
||||
if 'ssh' in banner_lower or port == 22:
|
||||
if 'openssh' in banner_lower:
|
||||
return self._extract_openssh_version(banner)
|
||||
return 'SSH', None
|
||||
|
||||
# FTP
|
||||
if 'ftp' in banner_lower or port in [20, 21]:
|
||||
if 'filezilla' in banner_lower:
|
||||
return 'FileZilla FTP', None
|
||||
elif 'proftpd' in banner_lower:
|
||||
return 'ProFTPD', None
|
||||
return 'FTP', None
|
||||
|
||||
# SMTP
|
||||
if 'smtp' in banner_lower or 'mail' in banner_lower or port == 25:
|
||||
if 'postfix' in banner_lower:
|
||||
return 'Postfix', None
|
||||
elif 'exim' in banner_lower:
|
||||
return 'Exim', None
|
||||
return 'SMTP', None
|
||||
|
||||
# MySQL
|
||||
if 'mysql' in banner_lower or port == 3306:
|
||||
return 'MySQL', None
|
||||
|
||||
# PostgreSQL
|
||||
if 'postgresql' in banner_lower or port == 5432:
|
||||
return 'PostgreSQL', None
|
||||
|
||||
# Generic identification
|
||||
if port == 22:
|
||||
return 'SSH', None
|
||||
elif port in [80, 8080]:
|
||||
return 'HTTP', None
|
||||
elif port in [443, 8443]:
|
||||
return 'HTTPS', None
|
||||
|
||||
return None, None
|
||||
|
||||
def _extract_apache_version(self, banner: str) -> tuple[str, Optional[str]]:
|
||||
"""Extract Apache version from banner."""
|
||||
import re
|
||||
match = re.search(r'Apache/?([\d.]+)?', banner, re.IGNORECASE)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
return 'Apache', version
|
||||
return 'Apache', None
|
||||
|
||||
def _extract_nginx_version(self, banner: str) -> tuple[str, Optional[str]]:
|
||||
"""Extract nginx version from banner."""
|
||||
import re
|
||||
match = re.search(r'nginx/?([\d.]+)?', banner, re.IGNORECASE)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
return 'nginx', version
|
||||
return 'nginx', None
|
||||
|
||||
def _extract_openssh_version(self, banner: str) -> tuple[str, Optional[str]]:
|
||||
"""Extract OpenSSH version from banner."""
|
||||
import re
|
||||
match = re.search(r'OpenSSH[_/]?([\d.]+\w*)?', banner, re.IGNORECASE)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
return 'OpenSSH', version
|
||||
return 'OpenSSH', None
|
||||
|
||||
def _guess_service_from_port(self, port: int) -> Optional[str]:
|
||||
"""
|
||||
Guess service name from well-known port number.
|
||||
|
||||
Args:
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Service name or None
|
||||
"""
|
||||
common_services = {
|
||||
20: 'ftp-data',
|
||||
21: 'ftp',
|
||||
22: 'ssh',
|
||||
23: 'telnet',
|
||||
25: 'smtp',
|
||||
53: 'dns',
|
||||
80: 'http',
|
||||
110: 'pop3',
|
||||
143: 'imap',
|
||||
443: 'https',
|
||||
445: 'smb',
|
||||
993: 'imaps',
|
||||
995: 'pop3s',
|
||||
3306: 'mysql',
|
||||
3389: 'rdp',
|
||||
5432: 'postgresql',
|
||||
5900: 'vnc',
|
||||
6379: 'redis',
|
||||
8080: 'http-alt',
|
||||
8443: 'https-alt',
|
||||
27017: 'mongodb',
|
||||
}
|
||||
|
||||
return common_services.get(port)
|
||||
256
teamleader_test/app/schemas.py
Normal file
256
teamleader_test/app/schemas.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Pydantic schemas for API request/response validation."""
|
||||
|
||||
from pydantic import BaseModel, Field, IPvAnyAddress, field_validator
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScanType(str, Enum):
|
||||
"""Scan type enumeration."""
|
||||
QUICK = "quick"
|
||||
STANDARD = "standard"
|
||||
DEEP = "deep"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ScanStatus(str, Enum):
|
||||
"""Scan status enumeration."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class HostStatus(str, Enum):
|
||||
"""Host status enumeration."""
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
SCANNING = "scanning"
|
||||
|
||||
|
||||
class ConnectionType(str, Enum):
|
||||
"""Connection type enumeration."""
|
||||
GATEWAY = "gateway"
|
||||
SAME_SUBNET = "same_subnet"
|
||||
SERVICE = "service"
|
||||
INFERRED = "inferred"
|
||||
|
||||
|
||||
# Service schemas
|
||||
class ServiceBase(BaseModel):
|
||||
"""Base service schema."""
|
||||
port: int = Field(..., ge=1, le=65535)
|
||||
protocol: str = Field(default="tcp", pattern="^(tcp|udp)$")
|
||||
state: str = Field(default="open")
|
||||
service_name: Optional[str] = None
|
||||
service_version: Optional[str] = None
|
||||
banner: Optional[str] = None
|
||||
|
||||
|
||||
class ServiceCreate(ServiceBase):
|
||||
"""Schema for creating a service."""
|
||||
host_id: int
|
||||
|
||||
|
||||
class ServiceResponse(ServiceBase):
|
||||
"""Schema for service response."""
|
||||
id: int
|
||||
host_id: int
|
||||
first_seen: datetime
|
||||
last_seen: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Host schemas
|
||||
class HostBase(BaseModel):
|
||||
"""Base host schema."""
|
||||
ip_address: str
|
||||
hostname: Optional[str] = None
|
||||
mac_address: Optional[str] = None
|
||||
|
||||
@field_validator('ip_address')
|
||||
@classmethod
|
||||
def validate_ip(cls, v: str) -> str:
|
||||
"""Validate IP address format."""
|
||||
import ipaddress
|
||||
try:
|
||||
ipaddress.ip_address(v)
|
||||
return v
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid IP address: {v}")
|
||||
|
||||
|
||||
class HostCreate(HostBase):
|
||||
"""Schema for creating a host."""
|
||||
device_type: Optional[str] = None
|
||||
os_guess: Optional[str] = None
|
||||
vendor: Optional[str] = None
|
||||
|
||||
|
||||
class HostResponse(HostBase):
|
||||
"""Schema for host response."""
|
||||
id: int
|
||||
first_seen: datetime
|
||||
last_seen: datetime
|
||||
status: HostStatus
|
||||
device_type: Optional[str] = None
|
||||
os_guess: Optional[str] = None
|
||||
vendor: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
services: List[ServiceResponse] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HostDetailResponse(HostResponse):
|
||||
"""Detailed host response with connection info."""
|
||||
outgoing_connections: List['ConnectionResponse'] = []
|
||||
incoming_connections: List['ConnectionResponse'] = []
|
||||
|
||||
|
||||
# Connection schemas
|
||||
class ConnectionBase(BaseModel):
|
||||
"""Base connection schema."""
|
||||
source_host_id: int
|
||||
target_host_id: int
|
||||
connection_type: ConnectionType
|
||||
protocol: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
confidence: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ConnectionCreate(ConnectionBase):
|
||||
"""Schema for creating a connection."""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConnectionResponse(ConnectionBase):
|
||||
"""Schema for connection response."""
|
||||
id: int
|
||||
detected_at: datetime
|
||||
last_verified: Optional[datetime] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Scan schemas
|
||||
class ScanConfigRequest(BaseModel):
|
||||
"""Schema for scan configuration request."""
|
||||
network_range: str
|
||||
scan_type: ScanType = Field(default=ScanType.QUICK)
|
||||
port_range: Optional[str] = None
|
||||
include_service_detection: bool = True
|
||||
use_nmap: bool = True
|
||||
|
||||
@field_validator('network_range')
|
||||
@classmethod
|
||||
def validate_network(cls, v: str) -> str:
|
||||
"""Validate network range format."""
|
||||
import ipaddress
|
||||
try:
|
||||
network = ipaddress.ip_network(v, strict=False)
|
||||
# Check if it's a private network
|
||||
if not network.is_private:
|
||||
raise ValueError("Only private network ranges are allowed")
|
||||
return v
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid network range: {e}")
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
"""Schema for scan response."""
|
||||
id: int
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
scan_type: str
|
||||
network_range: str
|
||||
status: ScanStatus
|
||||
hosts_found: int = 0
|
||||
ports_scanned: int = 0
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ScanStatusResponse(ScanResponse):
|
||||
"""Schema for detailed scan status response."""
|
||||
progress: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||
current_host: Optional[str] = None
|
||||
estimated_completion: Optional[datetime] = None
|
||||
|
||||
|
||||
class ScanStartResponse(BaseModel):
|
||||
"""Schema for scan start response."""
|
||||
scan_id: int
|
||||
message: str
|
||||
status: ScanStatus
|
||||
|
||||
|
||||
# Topology schemas
|
||||
class TopologyNode(BaseModel):
|
||||
"""Schema for topology graph node."""
|
||||
id: str
|
||||
ip: str
|
||||
hostname: Optional[str]
|
||||
type: str
|
||||
status: str
|
||||
service_count: int
|
||||
connections: int = 0
|
||||
|
||||
|
||||
class TopologyEdge(BaseModel):
|
||||
"""Schema for topology graph edge."""
|
||||
source: str
|
||||
target: str
|
||||
type: str = "default"
|
||||
confidence: float = 0.5
|
||||
|
||||
|
||||
class TopologyResponse(BaseModel):
|
||||
"""Schema for topology graph response."""
|
||||
nodes: List[TopologyNode]
|
||||
edges: List[TopologyEdge]
|
||||
statistics: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# WebSocket message schemas
|
||||
class WSMessageType(str, Enum):
|
||||
"""WebSocket message type enumeration."""
|
||||
SCAN_STARTED = "scan_started"
|
||||
SCAN_PROGRESS = "scan_progress"
|
||||
HOST_DISCOVERED = "host_discovered"
|
||||
SERVICE_DISCOVERED = "service_discovered"
|
||||
SCAN_COMPLETED = "scan_completed"
|
||||
SCAN_FAILED = "scan_failed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class WSMessage(BaseModel):
|
||||
"""Schema for WebSocket messages."""
|
||||
type: WSMessageType
|
||||
data: Dict[str, Any]
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# Statistics schemas
|
||||
class NetworkStatistics(BaseModel):
|
||||
"""Schema for network statistics."""
|
||||
total_hosts: int
|
||||
online_hosts: int
|
||||
offline_hosts: int
|
||||
total_services: int
|
||||
total_scans: int
|
||||
last_scan: Optional[datetime] = None
|
||||
most_common_services: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
# Rebuild models to resolve forward references
|
||||
HostDetailResponse.model_rebuild()
|
||||
6
teamleader_test/app/services/__init__.py
Normal file
6
teamleader_test/app/services/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Business logic services."""
|
||||
|
||||
from app.services.scan_service import ScanService
|
||||
from app.services.topology_service import TopologyService
|
||||
|
||||
__all__ = ['ScanService', 'TopologyService']
|
||||
553
teamleader_test/app/services/scan_service.py
Normal file
553
teamleader_test/app/services/scan_service.py
Normal file
@@ -0,0 +1,553 @@
|
||||
"""Scan service for orchestrating network scanning operations."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Scan, Host, Service, Connection
|
||||
from app.schemas import ScanConfigRequest, ScanStatus as ScanStatusEnum
|
||||
from app.scanner.network_scanner import NetworkScanner
|
||||
from app.scanner.port_scanner import PortScanner
|
||||
from app.scanner.service_detector import ServiceDetector
|
||||
from app.scanner.nmap_scanner import NmapScanner
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScanService:
|
||||
"""Service for managing network scans."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize scan service.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
self.db = db
|
||||
self.active_scans: Dict[int, asyncio.Task] = {}
|
||||
self.cancel_requested: Dict[int, bool] = {}
|
||||
|
||||
def create_scan(self, config: ScanConfigRequest) -> Scan:
|
||||
"""
|
||||
Create a new scan record.
|
||||
|
||||
Args:
|
||||
config: Scan configuration
|
||||
|
||||
Returns:
|
||||
Created scan object
|
||||
"""
|
||||
scan = Scan(
|
||||
scan_type=config.scan_type.value,
|
||||
network_range=config.network_range,
|
||||
status=ScanStatusEnum.PENDING.value,
|
||||
started_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(scan)
|
||||
self.db.commit()
|
||||
self.db.refresh(scan)
|
||||
|
||||
logger.info(f"Created scan {scan.id} for {config.network_range}")
|
||||
return scan
|
||||
|
||||
def cancel_scan(self, scan_id: int) -> bool:
|
||||
"""
|
||||
Cancel a running scan.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID to cancel
|
||||
|
||||
Returns:
|
||||
True if scan was cancelled, False if not found or not running
|
||||
"""
|
||||
try:
|
||||
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
|
||||
if not scan:
|
||||
logger.warning(f"Scan {scan_id} not found")
|
||||
return False
|
||||
|
||||
if scan.status not in [ScanStatusEnum.PENDING.value, ScanStatusEnum.RUNNING.value]:
|
||||
logger.warning(f"Scan {scan_id} is not running (status: {scan.status})")
|
||||
return False
|
||||
|
||||
# Mark for cancellation
|
||||
self.cancel_requested[scan_id] = True
|
||||
|
||||
# Cancel the task if it exists
|
||||
if scan_id in self.active_scans:
|
||||
task = self.active_scans[scan_id]
|
||||
task.cancel()
|
||||
del self.active_scans[scan_id]
|
||||
|
||||
# Update scan status
|
||||
scan.status = ScanStatusEnum.CANCELLED.value
|
||||
scan.completed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Cancelled scan {scan_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling scan {scan_id}: {e}")
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
async def execute_scan(
|
||||
self,
|
||||
scan_id: int,
|
||||
config: ScanConfigRequest,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> None:
|
||||
"""
|
||||
Execute a network scan.
|
||||
|
||||
Args:
|
||||
scan_id: Scan ID
|
||||
config: Scan configuration
|
||||
progress_callback: Optional callback for progress updates
|
||||
"""
|
||||
scan = self.db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
logger.error(f"Scan {scan_id} not found")
|
||||
return
|
||||
|
||||
try:
|
||||
# Initialize cancellation flag
|
||||
self.cancel_requested[scan_id] = False
|
||||
|
||||
# Update scan status
|
||||
scan.status = ScanStatusEnum.RUNNING.value
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Starting scan {scan_id}")
|
||||
|
||||
# Check for cancellation
|
||||
if self.cancel_requested.get(scan_id):
|
||||
raise asyncio.CancelledError("Scan cancelled by user")
|
||||
|
||||
# Initialize scanners
|
||||
network_scanner = NetworkScanner(
|
||||
progress_callback=lambda host, progress: self._on_host_progress(
|
||||
scan_id, host, progress, progress_callback
|
||||
)
|
||||
)
|
||||
|
||||
# Phase 1: Host Discovery
|
||||
logger.info(f"Phase 1: Discovering hosts in {config.network_range}")
|
||||
active_hosts = await network_scanner.scan_network(config.network_range)
|
||||
|
||||
scan.hosts_found = len(active_hosts)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Found {len(active_hosts)} active hosts")
|
||||
|
||||
# Check for cancellation
|
||||
if self.cancel_requested.get(scan_id):
|
||||
raise asyncio.CancelledError("Scan cancelled by user")
|
||||
|
||||
# Send progress update
|
||||
if progress_callback:
|
||||
await progress_callback({
|
||||
'type': 'scan_progress',
|
||||
'scan_id': scan_id,
|
||||
'progress': 0.3,
|
||||
'current_host': f"Found {len(active_hosts)} hosts"
|
||||
})
|
||||
|
||||
# Phase 2: Port Scanning and Service Detection
|
||||
if config.use_nmap and settings.enable_nmap:
|
||||
await self._scan_with_nmap(scan, active_hosts, config, progress_callback)
|
||||
else:
|
||||
await self._scan_with_socket(scan, active_hosts, config, progress_callback)
|
||||
|
||||
# Phase 3: Detect Connections
|
||||
await self._detect_connections(scan, network_scanner)
|
||||
|
||||
# Mark scan as completed
|
||||
scan.status = ScanStatusEnum.COMPLETED.value
|
||||
scan.completed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Scan {scan_id} completed successfully")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback({
|
||||
'type': 'scan_completed',
|
||||
'scan_id': scan_id,
|
||||
'hosts_found': scan.hosts_found
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Scan {scan_id} was cancelled")
|
||||
|
||||
scan.status = ScanStatusEnum.CANCELLED.value
|
||||
scan.completed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback({
|
||||
'type': 'scan_completed',
|
||||
'scan_id': scan_id,
|
||||
'hosts_found': scan.hosts_found
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing scan {scan_id}: {e}", exc_info=True)
|
||||
|
||||
scan.status = ScanStatusEnum.FAILED.value
|
||||
scan.error_message = str(e)
|
||||
scan.completed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback({
|
||||
'type': 'scan_failed',
|
||||
'scan_id': scan_id,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
self.cancel_requested.pop(scan_id, None)
|
||||
self.active_scans.pop(scan_id, None)
|
||||
|
||||
async def _scan_with_socket(
|
||||
self,
|
||||
scan: Scan,
|
||||
hosts: list,
|
||||
config: ScanConfigRequest,
|
||||
progress_callback: Optional[callable]
|
||||
) -> None:
|
||||
"""Scan hosts using socket-based scanning."""
|
||||
port_scanner = PortScanner(
|
||||
progress_callback=lambda host, port, progress: self._on_port_progress(
|
||||
scan.id, host, port, progress, progress_callback
|
||||
)
|
||||
)
|
||||
service_detector = ServiceDetector()
|
||||
|
||||
for idx, ip in enumerate(hosts, 1):
|
||||
try:
|
||||
# Check for cancellation
|
||||
if self.cancel_requested.get(scan.id):
|
||||
logger.info(f"Scan {scan.id} cancelled during port scanning")
|
||||
raise asyncio.CancelledError("Scan cancelled by user")
|
||||
|
||||
logger.info(f"Scanning host {idx}/{len(hosts)}: {ip}")
|
||||
|
||||
# Get or create host
|
||||
host = self._get_or_create_host(ip)
|
||||
self.db.commit() # Commit to ensure host.id is set
|
||||
self.db.refresh(host)
|
||||
|
||||
# Send host discovered notification
|
||||
if progress_callback:
|
||||
await progress_callback({
|
||||
'type': 'host_discovered',
|
||||
'scan_id': scan.id,
|
||||
'host': {
|
||||
'ip_address': ip,
|
||||
'status': 'online'
|
||||
}
|
||||
})
|
||||
|
||||
# Scan ports
|
||||
custom_ports = None
|
||||
if config.port_range:
|
||||
custom_ports = port_scanner.parse_port_range(config.port_range)
|
||||
|
||||
open_ports = await port_scanner.scan_host_ports(
|
||||
ip,
|
||||
scan_type=config.scan_type.value,
|
||||
custom_ports=custom_ports
|
||||
)
|
||||
|
||||
scan.ports_scanned += len(open_ports)
|
||||
|
||||
# Detect services
|
||||
if config.include_service_detection:
|
||||
for port_info in open_ports:
|
||||
service_info = service_detector.detect_service(ip, port_info['port'])
|
||||
port_info.update(service_info)
|
||||
|
||||
# Store services
|
||||
self._store_services(host, open_ports)
|
||||
|
||||
# Associate host with scan
|
||||
if host not in scan.hosts:
|
||||
scan.hosts.append(host)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# Send progress update
|
||||
if progress_callback:
|
||||
progress = 0.3 + (0.6 * (idx / len(hosts))) # 30-90% for port scanning
|
||||
await progress_callback({
|
||||
'type': 'scan_progress',
|
||||
'scan_id': scan.id,
|
||||
'progress': progress,
|
||||
'current_host': f"Scanning {ip} ({idx}/{len(hosts)})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning host {ip}: {e}")
|
||||
continue
|
||||
|
||||
async def _scan_with_nmap(
|
||||
self,
|
||||
scan: Scan,
|
||||
hosts: list,
|
||||
config: ScanConfigRequest,
|
||||
progress_callback: Optional[callable]
|
||||
) -> None:
|
||||
"""Scan hosts using nmap."""
|
||||
nmap_scanner = NmapScanner()
|
||||
|
||||
if not nmap_scanner.nmap_available:
|
||||
logger.warning("Nmap not available, falling back to socket scanning")
|
||||
await self._scan_with_socket(scan, hosts, config, progress_callback)
|
||||
return
|
||||
|
||||
# Scan each host with nmap
|
||||
for idx, ip in enumerate(hosts, 1):
|
||||
try:
|
||||
logger.info(f"Scanning host {idx}/{len(hosts)} with nmap: {ip}")
|
||||
|
||||
# Get or create host
|
||||
host = self._get_or_create_host(ip)
|
||||
self.db.commit() # Commit to ensure host.id is set
|
||||
self.db.refresh(host)
|
||||
|
||||
# Build nmap arguments
|
||||
port_range = config.port_range if config.port_range else None
|
||||
arguments = nmap_scanner.get_scan_arguments(
|
||||
scan_type=config.scan_type.value,
|
||||
service_detection=config.include_service_detection,
|
||||
port_range=port_range
|
||||
)
|
||||
|
||||
# Execute nmap scan
|
||||
result = await nmap_scanner.scan_host(ip, arguments)
|
||||
|
||||
if result:
|
||||
# Update hostname if available
|
||||
if result.get('hostname'):
|
||||
host.hostname = result['hostname']
|
||||
|
||||
# Store services
|
||||
if result.get('ports'):
|
||||
self._store_services(host, result['ports'])
|
||||
scan.ports_scanned += len(result['ports'])
|
||||
|
||||
# Store OS information
|
||||
if result.get('os_matches'):
|
||||
best_match = max(result['os_matches'], key=lambda x: float(x['accuracy']))
|
||||
host.os_guess = best_match['name']
|
||||
|
||||
# Associate host with scan
|
||||
if host not in scan.hosts:
|
||||
scan.hosts.append(host)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning host {ip} with nmap: {e}")
|
||||
continue
|
||||
|
||||
def _get_or_create_host(self, ip: str) -> Host:
|
||||
"""Get existing host or create new one."""
|
||||
host = self.db.query(Host).filter(Host.ip_address == ip).first()
|
||||
|
||||
if host:
|
||||
host.last_seen = datetime.utcnow()
|
||||
host.status = 'online'
|
||||
else:
|
||||
host = Host(
|
||||
ip_address=ip,
|
||||
status='online',
|
||||
first_seen=datetime.utcnow(),
|
||||
last_seen=datetime.utcnow()
|
||||
)
|
||||
self.db.add(host)
|
||||
|
||||
return host
|
||||
|
||||
def _store_services(self, host: Host, services_data: list) -> None:
|
||||
"""Store or update services for a host."""
|
||||
for service_info in services_data:
|
||||
# Check if service already exists
|
||||
service = self.db.query(Service).filter(
|
||||
Service.host_id == host.id,
|
||||
Service.port == service_info['port'],
|
||||
Service.protocol == service_info.get('protocol', 'tcp')
|
||||
).first()
|
||||
|
||||
if service:
|
||||
# Update existing service
|
||||
service.last_seen = datetime.utcnow()
|
||||
service.state = service_info.get('state', 'open')
|
||||
if service_info.get('service_name'):
|
||||
service.service_name = service_info['service_name']
|
||||
if service_info.get('service_version'):
|
||||
service.service_version = service_info['service_version']
|
||||
if service_info.get('banner'):
|
||||
service.banner = service_info['banner']
|
||||
else:
|
||||
# Create new service
|
||||
service = Service(
|
||||
host_id=host.id,
|
||||
port=service_info['port'],
|
||||
protocol=service_info.get('protocol', 'tcp'),
|
||||
state=service_info.get('state', 'open'),
|
||||
service_name=service_info.get('service_name'),
|
||||
service_version=service_info.get('service_version'),
|
||||
banner=service_info.get('banner'),
|
||||
first_seen=datetime.utcnow(),
|
||||
last_seen=datetime.utcnow()
|
||||
)
|
||||
self.db.add(service)
|
||||
|
||||
async def _detect_connections(self, scan: Scan, network_scanner: NetworkScanner) -> None:
|
||||
"""Detect connections between hosts."""
|
||||
try:
|
||||
# Get gateway
|
||||
gateway_ip = network_scanner.get_local_network_range()
|
||||
if gateway_ip:
|
||||
gateway_network = gateway_ip.split('/')[0].rsplit('.', 1)[0] + '.1'
|
||||
|
||||
# Find or create gateway host
|
||||
gateway_host = self.db.query(Host).filter(
|
||||
Host.ip_address == gateway_network
|
||||
).first()
|
||||
|
||||
if gateway_host:
|
||||
# Connect all hosts to gateway
|
||||
for host in scan.hosts:
|
||||
if host.id != gateway_host.id:
|
||||
self._create_connection(
|
||||
host.id,
|
||||
gateway_host.id,
|
||||
'gateway',
|
||||
confidence=0.9
|
||||
)
|
||||
|
||||
# Create connections based on services
|
||||
for host in scan.hosts:
|
||||
for service in host.services:
|
||||
# If host has client-type services, it might connect to servers
|
||||
if service.service_name in ['http', 'https', 'ssh']:
|
||||
# Find potential servers on the network
|
||||
for other_host in scan.hosts:
|
||||
if other_host.id != host.id:
|
||||
for other_service in other_host.services:
|
||||
if (other_service.port == service.port and
|
||||
other_service.service_name in ['http', 'https', 'ssh']):
|
||||
self._create_connection(
|
||||
host.id,
|
||||
other_host.id,
|
||||
'service',
|
||||
protocol='tcp',
|
||||
port=service.port,
|
||||
confidence=0.5
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting connections: {e}")
|
||||
|
||||
def _create_connection(
|
||||
self,
|
||||
source_id: int,
|
||||
target_id: int,
|
||||
conn_type: str,
|
||||
protocol: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
confidence: float = 1.0
|
||||
) -> None:
|
||||
"""Create a connection if it doesn't exist."""
|
||||
existing = self.db.query(Connection).filter(
|
||||
Connection.source_host_id == source_id,
|
||||
Connection.target_host_id == target_id,
|
||||
Connection.connection_type == conn_type
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
connection = Connection(
|
||||
source_host_id=source_id,
|
||||
target_host_id=target_id,
|
||||
connection_type=conn_type,
|
||||
protocol=protocol,
|
||||
port=port,
|
||||
confidence=confidence,
|
||||
detected_at=datetime.utcnow()
|
||||
)
|
||||
self.db.add(connection)
|
||||
|
||||
def _on_host_progress(
|
||||
self,
|
||||
scan_id: int,
|
||||
host: str,
|
||||
progress: float,
|
||||
callback: Optional[callable]
|
||||
) -> None:
|
||||
"""Handle host discovery progress."""
|
||||
if callback:
|
||||
asyncio.create_task(callback({
|
||||
'type': 'scan_progress',
|
||||
'scan_id': scan_id,
|
||||
'current_host': host,
|
||||
'progress': progress * 0.5 # Host discovery is first 50%
|
||||
}))
|
||||
|
||||
def _on_port_progress(
|
||||
self,
|
||||
scan_id: int,
|
||||
host: str,
|
||||
port: int,
|
||||
progress: float,
|
||||
callback: Optional[callable]
|
||||
) -> None:
|
||||
"""Handle port scanning progress."""
|
||||
if callback:
|
||||
asyncio.create_task(callback({
|
||||
'type': 'scan_progress',
|
||||
'scan_id': scan_id,
|
||||
'current_host': host,
|
||||
'current_port': port,
|
||||
'progress': 0.5 + (progress * 0.5) # Port scanning is second 50%
|
||||
}))
|
||||
|
||||
def get_scan_status(self, scan_id: int) -> Optional[Scan]:
|
||||
"""Get scan status by ID."""
|
||||
return self.db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
|
||||
def list_scans(self, limit: int = 50, offset: int = 0) -> list:
|
||||
"""List recent scans."""
|
||||
return self.db.query(Scan)\
|
||||
.order_by(Scan.started_at.desc())\
|
||||
.limit(limit)\
|
||||
.offset(offset)\
|
||||
.all()
|
||||
|
||||
def cancel_scan(self, scan_id: int) -> bool:
|
||||
"""Cancel a running scan."""
|
||||
if scan_id in self.active_scans:
|
||||
task = self.active_scans[scan_id]
|
||||
task.cancel()
|
||||
del self.active_scans[scan_id]
|
||||
|
||||
scan = self.get_scan_status(scan_id)
|
||||
if scan:
|
||||
scan.status = ScanStatusEnum.CANCELLED.value
|
||||
scan.completed_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
256
teamleader_test/app/services/topology_service.py
Normal file
256
teamleader_test/app/services/topology_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Topology service for network graph generation."""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.models import Host, Service, Connection
|
||||
from app.schemas import TopologyNode, TopologyEdge, TopologyResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TopologyService:
|
||||
"""Service for generating network topology graphs."""
|
||||
|
||||
# Node type colors
|
||||
NODE_COLORS = {
|
||||
'gateway': '#FF6B6B',
|
||||
'server': '#4ECDC4',
|
||||
'workstation': '#45B7D1',
|
||||
'device': '#96CEB4',
|
||||
'unknown': '#95A5A6'
|
||||
}
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize topology service.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def generate_topology(self, include_offline: bool = False) -> TopologyResponse:
|
||||
"""
|
||||
Generate network topology graph.
|
||||
|
||||
Args:
|
||||
include_offline: Include offline hosts
|
||||
|
||||
Returns:
|
||||
Topology response with nodes and edges
|
||||
"""
|
||||
logger.info("Generating network topology")
|
||||
|
||||
# Get hosts
|
||||
query = self.db.query(Host)
|
||||
if not include_offline:
|
||||
query = query.filter(Host.status == 'online')
|
||||
|
||||
hosts = query.all()
|
||||
|
||||
# Generate nodes
|
||||
nodes = []
|
||||
for host in hosts:
|
||||
node = self._create_node(host)
|
||||
nodes.append(node)
|
||||
|
||||
# Generate edges from connections
|
||||
edges = []
|
||||
connections = self.db.query(Connection).all()
|
||||
|
||||
for conn in connections:
|
||||
# Only include edges if both hosts are in the topology
|
||||
source_in_topology = any(n.id == str(conn.source_host_id) for n in nodes)
|
||||
target_in_topology = any(n.id == str(conn.target_host_id) for n in nodes)
|
||||
|
||||
if source_in_topology and target_in_topology:
|
||||
edge = self._create_edge(conn)
|
||||
edges.append(edge)
|
||||
|
||||
# Generate statistics
|
||||
statistics = self._generate_statistics(hosts, connections)
|
||||
|
||||
logger.info(f"Generated topology with {len(nodes)} nodes and {len(edges)} edges")
|
||||
|
||||
return TopologyResponse(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
statistics=statistics
|
||||
)
|
||||
|
||||
def _create_node(self, host: Host) -> TopologyNode:
|
||||
"""
|
||||
Create a topology node from a host.
|
||||
|
||||
Args:
|
||||
host: Host model
|
||||
|
||||
Returns:
|
||||
TopologyNode
|
||||
"""
|
||||
# Determine device type
|
||||
device_type = self._determine_device_type(host)
|
||||
|
||||
# Count connections
|
||||
connections = self.db.query(Connection).filter(
|
||||
(Connection.source_host_id == host.id) |
|
||||
(Connection.target_host_id == host.id)
|
||||
).count()
|
||||
|
||||
return TopologyNode(
|
||||
id=str(host.id),
|
||||
ip=host.ip_address,
|
||||
hostname=host.hostname,
|
||||
type=device_type,
|
||||
status=host.status,
|
||||
service_count=len(host.services),
|
||||
connections=connections
|
||||
)
|
||||
|
||||
def _determine_device_type(self, host: Host) -> str:
|
||||
"""
|
||||
Determine device type based on host information.
|
||||
|
||||
Args:
|
||||
host: Host model
|
||||
|
||||
Returns:
|
||||
Device type string
|
||||
"""
|
||||
# Check if explicitly set
|
||||
if host.device_type:
|
||||
return host.device_type
|
||||
|
||||
# Infer from services
|
||||
service_names = [s.service_name for s in host.services if s.service_name]
|
||||
|
||||
# Check for gateway indicators
|
||||
if any(s.port == 53 for s in host.services): # DNS server
|
||||
return 'gateway'
|
||||
|
||||
# Check for server indicators
|
||||
server_services = ['http', 'https', 'ssh', 'smtp', 'mysql', 'postgresql', 'ftp']
|
||||
if any(svc in service_names for svc in server_services):
|
||||
if len(host.services) > 5:
|
||||
return 'server'
|
||||
|
||||
# Check for workstation indicators
|
||||
if any(s.port == 3389 for s in host.services): # RDP
|
||||
return 'workstation'
|
||||
|
||||
# Default to device
|
||||
if len(host.services) > 0:
|
||||
return 'device'
|
||||
|
||||
return 'unknown'
|
||||
|
||||
def _create_edge(self, connection: Connection) -> TopologyEdge:
|
||||
"""
|
||||
Create a topology edge from a connection.
|
||||
|
||||
Args:
|
||||
connection: Connection model
|
||||
|
||||
Returns:
|
||||
TopologyEdge
|
||||
"""
|
||||
return TopologyEdge(
|
||||
source=str(connection.source_host_id),
|
||||
target=str(connection.target_host_id),
|
||||
type=connection.connection_type or 'default',
|
||||
confidence=connection.confidence
|
||||
)
|
||||
|
||||
|
||||
def _generate_statistics(
|
||||
self,
|
||||
hosts: List[Host],
|
||||
connections: List[Connection]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate statistics about the topology.
|
||||
|
||||
Args:
|
||||
hosts: List of hosts
|
||||
connections: List of connections
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
# Count isolated nodes (no connections)
|
||||
isolated = 0
|
||||
for host in hosts:
|
||||
conn_count = self.db.query(Connection).filter(
|
||||
(Connection.source_host_id == host.id) |
|
||||
(Connection.target_host_id == host.id)
|
||||
).count()
|
||||
if conn_count == 0:
|
||||
isolated += 1
|
||||
|
||||
# Calculate average connections
|
||||
avg_connections = len(connections) / max(len(hosts), 1) if hosts else 0
|
||||
|
||||
return {
|
||||
'total_nodes': len(hosts),
|
||||
'total_edges': len(connections),
|
||||
'isolated_nodes': isolated,
|
||||
'avg_connections': round(avg_connections, 2)
|
||||
}
|
||||
|
||||
def get_host_neighbors(self, host_id: int) -> List[Host]:
|
||||
"""
|
||||
Get all hosts connected to a specific host.
|
||||
|
||||
Args:
|
||||
host_id: Host ID
|
||||
|
||||
Returns:
|
||||
List of connected hosts
|
||||
"""
|
||||
# Get outgoing connections
|
||||
outgoing = self.db.query(Connection).filter(
|
||||
Connection.source_host_id == host_id
|
||||
).all()
|
||||
|
||||
# Get incoming connections
|
||||
incoming = self.db.query(Connection).filter(
|
||||
Connection.target_host_id == host_id
|
||||
).all()
|
||||
|
||||
# Collect unique neighbor IDs
|
||||
neighbor_ids = set()
|
||||
for conn in outgoing:
|
||||
neighbor_ids.add(conn.target_host_id)
|
||||
for conn in incoming:
|
||||
neighbor_ids.add(conn.source_host_id)
|
||||
|
||||
# Get host objects
|
||||
neighbors = self.db.query(Host).filter(
|
||||
Host.id.in_(neighbor_ids)
|
||||
).all()
|
||||
|
||||
return neighbors
|
||||
|
||||
def get_network_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get network statistics.
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
total_hosts = self.db.query(func.count(Host.id)).scalar()
|
||||
online_hosts = self.db.query(func.count(Host.id)).filter(
|
||||
Host.status == 'online'
|
||||
).scalar()
|
||||
total_services = self.db.query(func.count(Service.id)).scalar()
|
||||
|
||||
return {
|
||||
'total_hosts': total_hosts,
|
||||
'online_hosts': online_hosts,
|
||||
'offline_hosts': total_hosts - online_hosts,
|
||||
'total_services': total_services,
|
||||
'total_connections': self.db.query(func.count(Connection.id)).scalar()
|
||||
}
|
||||
Reference in New Issue
Block a user