This commit is contained in:
Sky Johnson 2025-03-05 12:16:23 -06:00
parent 85f04865b5
commit 3bc0920f09
5 changed files with 863 additions and 94 deletions

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2024 Go
Copyright (c) 2024 Sharkk
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

526
bench/real_bench_test.go Normal file
View File

@ -0,0 +1,526 @@
package bench
import (
"bytes"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"testing"
"time"
web "git.sharkk.net/Go/Web"
)
// startTestServer starts a server for benchmarking on a random port
func startTestServer(b *testing.B) (string, web.Server) {
s := web.NewServer()
// Setup routes
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello, World!")
})
s.Get("/json", func(ctx web.Context) error {
ctx.Response().SetHeader("Content-Type", "application/json")
return ctx.String(`{"message":"Hello, World!","code":200,"success":true}`)
})
s.Post("/echo", func(ctx web.Context) error {
body := ctx.Request().Body()
return ctx.Bytes(body)
})
s.Get("/users/:id/posts/:postId", func(ctx web.Context) error {
userId := ctx.Request().Param("id")
postId := ctx.Request().Param("postId")
return ctx.String(userId + ":" + postId)
})
s.Get("/middleware-test", func(ctx web.Context) error {
return ctx.String("OK")
})
s.Get("/headers", func(ctx web.Context) error {
ctx.Response().SetHeader("X-Test-1", "Value1")
ctx.Response().SetHeader("X-Test-2", "Value2")
ctx.Response().SetHeader("X-Test-3", "Value3")
ctx.Response().SetHeader("Content-Type", "text/plain")
return ctx.String("Headers set")
})
s.Post("/submit", func(ctx web.Context) error {
return ctx.String("Received " + string(ctx.Request().Body()))
})
// Add middleware for middleware test
for i := 0; i < 5; i++ {
s.Use(func(ctx web.Context) error {
return ctx.Next()
})
}
// Find a free port
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
b.Fatalf("Failed to resolve TCP address: %v", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
b.Fatalf("Failed to listen on TCP: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
serverAddr := "localhost:" + strconv.Itoa(port)
// Start the server in a goroutine
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
s.Run(serverAddr)
}()
// Wait for server to start
wg.Wait()
time.Sleep(100 * time.Millisecond)
return serverAddr, s
}
// BenchmarkRealStaticGetRequest measures performance of real HTTP GET request
func BenchmarkRealStaticGetRequest(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "Hello, World!" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkRealJSONResponse measures performance of real HTTP JSON response
func BenchmarkRealJSONResponse(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/json"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid status: %d", resp.StatusCode)
}
}
}
// BenchmarkRealPostWithBody measures performance of real HTTP POST request with body
func BenchmarkRealPostWithBody(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/echo"
requestBody := strings.Repeat("Hello, World! ", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Post(url, "text/plain", strings.NewReader(requestBody))
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != requestBody {
b.Fatalf("Invalid response\nstatus=%d\nbody=%s", resp.StatusCode, body)
}
}
}
// BenchmarkRealRouteParams measures performance of real HTTP route parameter handling
func BenchmarkRealRouteParams(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/users/123/posts/456"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "123:456" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkRealHeaders measures real HTTP header handling performance
func BenchmarkRealHeaders(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/headers"
client := &http.Client{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
req, _ := http.NewRequest("GET", url, nil)
req.Header.Add("User-Agent", "Benchmark")
req.Header.Add("Accept", "*/*")
req.Header.Add("Authorization", "Bearer token12345")
resp, err := client.Do(req)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || resp.Header.Get("X-Test-1") != "Value1" {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
}
}
// BenchmarkRealParallelRequests measures real HTTP performance under concurrent load
func BenchmarkRealParallelRequests(b *testing.B) {
serverAddr, _ := startTestServer(b)
baseURL := "http://" + serverAddr
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
client := &http.Client{}
counter := 0
for pb.Next() {
var resp *http.Response
var err error
switch counter % 3 {
case 0:
resp, err = client.Get(baseURL + "/")
case 1:
id := strconv.Itoa(counter & 0xff)
resp, err = client.Get(baseURL + "/users/" + id + "/posts/456")
case 2:
data := "data" + strconv.Itoa(counter&0xff)
resp, err = client.Post(baseURL+"/submit", "text/plain", bytes.NewBufferString(data))
}
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, _ = io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
counter++
}
})
}
// startNetHTTPTestServer starts a standard net/http server for benchmarking
func startNetHTTPTestServer(b *testing.B) (string, *http.Server) {
mux := http.NewServeMux()
// Setup routes
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Write([]byte("Hello, World!"))
})
mux.HandleFunc("/json", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"message":"Hello, World!","code":200,"success":true}`))
})
mux.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
body, _ := io.ReadAll(r.Body)
w.Write(body)
})
mux.HandleFunc("/users/", func(w http.ResponseWriter, r *http.Request) {
parts := strings.Split(r.URL.Path, "/")
if len(parts) != 5 || parts[1] != "users" || parts[3] != "posts" || parts[0] != "" {
http.NotFound(w, r)
return
}
userId := parts[2]
postId := parts[4]
w.Write([]byte(userId + ":" + postId))
})
mux.HandleFunc("/middleware-test", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
})
mux.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test-1", "Value1")
w.Header().Set("X-Test-2", "Value2")
w.Header().Set("X-Test-3", "Value3")
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte("Headers set"))
})
mux.HandleFunc("/submit", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
body, _ := io.ReadAll(r.Body)
w.Write([]byte("Received " + string(body)))
})
// Find a free port
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
b.Fatalf("Failed to resolve TCP address: %v", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
b.Fatalf("Failed to listen on TCP: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
serverAddr := "localhost:" + strconv.Itoa(port)
// Start the server in a goroutine
server := &http.Server{
Addr: serverAddr,
Handler: mux,
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
server.Serve(listener)
}()
// Wait for server to start
wg.Wait()
time.Sleep(100 * time.Millisecond)
return serverAddr, server
}
// BenchmarkNetHTTPStaticGetRequest measures performance of net/http static GET request
func BenchmarkNetHTTPStaticGetRequest(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "Hello, World!" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkNetHTTPJSONResponse measures performance of net/http JSON response
func BenchmarkNetHTTPJSONResponse(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/json"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid status: %d", resp.StatusCode)
}
}
}
// BenchmarkNetHTTPPostWithBody measures performance of net/http POST request with body
func BenchmarkNetHTTPPostWithBody(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/echo"
requestBody := strings.Repeat("Hello, World! ", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Post(url, "text/plain", strings.NewReader(requestBody))
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != requestBody {
b.Fatalf("Invalid response\nstatus=%d\nbody=%s", resp.StatusCode, body)
}
}
}
// BenchmarkNetHTTPRouteParams measures performance of net/http route parameter handling
func BenchmarkNetHTTPRouteParams(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/users/123/posts/456"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "123:456" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkNetHTTPHeaders measures net/http header handling performance
func BenchmarkNetHTTPHeaders(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/headers"
client := &http.Client{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
req, _ := http.NewRequest("GET", url, nil)
req.Header.Add("User-Agent", "Benchmark")
req.Header.Add("Accept", "*/*")
req.Header.Add("Authorization", "Bearer token12345")
resp, err := client.Do(req)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || resp.Header.Get("X-Test-1") != "Value1" {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
}
}
// BenchmarkNetHTTPParallelRequests measures net/http performance under concurrent load
func BenchmarkNetHTTPParallelRequests(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
baseURL := "http://" + serverAddr
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
client := &http.Client{}
counter := 0
for pb.Next() {
var resp *http.Response
var err error
switch counter % 3 {
case 0:
resp, err = client.Get(baseURL + "/")
case 1:
id := strconv.Itoa(counter & 0xff)
resp, err = client.Get(baseURL + "/users/" + id + "/posts/456")
case 2:
data := "data" + strconv.Itoa(counter&0xff)
resp, err = client.Post(baseURL+"/submit", "text/plain", bytes.NewBufferString(data))
}
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, _ = io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
counter++
}
})
}

View File

@ -0,0 +1,165 @@
package bench
import (
"io"
"strconv"
"strings"
"testing"
web "git.sharkk.net/Go/Web"
)
// BenchmarkSyntheticStaticGetRequest measures performance of simple GET request with static response
func BenchmarkSyntheticStaticGetRequest(b *testing.B) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello, World!")
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/", nil, nil)
if response.Status() != 200 || string(response.Body()) != "Hello, World!" {
b.Fatalf("Invalid response: status=%d, body=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticJSONResponse measures performance of JSON serialization and response
func BenchmarkSyntheticJSONResponse(b *testing.B) {
s := web.NewServer()
s.Get("/json", func(ctx web.Context) error {
ctx.Response().SetHeader("Content-Type", "application/json")
return ctx.String(`{"message":"Hello, World!","code":200,"success":true}`)
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/json", nil, nil)
if response.Status() != 200 {
b.Fatalf("Invalid status: %d", response.Status())
}
}
}
// BenchmarkSyntheticPostWithBody measures performance of POST request with body processing
func BenchmarkSyntheticPostWithBody(b *testing.B) {
s := web.NewServer()
s.Post("/echo", func(ctx web.Context) error {
body := ctx.Request().Body()
return ctx.Bytes(body)
})
requestBody := strings.Repeat("Hello, World! ", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("POST", "/echo", nil, io.NopCloser(strings.NewReader(requestBody)))
if response.Status() != 200 || string(response.Body()) != requestBody {
b.Fatalf("Invalid response\nstatus=%d\nbody=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticRouteParams measures performance of route parameter extraction
func BenchmarkSyntheticRouteParams(b *testing.B) {
s := web.NewServer()
s.Get("/users/:id/posts/:postId", func(ctx web.Context) error {
userId := ctx.Request().Param("id")
postId := ctx.Request().Param("postId")
return ctx.String(userId + ":" + postId)
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/users/123/posts/456", nil, nil)
if response.Status() != 200 || string(response.Body()) != "123:456" {
b.Fatalf("Invalid response: status=%d, body=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticMiddleware measures the impact of middleware chain
func BenchmarkSyntheticMiddleware(b *testing.B) {
s := web.NewServer()
// Add 5 middleware functions
for i := 0; i < 5; i++ {
s.Use(func(ctx web.Context) error {
return ctx.Next()
})
}
s.Get("/middleware-test", func(ctx web.Context) error {
return ctx.String("OK")
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/middleware-test", nil, nil)
if response.Status() != 200 || string(response.Body()) != "OK" {
b.Fatalf("Invalid response: status=%d, body=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticHeaders measures header handling performance
func BenchmarkSyntheticHeaders(b *testing.B) {
s := web.NewServer()
s.Get("/headers", func(ctx web.Context) error {
ctx.Response().SetHeader("X-Test-1", "Value1")
ctx.Response().SetHeader("X-Test-2", "Value2")
ctx.Response().SetHeader("X-Test-3", "Value3")
ctx.Response().SetHeader("Content-Type", "text/plain")
return ctx.String("Headers set")
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/headers", []web.Header{
{"User-Agent", "Benchmark"},
{"Accept", "*/*"},
{"Authorization", "Bearer token12345"},
}, nil)
if response.Status() != 200 || response.Header("X-Test-1") != "Value1" {
b.Fatalf("Invalid response: status=%d", response.Status())
}
}
}
// BenchmarkSyntheticParallelRequests measures performance under concurrent load
func BenchmarkSyntheticParallelRequests(b *testing.B) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello, World!")
})
s.Get("/users/:id", func(ctx web.Context) error {
return ctx.String("User " + ctx.Request().Param("id"))
})
s.Post("/submit", func(ctx web.Context) error {
return ctx.String("Received " + string(ctx.Request().Body()))
})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
counter := 0
for pb.Next() {
var response web.Response
switch counter % 3 {
case 0:
response = s.Request("GET", "/", nil, nil)
case 1:
id := strconv.Itoa(counter & 0xff)
response = s.Request("GET", "/users/"+id, nil, nil)
case 2:
data := "data" + strconv.Itoa(counter&0xff)
response = s.Request("POST", "/submit", nil,
io.NopCloser(strings.NewReader(data)))
}
if response.Status() != 200 {
b.Fatalf("Invalid response: status=%d", response.Status())
}
counter++
}
})
}

View File

@ -19,6 +19,10 @@ import (
// Interface for an HTTP server.
type Server interface {
Get(path string, handler Handler)
Post(path string, handler Handler)
Put(path string, handler Handler)
Delete(path string, handler Handler)
Patch(path string, handler Handler)
Request(method string, path string, headers []Header, body io.Reader) Response
Router() *router.Router[Handler]
Run(address string) error

View File

@ -1,78 +1,121 @@
package web_tests
import (
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"syscall"
"testing"
"time"
web "git.sharkk.net/Go/Web"
)
const port = ":8888"
// testServer starts a server on a random port, runs the given test function,
// and ensures proper server shutdown afterward.
func testServer(t *testing.T, s web.Server, test func(port string)) {
// Get a free port by letting the OS assign one
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Failed to find available port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
portStr := fmt.Sprintf(":%d", port)
listener.Close()
// Start the server in a goroutine
ready := make(chan bool)
done := make(chan bool)
go func() {
// Signal that we're about to start the server
ready <- true
// Start the server (this blocks until interrupted)
s.Run(portStr)
// Signal that the server has shut down
done <- true
}()
// Wait for server to begin starting
<-ready
// Give it a moment to actually start listening
time.Sleep(50 * time.Millisecond)
// Run the test function with the assigned port
test(portStr)
// Send interrupt to stop the server
p, _ := os.FindProcess(os.Getpid())
p.Signal(os.Interrupt)
// Wait for server to confirm shutdown with timeout
select {
case <-done:
// Server shutdown properly
case <-time.After(100 * time.Millisecond):
t.Log("Warning: Server didn't shut down within timeout")
}
}
func TestRun(t *testing.T) {
s := web.NewServer()
go func() {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
_, err := http.Get("http://127.0.0.1" + port + "/")
if err != nil {
t.Errorf("Error: %s", err)
}
}()
s.Run(port)
}
func TestPanic(t *testing.T) {
s := web.NewServer()
s.Get("/panic", func(ctx web.Context) error {
panic("Something unbelievable happened")
// Add a handler for the root path
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello")
})
defer func() {
r := recover()
if r == nil {
t.Error("Didn't panic")
testServer(t, s, func(port string) {
resp, err := http.Get("http://127.0.0.1" + port + "/")
if err != nil {
t.Fatalf("Error making request: %v", err)
}
}()
defer resp.Body.Close()
s.Request("GET", "/panic", nil, nil)
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Error reading response body: %v", err)
}
if string(body) != "Hello" {
t.Errorf("Expected body 'Hello', got '%s'", string(body))
}
})
}
func TestBadRequest(t *testing.T) {
s := web.NewServer()
go func() {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
conn, err := net.Dial("tcp", port)
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "BadRequest\r\n\r\n")
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error writing: %v", err)
}
response, err := io.ReadAll(conn)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error reading: %v", err)
}
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Error: %s", string(response))
}
}()
s.Run(port)
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Expected 400 response, got: %s", string(response))
}
})
}
func TestBadRequestHeader(t *testing.T) {
@ -82,60 +125,54 @@ func TestBadRequestHeader(t *testing.T) {
return ctx.String("Hello")
})
go func() {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
conn, err := net.Dial("tcp", port)
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n")
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error writing: %v", err)
}
buffer := make([]byte, len("HTTP/1.1 200"))
_, err = conn.Read(buffer)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error reading: %v", err)
}
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Error: %s", string(buffer))
}
}()
s.Run(port)
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Expected '200' response, got: %s", string(buffer))
}
})
}
func TestBadRequestMethod(t *testing.T) {
s := web.NewServer()
go func() {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
conn, err := net.Dial("tcp", port)
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n")
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error writing: %v", err)
}
response, err := io.ReadAll(conn)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error reading: %v", err)
}
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Error: %s", string(response))
}
}()
s.Run(port)
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Expected 400 response, got: %s", string(response))
}
})
}
func TestBadRequestProtocol(t *testing.T) {
@ -145,67 +182,85 @@ func TestBadRequestProtocol(t *testing.T) {
return ctx.String("Hello")
})
go func() {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
conn, err := net.Dial("tcp", port)
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error connecting: %v", err)
}
defer conn.Close()
_, err = io.WriteString(conn, "GET /\r\n\r\n")
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error writing: %v", err)
}
buffer := make([]byte, len("HTTP/1.1 200"))
_, err = conn.Read(buffer)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error reading: %v", err)
}
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Error: %s", string(buffer))
}
}()
s.Run(port)
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Expected '200' response, got: %s", string(buffer))
}
})
}
func TestEarlyClose(t *testing.T) {
s := web.NewServer()
go func() {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
conn, err := net.Dial("tcp", port)
testServer(t, s, func(port string) {
conn, err := net.Dial("tcp", "127.0.0.1"+port)
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error connecting: %v", err)
}
_, err = io.WriteString(conn, "GET /\r\n")
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error writing: %v", err)
}
// Close the connection early
err = conn.Close()
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Error closing connection: %v", err)
}
}()
s.Run(port)
// No assertion needed - if the server doesn't crash, the test passes
})
}
func TestUnavailablePort(t *testing.T) {
listener, err := net.Listen("tcp", port)
// Get a port and keep it occupied
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Errorf("Error: %s", err)
t.Fatalf("Failed to find available port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
portStr := fmt.Sprintf(":%d", port)
defer listener.Close()
// Try to run the server on the occupied port
s := web.NewServer()
s.Run(port)
// Run in a goroutine with a timeout
errChan := make(chan error, 1)
go func() {
errChan <- s.Run(portStr)
}()
// Check if we get an error within a reasonable time
select {
case err := <-errChan:
if err == nil {
t.Error("Expected error when binding to unavailable port")
}
case <-time.After(100 * time.Millisecond):
t.Error("Server.Run() didn't return with error on unavailable port")
// Try to stop the server anyway
p, _ := os.FindProcess(os.Getpid())
p.Signal(os.Interrupt)
}
}
func TestBodyContent(t *testing.T) {
@ -216,12 +271,31 @@ func TestBodyContent(t *testing.T) {
return ctx.String(string(body))
})
response := s.Request("POST", "/", nil, io.NopCloser(strings.NewReader("Hello")))
if response.Status() != 200 {
t.Errorf("Error: %s", response.Body())
}
testServer(t, s, func(port string) {
// Make a POST request with a body
client := &http.Client{}
req, err := http.NewRequest("POST", "http://127.0.0.1"+port+"/", strings.NewReader("Hello"))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
if string(response.Body()) != "Hello" {
t.Errorf("Error: %s", response.Body())
}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Error reading response body: %v", err)
}
if string(body) != "Hello" {
t.Errorf("Expected body 'Hello', got '%s'", string(body))
}
})
}