Web/tests/server_test.go

302 lines
6.7 KiB
Go

package web_test
import (
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"testing"
"time"
web "git.sharkk.net/Go/Web"
)
// testServer starts a server on a random port, runs the given test function,
// and ensures proper server shutdown afterward.
func testServer(t *testing.T, s web.Server, test func(port string)) {
// Get a free port by letting the OS assign one
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Failed to find available port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
portStr := fmt.Sprintf(":%d", port)
listener.Close()
// Start the server in a goroutine
ready := make(chan bool)
done := make(chan bool)
go func() {
// Signal that we're about to start the server
ready <- true
// Start the server (this blocks until interrupted)
s.Run(portStr)
// Signal that the server has shut down
done <- true
}()
// Wait for server to begin starting
<-ready
// Give it a moment to actually start listening
time.Sleep(50 * time.Millisecond)
// Run the test function with the assigned port
test(portStr)
// Send interrupt to stop the server
p, _ := os.FindProcess(os.Getpid())
p.Signal(os.Interrupt)
// Wait for server to confirm shutdown with timeout
select {
case <-done:
// Server shutdown properly
case <-time.After(100 * time.Millisecond):
t.Log("Warning: Server didn't shut down within timeout")
}
}
func TestRun(t *testing.T) {
s := web.NewServer()
// Add a handler for the root path
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello")
})
testServer(t, s, func(port string) {
resp, err := http.Get("http://127.0.0.1" + port + "/")
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Error reading response body: %v", err)
}
if string(body) != "Hello" {
t.Errorf("Expected body 'Hello', got '%s'", string(body))
}
})
}
func TestBadRequest(t *testing.T) {
s := web.NewServer()
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "BadRequest\r\n\r\n")
if err != nil {
t.Fatalf("Error writing: %v", err)
}
response, err := io.ReadAll(conn)
if err != nil {
t.Fatalf("Error reading: %v", err)
}
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Expected 400 response, got: %s", string(response))
}
})
}
func TestBadRequestHeader(t *testing.T) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello")
})
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n")
if err != nil {
t.Fatalf("Error writing: %v", err)
}
buffer := make([]byte, len("HTTP/1.1 200"))
_, err = conn.Read(buffer)
if err != nil {
t.Fatalf("Error reading: %v", err)
}
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Expected '200' response, got: %s", string(buffer))
}
})
}
func TestBadRequestMethod(t *testing.T) {
s := web.NewServer()
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n")
if err != nil {
t.Fatalf("Error writing: %v", err)
}
response, err := io.ReadAll(conn)
if err != nil {
t.Fatalf("Error reading: %v", err)
}
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Expected 400 response, got: %s", string(response))
}
})
}
func TestBadRequestProtocol(t *testing.T) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello")
})
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "GET /\r\n\r\n")
if err != nil {
t.Fatalf("Error writing: %v", err)
}
buffer := make([]byte, len("HTTP/1.1 200"))
_, err = conn.Read(buffer)
if err != nil {
t.Fatalf("Error reading: %v", err)
}
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Expected '200' response, got: %s", string(buffer))
}
})
}
func TestEarlyClose(t *testing.T) {
s := web.NewServer()
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Fatalf("Error connecting: %v", err)
}
_, err = io.WriteString(conn, "GET /\r\n")
if err != nil {
t.Fatalf("Error writing: %v", err)
}
// Close the connection early
err = conn.Close()
if err != nil {
t.Fatalf("Error closing connection: %v", err)
}
// No assertion needed - if the server doesn't crash, the test passes
})
}
func TestUnavailablePort(t *testing.T) {
// Get a port and keep it occupied
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Failed to find available port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
portStr := fmt.Sprintf(":%d", port)
defer listener.Close()
// Try to run the server on the occupied port
s := web.NewServer()
// Run in a goroutine with a timeout
errChan := make(chan error, 1)
go func() {
errChan <- s.Run(portStr)
}()
// Check if we get an error within a reasonable time
select {
case err := <-errChan:
if err == nil {
t.Error("Expected error when binding to unavailable port")
}
case <-time.After(100 * time.Millisecond):
t.Error("Server.Run() didn't return with error on unavailable port")
// Try to stop the server anyway
p, _ := os.FindProcess(os.Getpid())
p.Signal(os.Interrupt)
}
}
func TestBodyContent(t *testing.T) {
s := web.NewServer()
s.Post("/", func(ctx web.Context) error {
body := ctx.Request().Body()
return ctx.String(string(body))
})
testServer(t, s, func(port string) {
// Make a POST request with a body
client := &http.Client{}
req, err := http.NewRequest("POST", "http://127.0.0.1"+port+"/", strings.NewReader("Hello"))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Error reading response body: %v", err)
}
if string(body) != "Hello" {
t.Errorf("Expected body 'Hello', got '%s'", string(body))
}
})
}