diff --git a/src/async_test.go b/src/async_test.go index bfb77b2..67e0799 100644 --- a/src/async_test.go +++ b/src/async_test.go @@ -35,8 +35,16 @@ func TestAsyncProcessorEnqueueAndProcess(t *testing.T) { t.Fatal("timed out waiting for async processing") } - if got := testutil.ToFloat64(asyncProcessedEventsCounter.WithLabelValues("workflow_run")); got != 1 { - t.Fatalf("expected processed counter to be 1, got %v", got) + deadline := time.Now().Add(2 * time.Second) + for { + if got := testutil.ToFloat64(asyncProcessedEventsCounter.WithLabelValues("workflow_run")); got == 1 { + break + } + if time.Now().After(deadline) { + got := testutil.ToFloat64(asyncProcessedEventsCounter.WithLabelValues("workflow_run")) + t.Fatalf("expected processed counter to be 1, got %v", got) + } + time.Sleep(10 * time.Millisecond) } } diff --git a/src/main.go b/src/main.go index 772a1b9..95d0e56 100644 --- a/src/main.go +++ b/src/main.go @@ -1,11 +1,17 @@ package main import ( + "bufio" + "context" "encoding/json" + "errors" + "net" "net/http" "net/http/pprof" "os" + "os/signal" "strings" + "syscall" "time" "github.com/gorilla/mux" @@ -22,7 +28,13 @@ type HealthCheckResposne struct { type statusRecorder struct { http.ResponseWriter - status int + status int + wroteHeader bool +} + +type serviceMetrics struct { + apiCallsCounter *prometheus.CounterVec + requestDurationHistogram *prometheus.HistogramVec } var ( @@ -32,60 +44,112 @@ var ( enableDebug string // Compile time flag to enable debug mode debug bool - apiCallsCounter = promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: "promgithub_api_calls_total", - Help: "Number of API calls", - }, - []string{"status", "method", "path"}, - ) - - requestDurationHistogram = promauto.NewHistogramVec( - prometheus.HistogramOpts{ - Name: "promgithub_request_duration_seconds", - Help: "Request duration in seconds", - Buckets: prometheus.DefBuckets, - }, - []string{"path", "method"}, - ) + defaultServiceMetrics = newServiceMetrics(prometheus.DefaultRegisterer) ) -func apiHandler(logger *zap.Logger) func(http.Handler) http.Handler { +func newServiceMetrics(registerer prometheus.Registerer) *serviceMetrics { + factory := promauto.With(registerer) + + return &serviceMetrics{ + apiCallsCounter: factory.NewCounterVec( + prometheus.CounterOpts{ + Name: "promgithub_api_calls_total", + Help: "Number of API calls", + }, + []string{"status", "method", "path"}, + ), + requestDurationHistogram: factory.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "promgithub_request_duration_seconds", + Help: "Request duration in seconds", + Buckets: prometheus.DefBuckets, + }, + []string{"path", "method"}, + ), + } +} + +func (r *statusRecorder) WriteHeader(status int) { + r.status = status + r.wroteHeader = true + r.ResponseWriter.WriteHeader(status) +} + +func (r *statusRecorder) Write(body []byte) (int, error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + return r.ResponseWriter.Write(body) +} + +func (r *statusRecorder) Flush() { + if flusher, ok := r.ResponseWriter.(http.Flusher); ok { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + flusher.Flush() + } +} + +func (r *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := r.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("response writer does not support hijacking") + } + return hijacker.Hijack() +} + +func (r *statusRecorder) Push(target string, opts *http.PushOptions) error { + pusher, ok := r.ResponseWriter.(http.Pusher) + if !ok { + return http.ErrNotSupported + } + return pusher.Push(target, opts) +} + +func (r *statusRecorder) Unwrap() http.ResponseWriter { + return r.ResponseWriter +} + +func apiHandler(logger *zap.Logger, metrics *serviceMetrics) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() - - rec := statusRecorder{ResponseWriter: w, status: 200} - - logger.Info("Received request", - zap.String("method", r.Method), - zap.String("path", r.URL.Path), - zap.String("remoteAddr", r.RemoteAddr), - zap.String("userAgent", r.UserAgent()), - ) + rec := statusRecorder{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(&rec, r) duration := time.Since(start).Seconds() + statusText := http.StatusText(rec.status) + if statusText == "" { + statusText = "UNKNOWN" + } + + metrics.apiCallsCounter.WithLabelValues(statusText, r.Method, r.URL.Path).Inc() + metrics.requestDurationHistogram.WithLabelValues(r.URL.Path, r.Method).Observe(duration) - apiCallsCounter.With(prometheus.Labels{ - "status": http.StatusText(rec.status), - "method": r.Method, - "path": r.URL.Path, - }).Inc() + fields := []zap.Field{ + zap.String("method", r.Method), + zap.String("path", r.URL.Path), + zap.Int("status", rec.status), + zap.Float64("durationSeconds", duration), + } - requestDurationHistogram.With(prometheus.Labels{ - "path": r.URL.Path, - "method": r.Method, - }).Observe(duration) + switch { + case rec.status >= http.StatusInternalServerError: + logger.Error("Request completed", fields...) + case rec.status >= http.StatusBadRequest: + logger.Warn("Request completed", fields...) + default: + logger.Debug("Request completed", fields...) + } }) } } func healthCheck(w http.ResponseWriter, _ *http.Request) { response := HealthCheckResposne{Status: "ok", Version: Version} - err := json.NewEncoder(w).Encode(response) - if err != nil { + if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -108,22 +172,19 @@ func init() { defer func() { if err := logger.Sync(); err != nil { - // Logger sync errors on program exit are typically not critical - // and often occur when stdout/stderr are closed before sync - _ = err // Explicitly ignore the error + _ = err } }() } -func setupRouter(logger *zap.Logger) *mux.Router { +func setupRouter(logger *zap.Logger, metrics *serviceMetrics, gatherer prometheus.Gatherer) *mux.Router { r := mux.NewRouter() - r.Use(apiHandler(logger)) + r.Use(apiHandler(logger, metrics)) r.HandleFunc("/health", healthCheck).Methods("GET") - r.Handle("/metrics", promhttp.Handler()) + r.Handle("/metrics", promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{})) r.HandleFunc("/webhook", githubEventsHandler).Methods("POST") - // Profiling endpoints if debug { r.HandleFunc("/debug/pprof/", pprof.Index) r.HandleFunc("/debug/pprof/allocs", pprof.Handler("allocs").ServeHTTP) @@ -141,7 +202,37 @@ func setupRouter(logger *zap.Logger) *mux.Router { return r } +func runServer(ctx context.Context, server *http.Server, logger *zap.Logger) error { + errCh := make(chan error, 1) + go func() { + err := server.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + return + } + errCh <- nil + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + logger.Info("Shutting down server") + if err := server.Shutdown(shutdownCtx); err != nil { + return err + } + + return <-errCh + } +} + func main() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + port := strings.TrimSpace(os.Getenv("PROMGITHUB_SERVICE_PORT")) if port == "" { port = "8080" @@ -182,19 +273,18 @@ func main() { zap.Int("queueSize", asyncConfig.QueueSize), ) - r := setupRouter(logger) - + r := setupRouter(logger, defaultServiceMetrics, prometheus.DefaultGatherer) server := &http.Server{ Addr: ":" + port, Handler: r, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, - MaxHeaderBytes: 1 << 20, // 1 MB + MaxHeaderBytes: 1 << 20, } logger.Info("Starting server", zap.String("port", port)) - if err := server.ListenAndServe(); err != nil { - logger.Fatal("Error starting server", zap.Error(err)) + if err := runServer(ctx, server, logger); err != nil { + logger.Fatal("Server exited with error", zap.Error(err)) } } diff --git a/src/main_test.go b/src/main_test.go index 03e0c88..08d61b1 100644 --- a/src/main_test.go +++ b/src/main_test.go @@ -1,7 +1,9 @@ package main import ( + "context" "encoding/json" + "net" "net/http" "net/http/httptest" "os" @@ -14,39 +16,29 @@ import ( "go.uber.org/zap" ) -var ( - reg *prometheus.Registry -) +var reg *prometheus.Registry func init() { - // Disable logging logger = zap.NewNop() reg = prometheus.NewRegistry() } func TestHealthCheck(t *testing.T) { - // Set the Version variable for the test Version = "1.0.0" - // Create a test HTTP request req, err := http.NewRequest("GET", "/health", nil) if err != nil { t.Fatalf("Failed to create HTTP request: %v", err) } - // Create a test HTTP response recorder rr := httptest.NewRecorder() - - // Call the healthCheck handler handler := http.HandlerFunc(healthCheck) handler.ServeHTTP(rr, req) - // Verify the response status code if status := rr.Code; status != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, status) } - // Verify the response body expectedResponse := HealthCheckResposne{Status: "ok", Version: Version} var actualResponse HealthCheckResposne if err := json.NewDecoder(rr.Body).Decode(&actualResponse); err != nil { @@ -59,21 +51,16 @@ func TestHealthCheck(t *testing.T) { } func TestSetupRouter(t *testing.T) { - // Set environment variables for the test _ = os.Setenv("PROMGITHUB_WEBHOOK_SECRET", "testsecret") defer func() { _ = os.Unsetenv("PROMGITHUB_WEBHOOK_SECRET") }() - // Initialize the logger - logger := zap.NewNop() - - // Set up the router - r := setupRouter(logger) + registry := prometheus.NewRegistry() + metrics := newServiceMetrics(registry) + r := setupRouter(zap.NewNop(), metrics, registry) - // Create a test HTTP server server := httptest.NewServer(r) defer server.Close() - // Test the /health endpoint resp, err := http.Get(server.URL + "/health") if err != nil { t.Fatalf("Failed to send HTTP request: %v", err) @@ -84,7 +71,6 @@ func TestSetupRouter(t *testing.T) { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } - // Test the /metrics endpoint resp, err = http.Get(server.URL + "/metrics") if err != nil { t.Fatalf("Failed to send HTTP request: %v", err) @@ -96,42 +82,61 @@ func TestSetupRouter(t *testing.T) { } } -func TestApiHandler(t *testing.T) { - apiCallsCounter.Reset() - reg.MustRegister(apiCallsCounter) +func TestApiHandlerRecordsExplicitStatusCode(t *testing.T) { + registry := prometheus.NewRegistry() + metrics := newServiceMetrics(registry) - // Create a test HTTP handler testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("OK")); err != nil { - t.Errorf("Failed to write response: %v", err) - } + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte("created")) }) - // Wrap the test handler with the apiHandler middleware - handler := apiHandler(logger)(testHandler) - - // Create a test HTTP server + handler := apiHandler(zap.NewNop(), metrics)(testHandler) server := httptest.NewServer(handler) defer server.Close() - // Create a test HTTP client - client := &http.Client{Timeout: 10 * time.Second} + resp, err := (&http.Client{Timeout: 10 * time.Second}).Get(server.URL) + if err != nil { + t.Fatalf("Failed to send HTTP request: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + t.Fatalf("Expected status code %d, got %d", http.StatusCreated, resp.StatusCode) + } + + if err := testutil.CollectAndCompare(metrics.apiCallsCounter, strings.NewReader(` + # HELP promgithub_api_calls_total Number of API calls + # TYPE promgithub_api_calls_total counter + promgithub_api_calls_total{method="GET",path="/",status="Created"} 1 + `)); err != nil { + t.Errorf("unexpected metrics: %v", err) + } +} - // Send a test HTTP request - resp, err := client.Get(server.URL) +func TestApiHandlerRecordsImplicitStatusCode(t *testing.T) { + registry := prometheus.NewRegistry() + metrics := newServiceMetrics(registry) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok")) + }) + + handler := apiHandler(zap.NewNop(), metrics)(testHandler) + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := (&http.Client{Timeout: 10 * time.Second}).Get(server.URL) if err != nil { t.Fatalf("Failed to send HTTP request: %v", err) } defer func() { _ = resp.Body.Close() }() - // Verify the response status code if resp.StatusCode != http.StatusOK { - t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + t.Fatalf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } - // Verify the Prometheus metrics - if err := testutil.CollectAndCompare(apiCallsCounter, strings.NewReader(` + if err := testutil.CollectAndCompare(metrics.apiCallsCounter, strings.NewReader(` # HELP promgithub_api_calls_total Number of API calls # TYPE promgithub_api_calls_total counter promgithub_api_calls_total{method="GET",path="/",status="OK"} 1 @@ -139,3 +144,67 @@ func TestApiHandler(t *testing.T) { t.Errorf("unexpected metrics: %v", err) } } + +func TestRunServerShutsDownOnContextCancel(t *testing.T) { + registry := prometheus.NewRegistry() + metrics := newServiceMetrics(registry) + router := setupRouter(zap.NewNop(), metrics, registry) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer func() { _ = listener.Close() }() + + server := &http.Server{ + Handler: router, + ReadHeaderTimeout: 2 * time.Second, + } + ctx, cancel := context.WithCancel(context.Background()) + + errCh := make(chan error, 1) + go func() { + errCh <- runServerWithListener(ctx, server, listener) + }() + + resp, err := http.Get("http://" + listener.Addr().String() + "/health") + if err != nil { + t.Fatalf("Failed to send HTTP request: %v", err) + } + _ = resp.Body.Close() + + cancel() + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("Expected graceful shutdown, got error: %v", err) + } + case <-time.After(3 * time.Second): + t.Fatal("Timed out waiting for server shutdown") + } +} + +func runServerWithListener(ctx context.Context, server *http.Server, listener net.Listener) error { + errCh := make(chan error, 1) + go func() { + err := server.Serve(listener) + if err != nil && err != http.ErrServerClosed { + errCh <- err + return + } + errCh <- nil + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + return err + } + return <-errCh + } +}