Files
docker_dev/dockmon/backend/websocket/rate_limiter.py

108 lines
4.4 KiB
Python

"""
WebSocket Rate Limiting for DockMon
Provides generous rate limiting to prevent DoS while supporting large deployments
"""
import time
from collections import defaultdict, deque
from typing import Dict, Deque, Tuple
import logging
logger = logging.getLogger(__name__)
class WebSocketRateLimiter:
"""
Rate limiter for WebSocket connections.
Designed to be generous for legitimate use while preventing abuse.
Supports monitoring hundreds of containers without issues.
"""
def __init__(self):
# Track message timestamps per connection
# Format: {connection_id: deque of timestamps}
self.message_history: Dict[str, Deque[float]] = defaultdict(lambda: deque(maxlen=1000))
# Rate limits (generous for large deployments)
self.limits = {
# Allow 100 messages per second (very generous for real-time monitoring)
"messages_per_second": 100,
# Allow 1000 messages per minute (for burst activity)
"messages_per_minute": 1000,
# Allow 10000 messages per hour (for sustained monitoring)
"messages_per_hour": 10000
}
# Track violations for logging
self.violations: Dict[str, int] = defaultdict(int)
def check_rate_limit(self, connection_id: str) -> Tuple[bool, str]:
"""
Check if a connection has exceeded rate limits.
Returns:
(allowed, reason) - allowed is True if within limits,
reason explains why if rejected
"""
current_time = time.time()
history = self.message_history[connection_id]
# Add current request
history.append(current_time)
# Count messages in different time windows
messages_last_second = sum(1 for t in history if current_time - t <= 1)
messages_last_minute = sum(1 for t in history if current_time - t <= 60)
messages_last_hour = sum(1 for t in history if current_time - t <= 3600)
# Check per-second limit (prevent bursts)
if messages_last_second > self.limits["messages_per_second"]:
self.violations[connection_id] += 1
logger.warning(f"WebSocket rate limit exceeded for {connection_id}: "
f"{messages_last_second} msgs/sec (limit: {self.limits['messages_per_second']})")
return False, f"Rate limit exceeded: {messages_last_second} messages per second"
# Check per-minute limit
if messages_last_minute > self.limits["messages_per_minute"]:
self.violations[connection_id] += 1
logger.warning(f"WebSocket rate limit exceeded for {connection_id}: "
f"{messages_last_minute} msgs/min (limit: {self.limits['messages_per_minute']})")
return False, f"Rate limit exceeded: {messages_last_minute} messages per minute"
# Check per-hour limit
if messages_last_hour > self.limits["messages_per_hour"]:
self.violations[connection_id] += 1
logger.warning(f"WebSocket rate limit exceeded for {connection_id}: "
f"{messages_last_hour} msgs/hour (limit: {self.limits['messages_per_hour']})")
return False, f"Rate limit exceeded: {messages_last_hour} messages per hour"
# Reset violations on successful request
if connection_id in self.violations:
del self.violations[connection_id]
return True, "OK"
def cleanup_connection(self, connection_id: str):
"""Clean up tracking data when a connection closes"""
if connection_id in self.message_history:
del self.message_history[connection_id]
if connection_id in self.violations:
del self.violations[connection_id]
def get_connection_stats(self, connection_id: str) -> dict:
"""Get current rate limiting stats for a connection"""
current_time = time.time()
history = self.message_history.get(connection_id, deque())
return {
"messages_last_second": sum(1 for t in history if current_time - t <= 1),
"messages_last_minute": sum(1 for t in history if current_time - t <= 60),
"messages_last_hour": sum(1 for t in history if current_time - t <= 3600),
"violations": self.violations.get(connection_id, 0),
"limits": self.limits
}
# Global rate limiter instance
ws_rate_limiter = WebSocketRateLimiter()