package web_test import ( "fmt" "io" "net" "net/http" "os" "strings" "testing" "time" web "git.sharkk.net/Go/Web" ) // testServer starts a server on a random port, runs the given test function, // and ensures proper server shutdown afterward. func testServer(t *testing.T, s web.Server, test func(port string)) { // Get a free port by letting the OS assign one listener, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf("Failed to find available port: %v", err) } port := listener.Addr().(*net.TCPAddr).Port portStr := fmt.Sprintf(":%d", port) listener.Close() // Start the server in a goroutine ready := make(chan bool) done := make(chan bool) go func() { // Signal that we're about to start the server ready <- true // Start the server (this blocks until interrupted) s.Run(portStr) // Signal that the server has shut down done <- true }() // Wait for server to begin starting <-ready // Give it a moment to actually start listening time.Sleep(50 * time.Millisecond) // Run the test function with the assigned port test(portStr) // Send interrupt to stop the server p, _ := os.FindProcess(os.Getpid()) p.Signal(os.Interrupt) // Wait for server to confirm shutdown with timeout select { case <-done: // Server shutdown properly case <-time.After(100 * time.Millisecond): t.Log("Warning: Server didn't shut down within timeout") } } func TestRun(t *testing.T) { s := web.NewServer() // Add a handler for the root path s.Get("/", func(ctx web.Context) error { return ctx.String("Hello") }) testServer(t, s, func(port string) { resp, err := http.Get("http://127.0.0.1" + port + "/") if err != nil { t.Fatalf("Error making request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Errorf("Expected status 200, got %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } if string(body) != "Hello" { t.Errorf("Expected body 'Hello', got '%s'", string(body)) } }) } func TestBadRequest(t *testing.T) { s := web.NewServer() testServer(t, s, func(port string) { conn, err := net.Dial("tcp", "127.0.0.1"+port) if err != nil { t.Fatalf("Error connecting: %v", err) } defer conn.Close() _, err = io.WriteString(conn, "BadRequest\r\n\r\n") if err != nil { t.Fatalf("Error writing: %v", err) } response, err := io.ReadAll(conn) if err != nil { t.Fatalf("Error reading: %v", err) } if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" { t.Errorf("Expected 400 response, got: %s", string(response)) } }) } func TestBadRequestHeader(t *testing.T) { s := web.NewServer() s.Get("/", func(ctx web.Context) error { return ctx.String("Hello") }) testServer(t, s, func(port string) { conn, err := net.Dial("tcp", "127.0.0.1"+port) if err != nil { t.Fatalf("Error connecting: %v", err) } defer conn.Close() _, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n") if err != nil { t.Fatalf("Error writing: %v", err) } buffer := make([]byte, len("HTTP/1.1 200")) _, err = conn.Read(buffer) if err != nil { t.Fatalf("Error reading: %v", err) } if string(buffer) != "HTTP/1.1 200" { t.Errorf("Expected '200' response, got: %s", string(buffer)) } }) } func TestBadRequestMethod(t *testing.T) { s := web.NewServer() testServer(t, s, func(port string) { conn, err := net.Dial("tcp", "127.0.0.1"+port) if err != nil { t.Fatalf("Error connecting: %v", err) } defer conn.Close() _, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n") if err != nil { t.Fatalf("Error writing: %v", err) } response, err := io.ReadAll(conn) if err != nil { t.Fatalf("Error reading: %v", err) } if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" { t.Errorf("Expected 400 response, got: %s", string(response)) } }) } func TestBadRequestProtocol(t *testing.T) { s := web.NewServer() s.Get("/", func(ctx web.Context) error { return ctx.String("Hello") }) testServer(t, s, func(port string) { conn, err := net.Dial("tcp", "127.0.0.1"+port) if err != nil { t.Fatalf("Error connecting: %v", err) } defer conn.Close() _, err = io.WriteString(conn, "GET /\r\n\r\n") if err != nil { t.Fatalf("Error writing: %v", err) } buffer := make([]byte, len("HTTP/1.1 200")) _, err = conn.Read(buffer) if err != nil { t.Fatalf("Error reading: %v", err) } if string(buffer) != "HTTP/1.1 200" { t.Errorf("Expected '200' response, got: %s", string(buffer)) } }) } func TestEarlyClose(t *testing.T) { s := web.NewServer() testServer(t, s, func(port string) { conn, err := net.Dial("tcp", "127.0.0.1"+port) if err != nil { t.Fatalf("Error connecting: %v", err) } _, err = io.WriteString(conn, "GET /\r\n") if err != nil { t.Fatalf("Error writing: %v", err) } // Close the connection early err = conn.Close() if err != nil { t.Fatalf("Error closing connection: %v", err) } // No assertion needed - if the server doesn't crash, the test passes }) } func TestUnavailablePort(t *testing.T) { // Get a port and keep it occupied listener, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf("Failed to find available port: %v", err) } port := listener.Addr().(*net.TCPAddr).Port portStr := fmt.Sprintf(":%d", port) defer listener.Close() // Try to run the server on the occupied port s := web.NewServer() // Run in a goroutine with a timeout errChan := make(chan error, 1) go func() { errChan <- s.Run(portStr) }() // Check if we get an error within a reasonable time select { case err := <-errChan: if err == nil { t.Error("Expected error when binding to unavailable port") } case <-time.After(100 * time.Millisecond): t.Error("Server.Run() didn't return with error on unavailable port") // Try to stop the server anyway p, _ := os.FindProcess(os.Getpid()) p.Signal(os.Interrupt) } } func TestBodyContent(t *testing.T) { s := web.NewServer() s.Post("/", func(ctx web.Context) error { body := ctx.Request().Body() return ctx.String(string(body)) }) testServer(t, s, func(port string) { // Make a POST request with a body client := &http.Client{} req, err := http.NewRequest("POST", "http://127.0.0.1"+port+"/", strings.NewReader("Hello")) if err != nil { t.Fatalf("Error creating request: %v", err) } resp, err := client.Do(req) if err != nil { t.Fatalf("Error making request: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Errorf("Expected status 200, got %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } if string(body) != "Hello" { t.Errorf("Expected body 'Hello', got '%s'", string(body)) } }) }