108 lines
4.4 KiB
Python
108 lines
4.4 KiB
Python
"""
|
|
WebSocket Rate Limiting for DockMon
|
|
Provides generous rate limiting to prevent DoS while supporting large deployments
|
|
"""
|
|
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from typing import Dict, Deque, Tuple
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WebSocketRateLimiter:
|
|
"""
|
|
Rate limiter for WebSocket connections.
|
|
Designed to be generous for legitimate use while preventing abuse.
|
|
|
|
Supports monitoring hundreds of containers without issues.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Track message timestamps per connection
|
|
# Format: {connection_id: deque of timestamps}
|
|
self.message_history: Dict[str, Deque[float]] = defaultdict(lambda: deque(maxlen=1000))
|
|
|
|
# Rate limits (generous for large deployments)
|
|
self.limits = {
|
|
# Allow 100 messages per second (very generous for real-time monitoring)
|
|
"messages_per_second": 100,
|
|
# Allow 1000 messages per minute (for burst activity)
|
|
"messages_per_minute": 1000,
|
|
# Allow 10000 messages per hour (for sustained monitoring)
|
|
"messages_per_hour": 10000
|
|
}
|
|
|
|
# Track violations for logging
|
|
self.violations: Dict[str, int] = defaultdict(int)
|
|
|
|
def check_rate_limit(self, connection_id: str) -> Tuple[bool, str]:
|
|
"""
|
|
Check if a connection has exceeded rate limits.
|
|
|
|
Returns:
|
|
(allowed, reason) - allowed is True if within limits,
|
|
reason explains why if rejected
|
|
"""
|
|
current_time = time.time()
|
|
history = self.message_history[connection_id]
|
|
|
|
# Add current request
|
|
history.append(current_time)
|
|
|
|
# Count messages in different time windows
|
|
messages_last_second = sum(1 for t in history if current_time - t <= 1)
|
|
messages_last_minute = sum(1 for t in history if current_time - t <= 60)
|
|
messages_last_hour = sum(1 for t in history if current_time - t <= 3600)
|
|
|
|
# Check per-second limit (prevent bursts)
|
|
if messages_last_second > self.limits["messages_per_second"]:
|
|
self.violations[connection_id] += 1
|
|
logger.warning(f"WebSocket rate limit exceeded for {connection_id}: "
|
|
f"{messages_last_second} msgs/sec (limit: {self.limits['messages_per_second']})")
|
|
return False, f"Rate limit exceeded: {messages_last_second} messages per second"
|
|
|
|
# Check per-minute limit
|
|
if messages_last_minute > self.limits["messages_per_minute"]:
|
|
self.violations[connection_id] += 1
|
|
logger.warning(f"WebSocket rate limit exceeded for {connection_id}: "
|
|
f"{messages_last_minute} msgs/min (limit: {self.limits['messages_per_minute']})")
|
|
return False, f"Rate limit exceeded: {messages_last_minute} messages per minute"
|
|
|
|
# Check per-hour limit
|
|
if messages_last_hour > self.limits["messages_per_hour"]:
|
|
self.violations[connection_id] += 1
|
|
logger.warning(f"WebSocket rate limit exceeded for {connection_id}: "
|
|
f"{messages_last_hour} msgs/hour (limit: {self.limits['messages_per_hour']})")
|
|
return False, f"Rate limit exceeded: {messages_last_hour} messages per hour"
|
|
|
|
# Reset violations on successful request
|
|
if connection_id in self.violations:
|
|
del self.violations[connection_id]
|
|
|
|
return True, "OK"
|
|
|
|
def cleanup_connection(self, connection_id: str):
|
|
"""Clean up tracking data when a connection closes"""
|
|
if connection_id in self.message_history:
|
|
del self.message_history[connection_id]
|
|
if connection_id in self.violations:
|
|
del self.violations[connection_id]
|
|
|
|
def get_connection_stats(self, connection_id: str) -> dict:
|
|
"""Get current rate limiting stats for a connection"""
|
|
current_time = time.time()
|
|
history = self.message_history.get(connection_id, deque())
|
|
|
|
return {
|
|
"messages_last_second": sum(1 for t in history if current_time - t <= 1),
|
|
"messages_last_minute": sum(1 for t in history if current_time - t <= 60),
|
|
"messages_last_hour": sum(1 for t in history if current_time - t <= 3600),
|
|
"violations": self.violations.get(connection_id, 0),
|
|
"limits": self.limits
|
|
}
|
|
|
|
|
|
# Global rate limiter instance
|
|
ws_rate_limiter = WebSocketRateLimiter() |