Files
docker_dev/dockmon/backend/security/rate_limiting.py

185 lines
7.4 KiB
Python

"""
Rate Limiting System for DockMon
Provides protection against abuse and DoS attacks using token bucket algorithm
"""
import logging
import os
import time
from collections import defaultdict
from datetime import datetime
from typing import Dict, Tuple
from fastapi import Request, HTTPException, status, Depends
from .audit import security_audit
logger = logging.getLogger(__name__)
class RateLimiter:
"""
In-memory rate limiter using token bucket algorithm
Provides protection against abuse and DoS attacks
"""
def __init__(self):
# Start with generous initial tokens to allow immediate bursts
# This prevents legitimate users from getting rate limited on first use
self.clients = defaultdict(lambda: {"tokens": 100, "last_update": time.time(), "violations": 0})
# Get rate limits from environment or use production-friendly defaults
self.limits = {
# endpoint_pattern: (requests_per_minute, burst_limit, violation_threshold)
"default": (
int(os.getenv('DOCKMON_RATE_LIMIT_DEFAULT', 120)),
int(os.getenv('DOCKMON_RATE_BURST_DEFAULT', 20)),
int(os.getenv('DOCKMON_RATE_VIOLATIONS_DEFAULT', 8))
),
"auth": (
int(os.getenv('DOCKMON_RATE_LIMIT_AUTH', 10)), # 10 per minute for auth
int(os.getenv('DOCKMON_RATE_BURST_AUTH', 5)), # Lower burst
int(os.getenv('DOCKMON_RATE_VIOLATIONS_AUTH', 10)) # More lenient violations
),
"hosts": (
int(os.getenv('DOCKMON_RATE_LIMIT_HOSTS', 60)),
int(os.getenv('DOCKMON_RATE_BURST_HOSTS', 15)),
int(os.getenv('DOCKMON_RATE_VIOLATIONS_HOSTS', 8))
),
"containers": (
int(os.getenv('DOCKMON_RATE_LIMIT_CONTAINERS', 900)), # Increased for logs polling with multiple containers
int(os.getenv('DOCKMON_RATE_BURST_CONTAINERS', 180)),
int(os.getenv('DOCKMON_RATE_VIOLATIONS_CONTAINERS', 25))
),
"notifications": (
int(os.getenv('DOCKMON_RATE_LIMIT_NOTIFICATIONS', 30)),
int(os.getenv('DOCKMON_RATE_BURST_NOTIFICATIONS', 10)),
int(os.getenv('DOCKMON_RATE_VIOLATIONS_NOTIFICATIONS', 5))
),
}
logger.info(f"Rate limiting configured: Default={self.limits['default'][0]}/min, "
f"Auth={self.limits['auth'][0]}/min, Containers={self.limits['containers'][0]}/min")
self.banned_clients = {} # IP -> ban_until_timestamp
def _get_limit(self, endpoint: str) -> tuple:
"""Get rate limit for specific endpoint"""
for pattern, limits in self.limits.items():
if pattern in endpoint.lower():
return limits
return self.limits["default"]
def _cleanup_old_entries(self):
"""Clean up old entries to prevent memory leaks"""
current_time = time.time()
# Remove clients not seen for 1 hour
cutoff_time = current_time - 3600
old_clients = [ip for ip, data in self.clients.items()
if data["last_update"] < cutoff_time]
for ip in old_clients:
del self.clients[ip]
# Remove expired bans
expired_bans = [ip for ip, ban_time in self.banned_clients.items()
if current_time > ban_time]
for ip in expired_bans:
del self.banned_clients[ip]
def is_allowed(self, client_ip: str, endpoint: str) -> Tuple[bool, str]:
"""Check if request is allowed and return (allowed, reason)"""
current_time = time.time()
# Cleanup old entries periodically
if current_time % 300 < 1: # Every 5 minutes
self._cleanup_old_entries()
# Check if client is banned
if client_ip in self.banned_clients:
if current_time < self.banned_clients[client_ip]:
return False, f"IP banned until {datetime.fromtimestamp(self.banned_clients[client_ip]).isoformat()}"
else:
# Ban expired, remove from banned list
del self.banned_clients[client_ip]
requests_per_minute, burst_limit, violation_threshold = self._get_limit(endpoint)
client_data = self.clients[client_ip]
# Token bucket algorithm with burst support
time_passed = current_time - client_data["last_update"]
tokens_to_add = (time_passed / 60.0) * requests_per_minute
# Allow bursting up to burst_limit tokens
client_data["tokens"] = min(burst_limit, client_data["tokens"] + tokens_to_add)
client_data["last_update"] = current_time
# Check if request is allowed
if client_data["tokens"] >= 1.0:
client_data["tokens"] -= 1.0
return True, "OK"
else:
# Rate limit exceeded
client_data["violations"] += 1
# Check if violations exceed threshold - ban the client
if client_data["violations"] >= violation_threshold:
ban_duration = 60 # 60 seconds ban (reduced from 15 minutes for better UX)
self.banned_clients[client_ip] = current_time + ban_duration
logger.warning(f"IP {client_ip} banned for 60 seconds due to {violation_threshold} rate limit violations")
# Security audit log
security_audit.log_rate_limit_violation(
client_ip=client_ip,
endpoint=endpoint,
violations=client_data["violations"],
banned=True
)
return False, f"IP banned for repeated violations"
# Log rate limit violation
security_audit.log_rate_limit_violation(
client_ip=client_ip,
endpoint=endpoint,
violations=client_data["violations"],
banned=False
)
return False, f"Rate limit exceeded. Try again in {int(60 - time_passed)} seconds"
def get_stats(self) -> dict:
"""Get rate limiter statistics"""
return {
"active_clients": len(self.clients),
"banned_clients": len(self.banned_clients),
"total_violations": sum(data["violations"] for data in self.clients.values())
}
# Global rate limiter instance
rate_limiter = RateLimiter()
def get_rate_limit_dependency(endpoint_type: str = "default"):
"""Create a dependency for rate limiting specific endpoint types"""
def rate_limit_check(request: Request):
client_ip = request.client.host
endpoint_name = f"{endpoint_type}_{request.url.path}"
allowed, reason = rate_limiter.is_allowed(client_ip, endpoint_name)
if not allowed:
logger.warning(f"Rate limit exceeded for {client_ip} on {endpoint_name}: {reason}")
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded: {reason}",
headers={"Retry-After": "60"}
)
return True
return rate_limit_check
# Rate limiting dependencies for different endpoint types
rate_limit_auth = Depends(get_rate_limit_dependency("auth"))
rate_limit_hosts = Depends(get_rate_limit_dependency("hosts"))
rate_limit_containers = Depends(get_rate_limit_dependency("containers"))
rate_limit_notifications = Depends(get_rate_limit_dependency("notifications"))
rate_limit_default = Depends(get_rate_limit_dependency("default"))