Files
docker_dev/dockmon/stats-service/main.go

570 lines
16 KiB
Go

package main
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"log"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/gorilla/websocket"
)
// Configuration with environment variable support
var config = struct {
TokenFilePath string
Port string
AggregationInterval time.Duration
EventCacheSize int
CleanupInterval time.Duration
MaxRequestBodySize int64
AllowedOrigins string
}{
TokenFilePath: getEnv("TOKEN_FILE_PATH", "/tmp/stats-service-token"),
Port: getEnv("STATS_SERVICE_PORT", "8081"),
AggregationInterval: getEnvDuration("AGGREGATION_INTERVAL", "1s"),
EventCacheSize: getEnvInt("EVENT_CACHE_SIZE", 100),
CleanupInterval: getEnvDuration("CLEANUP_INTERVAL", "60s"),
MaxRequestBodySize: getEnvInt64("MAX_REQUEST_BODY_SIZE", 1048576), // 1MB default
AllowedOrigins: getEnv("ALLOWED_ORIGINS", "http://localhost:8080,http://127.0.0.1:8080,http://localhost,http://127.0.0.1"),
}
// getEnv gets environment variable with fallback
func getEnv(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}
// getEnvInt gets integer environment variable with fallback
func getEnvInt(key string, fallback int) int {
if value := os.Getenv(key); value != "" {
if intVal, err := strconv.Atoi(value); err == nil {
return intVal
}
}
return fallback
}
// getEnvInt64 gets int64 environment variable with fallback
func getEnvInt64(key string, fallback int64) int64 {
if value := os.Getenv(key); value != "" {
if intVal, err := strconv.ParseInt(value, 10, 64); err == nil {
return intVal
}
}
return fallback
}
// getEnvDuration gets duration environment variable with fallback
func getEnvDuration(key string, fallback string) time.Duration {
value := getEnv(key, fallback)
if duration, err := time.ParseDuration(value); err == nil {
return duration
}
// Fallback parsing
if duration, err := time.ParseDuration(fallback); err == nil {
return duration
}
return 1 * time.Second // Ultimate fallback
}
// generateToken creates a cryptographically secure random token
func generateToken() (string, error) {
bytes := make([]byte, 32) // 256-bit token
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// limitRequestBody limits the size of request bodies
func limitRequestBody(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, config.MaxRequestBodySize)
next(w, r)
}
}
// jsonResponse writes a JSON response and handles encoding errors
func jsonResponse(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Error encoding JSON response: %v", err)
// Can't send error status if response already partially sent
}
}
// authMiddleware validates the Bearer token using constant-time comparison
func authMiddleware(token string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get Authorization header
authHeader := r.Header.Get("Authorization")
expectedAuth := "Bearer " + token
// Use constant-time comparison to prevent timing attacks
// Check length first (still constant-time for the comparison itself)
if len(authHeader) != len(expectedAuth) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
log.Printf("Unauthorized request from %s to %s (length mismatch)", r.RemoteAddr, r.URL.Path)
return
}
// Constant-time comparison of the full auth header
if subtle.ConstantTimeCompare([]byte(authHeader), []byte(expectedAuth)) != 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
log.Printf("Unauthorized request from %s to %s", r.RemoteAddr, r.URL.Path)
return
}
next(w, r)
}
}
func main() {
log.Println("Starting DockMon Stats Service...")
// Generate random token
token, err := generateToken()
if err != nil {
log.Fatalf("Failed to generate token: %v", err)
}
// Write token to file for Python backend
if err := os.WriteFile(config.TokenFilePath, []byte(token), 0600); err != nil {
log.Fatalf("Failed to write token file: %v", err)
}
log.Printf("Generated temporary auth token for stats service")
log.Printf("Configuration: port=%s, aggregation=%v, cache_size=%d, cleanup=%v",
config.Port, config.AggregationInterval, config.EventCacheSize, config.CleanupInterval)
// Create stats cache
cache := NewStatsCache()
// Create stream manager
streamManager := NewStreamManager(cache)
// Create aggregator with configured interval
aggregator := NewAggregator(cache, streamManager, config.AggregationInterval)
// Create event management components with configured cache size
eventCache := NewEventCache(config.EventCacheSize)
eventBroadcaster := NewEventBroadcaster()
eventManager := NewEventManager(eventBroadcaster, eventCache)
// Create context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start aggregator
go aggregator.Start(ctx)
// Start cleanup routine (remove stale stats every 60 seconds)
go func() {
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cache.CleanStaleStats(60 * time.Second)
}
}
}()
// Create HTTP server
mux := http.NewServeMux()
// Health check endpoint
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
_, totalEvents := eventCache.GetStats()
jsonResponse(w, map[string]interface{}{
"status": "ok",
"service": "dockmon-stats",
"stats_streams": streamManager.GetStreamCount(),
"event_hosts": eventManager.GetActiveHosts(),
"event_connections": eventBroadcaster.GetConnectionCount(),
"cached_events": totalEvents,
})
})
// Get all host stats (main endpoint for Python backend) - PROTECTED
mux.HandleFunc("/api/stats/hosts", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
hostStats := cache.GetAllHostStats()
json.NewEncoder(w).Encode(hostStats)
}))
// Get stats for a specific host - PROTECTED
mux.HandleFunc("/api/stats/host/", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
hostID := r.URL.Path[len("/api/stats/host/"):]
if hostID == "" {
http.Error(w, "host_id required", http.StatusBadRequest)
return
}
stats, ok := cache.GetHostStats(hostID)
if !ok {
http.NotFound(w, r)
return
}
jsonResponse(w,stats)
}))
// Get all container stats (for debugging) - PROTECTED
mux.HandleFunc("/api/stats/containers", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
containerStats := cache.GetAllContainerStats()
json.NewEncoder(w).Encode(containerStats)
}))
// Start stream for a container (called by Python backend) - PROTECTED
mux.HandleFunc("/api/streams/start", authMiddleware(token, limitRequestBody(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
ContainerID string `json:"container_id"`
ContainerName string `json:"container_name"`
HostID string `json:"host_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate required fields
if req.ContainerID == "" || req.HostID == "" {
http.Error(w, "container_id and host_id are required", http.StatusBadRequest)
return
}
if err := streamManager.StartStream(ctx, req.ContainerID, req.ContainerName, req.HostID); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "started"})
})))
// Stop stream for a container - PROTECTED
mux.HandleFunc("/api/streams/stop", authMiddleware(token, limitRequestBody(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
ContainerID string `json:"container_id"`
HostID string `json:"host_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate required fields
if req.ContainerID == "" || req.HostID == "" {
http.Error(w, "container_id and host_id are required", http.StatusBadRequest)
return
}
streamManager.StopStream(req.ContainerID, req.HostID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "stopped"})
})))
// Add Docker host - PROTECTED
mux.HandleFunc("/api/hosts/add", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
HostID string `json:"host_id"`
HostAddress string `json:"host_address"`
TLSCACert string `json:"tls_ca_cert,omitempty"`
TLSCert string `json:"tls_cert,omitempty"`
TLSKey string `json:"tls_key,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate required fields
if req.HostID == "" || req.HostAddress == "" {
http.Error(w, "host_id and host_address are required", http.StatusBadRequest)
return
}
if err := streamManager.AddDockerHost(req.HostID, req.HostAddress, req.TLSCACert, req.TLSCert, req.TLSKey); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "added"})
}))
// Remove Docker host - PROTECTED
mux.HandleFunc("/api/hosts/remove", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
HostID string `json:"host_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate required fields
if req.HostID == "" {
http.Error(w, "host_id is required", http.StatusBadRequest)
return
}
streamManager.RemoveDockerHost(req.HostID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "removed"})
}))
// Debug endpoint - PROTECTED
mux.HandleFunc("/debug/stats", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
containerCount, hostCount := cache.GetStats()
jsonResponse(w, map[string]interface{}{
"streams": streamManager.GetStreamCount(),
"containers": containerCount,
"hosts": hostCount,
})
}))
// === Event Monitoring Endpoints ===
// Start monitoring events for a host - PROTECTED
mux.HandleFunc("/api/events/hosts/add", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
HostID string `json:"host_id"`
HostAddress string `json:"host_address"`
TLSCACert string `json:"tls_ca_cert,omitempty"`
TLSCert string `json:"tls_cert,omitempty"`
TLSKey string `json:"tls_key,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate required fields
if req.HostID == "" || req.HostAddress == "" {
http.Error(w, "host_id and host_address are required", http.StatusBadRequest)
return
}
if err := eventManager.AddHost(req.HostID, req.HostAddress, req.TLSCACert, req.TLSCert, req.TLSKey); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "started"})
}))
// Stop monitoring events for a host - PROTECTED
mux.HandleFunc("/api/events/hosts/remove", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
HostID string `json:"host_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate required fields
if req.HostID == "" {
http.Error(w, "host_id is required", http.StatusBadRequest)
return
}
eventManager.RemoveHost(req.HostID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "stopped"})
}))
// Get recent events - PROTECTED
mux.HandleFunc("/api/events/recent", authMiddleware(token, func(w http.ResponseWriter, r *http.Request) {
hostID := r.URL.Query().Get("host_id")
var events interface{}
if hostID != "" {
// Get events for specific host
events = eventCache.GetRecentEvents(hostID, 50)
} else {
// Get events for all hosts
events = eventCache.GetAllRecentEvents(50)
}
jsonResponse(w,events)
}))
// WebSocket endpoint for event streaming - PROTECTED
mux.HandleFunc("/ws/events", func(w http.ResponseWriter, r *http.Request) {
// Validate token from query parameter or header
tokenParam := r.URL.Query().Get("token")
authHeader := r.Header.Get("Authorization")
validToken := false
if tokenParam == token {
validToken = true
} else if authHeader == "Bearer "+token {
validToken = true
}
if !validToken {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
log.Printf("Unauthorized WebSocket connection attempt from %s", r.RemoteAddr)
return
}
// Upgrade to WebSocket
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow same-origin requests
}
// Check against configured allowed origins
allowedOrigins := strings.Split(config.AllowedOrigins, ",")
for _, allowed := range allowedOrigins {
if strings.TrimSpace(allowed) == origin {
return true
}
}
return false
},
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket upgrade failed: %v", err)
return
}
// Register connection
if err := eventBroadcaster.AddConnection(conn); err != nil {
log.Printf("Failed to register connection: %v", err)
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "Connection limit reached"))
conn.Close()
return
}
// Handle connection (read loop to detect disconnect)
go func() {
defer func() {
eventBroadcaster.RemoveConnection(conn)
conn.Close()
}()
for {
// Read messages (just to detect disconnect, we don't expect any)
_, _, err := conn.ReadMessage()
if err != nil {
break
}
}
}()
})
// Create server with configured port
srv := &http.Server{
Addr: ":" + config.Port,
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 60 * time.Second,
}
// Start server in goroutine
go func() {
log.Printf("Stats service listening on %s", srv.Addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err)
}
}()
// Wait for interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
log.Println("Shutting down stats service...")
// Graceful shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
// Stop all stats streams
streamManager.StopAllStreams()
// Stop all event monitoring
eventManager.StopAll()
// Close all event WebSocket connections
eventBroadcaster.CloseAll()
// Stop HTTP server
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("Server shutdown error: %v", err)
}
// Cancel context to stop aggregator
cancel()
// Clean up token file
if err := os.Remove(config.TokenFilePath); err != nil {
log.Printf("Warning: Failed to remove token file: %v", err)
} else {
log.Println("Removed token file")
}
log.Println("Stats service stopped")
}