Compare commits

...

3 Commits

Author SHA1 Message Date
47a1b56619 ref 2025-03-05 12:30:54 -06:00
9295e5445e ref 2 2025-03-05 12:23:12 -06:00
3bc0920f09 ref 1 2025-03-05 12:16:23 -06:00
9 changed files with 1018 additions and 137 deletions

View File

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2024 Go Copyright (c) 2024 Sharkk
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

526
bench/real_bench_test.go Normal file
View File

@ -0,0 +1,526 @@
package bench
import (
"bytes"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"testing"
"time"
web "git.sharkk.net/Go/Web"
)
// startTestServer starts a server for benchmarking on a random port
func startTestServer(b *testing.B) (string, web.Server) {
s := web.NewServer()
// Setup routes
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello, World!")
})
s.Get("/json", func(ctx web.Context) error {
ctx.Response().SetHeader("Content-Type", "application/json")
return ctx.String(`{"message":"Hello, World!","code":200,"success":true}`)
})
s.Post("/echo", func(ctx web.Context) error {
body := ctx.Request().Body()
return ctx.Bytes(body)
})
s.Get("/users/:id/posts/:postId", func(ctx web.Context) error {
userId := ctx.Request().Param("id")
postId := ctx.Request().Param("postId")
return ctx.String(userId + ":" + postId)
})
s.Get("/middleware-test", func(ctx web.Context) error {
return ctx.String("OK")
})
s.Get("/headers", func(ctx web.Context) error {
ctx.Response().SetHeader("X-Test-1", "Value1")
ctx.Response().SetHeader("X-Test-2", "Value2")
ctx.Response().SetHeader("X-Test-3", "Value3")
ctx.Response().SetHeader("Content-Type", "text/plain")
return ctx.String("Headers set")
})
s.Post("/submit", func(ctx web.Context) error {
return ctx.String("Received " + string(ctx.Request().Body()))
})
// Add middleware for middleware test
for i := 0; i < 5; i++ {
s.Use(func(ctx web.Context) error {
return ctx.Next()
})
}
// Find a free port
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
b.Fatalf("Failed to resolve TCP address: %v", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
b.Fatalf("Failed to listen on TCP: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
serverAddr := "localhost:" + strconv.Itoa(port)
// Start the server in a goroutine
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
s.Run(serverAddr)
}()
// Wait for server to start
wg.Wait()
time.Sleep(100 * time.Millisecond)
return serverAddr, s
}
// BenchmarkRealStaticGetRequest measures performance of real HTTP GET request
func BenchmarkRealStaticGetRequest(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "Hello, World!" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkRealJSONResponse measures performance of real HTTP JSON response
func BenchmarkRealJSONResponse(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/json"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid status: %d", resp.StatusCode)
}
}
}
// BenchmarkRealPostWithBody measures performance of real HTTP POST request with body
func BenchmarkRealPostWithBody(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/echo"
requestBody := strings.Repeat("Hello, World! ", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Post(url, "text/plain", strings.NewReader(requestBody))
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != requestBody {
b.Fatalf("Invalid response\nstatus=%d\nbody=%s", resp.StatusCode, body)
}
}
}
// BenchmarkRealRouteParams measures performance of real HTTP route parameter handling
func BenchmarkRealRouteParams(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/users/123/posts/456"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "123:456" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkRealHeaders measures real HTTP header handling performance
func BenchmarkRealHeaders(b *testing.B) {
serverAddr, _ := startTestServer(b)
url := "http://" + serverAddr + "/headers"
client := &http.Client{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
req, _ := http.NewRequest("GET", url, nil)
req.Header.Add("User-Agent", "Benchmark")
req.Header.Add("Accept", "*/*")
req.Header.Add("Authorization", "Bearer token12345")
resp, err := client.Do(req)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || resp.Header.Get("X-Test-1") != "Value1" {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
}
}
// BenchmarkRealParallelRequests measures real HTTP performance under concurrent load
func BenchmarkRealParallelRequests(b *testing.B) {
serverAddr, _ := startTestServer(b)
baseURL := "http://" + serverAddr
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
client := &http.Client{}
counter := 0
for pb.Next() {
var resp *http.Response
var err error
switch counter % 3 {
case 0:
resp, err = client.Get(baseURL + "/")
case 1:
id := strconv.Itoa(counter & 0xff)
resp, err = client.Get(baseURL + "/users/" + id + "/posts/456")
case 2:
data := "data" + strconv.Itoa(counter&0xff)
resp, err = client.Post(baseURL+"/submit", "text/plain", bytes.NewBufferString(data))
}
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, _ = io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
counter++
}
})
}
// startNetHTTPTestServer starts a standard net/http server for benchmarking
func startNetHTTPTestServer(b *testing.B) (string, *http.Server) {
mux := http.NewServeMux()
// Setup routes
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Write([]byte("Hello, World!"))
})
mux.HandleFunc("/json", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"message":"Hello, World!","code":200,"success":true}`))
})
mux.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
body, _ := io.ReadAll(r.Body)
w.Write(body)
})
mux.HandleFunc("/users/", func(w http.ResponseWriter, r *http.Request) {
parts := strings.Split(r.URL.Path, "/")
if len(parts) != 5 || parts[1] != "users" || parts[3] != "posts" || parts[0] != "" {
http.NotFound(w, r)
return
}
userId := parts[2]
postId := parts[4]
w.Write([]byte(userId + ":" + postId))
})
mux.HandleFunc("/middleware-test", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
})
mux.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test-1", "Value1")
w.Header().Set("X-Test-2", "Value2")
w.Header().Set("X-Test-3", "Value3")
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte("Headers set"))
})
mux.HandleFunc("/submit", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
body, _ := io.ReadAll(r.Body)
w.Write([]byte("Received " + string(body)))
})
// Find a free port
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
b.Fatalf("Failed to resolve TCP address: %v", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
b.Fatalf("Failed to listen on TCP: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
serverAddr := "localhost:" + strconv.Itoa(port)
// Start the server in a goroutine
server := &http.Server{
Addr: serverAddr,
Handler: mux,
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
server.Serve(listener)
}()
// Wait for server to start
wg.Wait()
time.Sleep(100 * time.Millisecond)
return serverAddr, server
}
// BenchmarkNetHTTPStaticGetRequest measures performance of net/http static GET request
func BenchmarkNetHTTPStaticGetRequest(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "Hello, World!" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkNetHTTPJSONResponse measures performance of net/http JSON response
func BenchmarkNetHTTPJSONResponse(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/json"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid status: %d", resp.StatusCode)
}
}
}
// BenchmarkNetHTTPPostWithBody measures performance of net/http POST request with body
func BenchmarkNetHTTPPostWithBody(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/echo"
requestBody := strings.Repeat("Hello, World! ", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Post(url, "text/plain", strings.NewReader(requestBody))
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != requestBody {
b.Fatalf("Invalid response\nstatus=%d\nbody=%s", resp.StatusCode, body)
}
}
}
// BenchmarkNetHTTPRouteParams measures performance of net/http route parameter handling
func BenchmarkNetHTTPRouteParams(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/users/123/posts/456"
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := http.Get(url)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || string(body) != "123:456" {
b.Fatalf("Invalid response: status=%d, body=%s", resp.StatusCode, body)
}
}
}
// BenchmarkNetHTTPHeaders measures net/http header handling performance
func BenchmarkNetHTTPHeaders(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
url := "http://" + serverAddr + "/headers"
client := &http.Client{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
req, _ := http.NewRequest("GET", url, nil)
req.Header.Add("User-Agent", "Benchmark")
req.Header.Add("Accept", "*/*")
req.Header.Add("Authorization", "Bearer token12345")
resp, err := client.Do(req)
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, err = io.ReadAll(resp.Body)
if err != nil {
b.Fatalf("Failed to read response: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 || resp.Header.Get("X-Test-1") != "Value1" {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
}
}
// BenchmarkNetHTTPParallelRequests measures net/http performance under concurrent load
func BenchmarkNetHTTPParallelRequests(b *testing.B) {
serverAddr, _ := startNetHTTPTestServer(b)
baseURL := "http://" + serverAddr
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
client := &http.Client{}
counter := 0
for pb.Next() {
var resp *http.Response
var err error
switch counter % 3 {
case 0:
resp, err = client.Get(baseURL + "/")
case 1:
id := strconv.Itoa(counter & 0xff)
resp, err = client.Get(baseURL + "/users/" + id + "/posts/456")
case 2:
data := "data" + strconv.Itoa(counter&0xff)
resp, err = client.Post(baseURL+"/submit", "text/plain", bytes.NewBufferString(data))
}
if err != nil {
b.Fatalf("Request failed: %v", err)
}
_, _ = io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != 200 {
b.Fatalf("Invalid response: status=%d", resp.StatusCode)
}
counter++
}
})
}

View File

@ -0,0 +1,165 @@
package bench
import (
"io"
"strconv"
"strings"
"testing"
web "git.sharkk.net/Go/Web"
)
// BenchmarkSyntheticStaticGetRequest measures performance of simple GET request with static response
func BenchmarkSyntheticStaticGetRequest(b *testing.B) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello, World!")
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/", nil, nil)
if response.Status() != 200 || string(response.Body()) != "Hello, World!" {
b.Fatalf("Invalid response: status=%d, body=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticJSONResponse measures performance of JSON serialization and response
func BenchmarkSyntheticJSONResponse(b *testing.B) {
s := web.NewServer()
s.Get("/json", func(ctx web.Context) error {
ctx.Response().SetHeader("Content-Type", "application/json")
return ctx.String(`{"message":"Hello, World!","code":200,"success":true}`)
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/json", nil, nil)
if response.Status() != 200 {
b.Fatalf("Invalid status: %d", response.Status())
}
}
}
// BenchmarkSyntheticPostWithBody measures performance of POST request with body processing
func BenchmarkSyntheticPostWithBody(b *testing.B) {
s := web.NewServer()
s.Post("/echo", func(ctx web.Context) error {
body := ctx.Request().Body()
return ctx.Bytes(body)
})
requestBody := strings.Repeat("Hello, World! ", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("POST", "/echo", nil, io.NopCloser(strings.NewReader(requestBody)))
if response.Status() != 200 || string(response.Body()) != requestBody {
b.Fatalf("Invalid response\nstatus=%d\nbody=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticRouteParams measures performance of route parameter extraction
func BenchmarkSyntheticRouteParams(b *testing.B) {
s := web.NewServer()
s.Get("/users/:id/posts/:postId", func(ctx web.Context) error {
userId := ctx.Request().Param("id")
postId := ctx.Request().Param("postId")
return ctx.String(userId + ":" + postId)
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/users/123/posts/456", nil, nil)
if response.Status() != 200 || string(response.Body()) != "123:456" {
b.Fatalf("Invalid response: status=%d, body=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticMiddleware measures the impact of middleware chain
func BenchmarkSyntheticMiddleware(b *testing.B) {
s := web.NewServer()
// Add 5 middleware functions
for i := 0; i < 5; i++ {
s.Use(func(ctx web.Context) error {
return ctx.Next()
})
}
s.Get("/middleware-test", func(ctx web.Context) error {
return ctx.String("OK")
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/middleware-test", nil, nil)
if response.Status() != 200 || string(response.Body()) != "OK" {
b.Fatalf("Invalid response: status=%d, body=%s", response.Status(), response.Body())
}
}
}
// BenchmarkSyntheticHeaders measures header handling performance
func BenchmarkSyntheticHeaders(b *testing.B) {
s := web.NewServer()
s.Get("/headers", func(ctx web.Context) error {
ctx.Response().SetHeader("X-Test-1", "Value1")
ctx.Response().SetHeader("X-Test-2", "Value2")
ctx.Response().SetHeader("X-Test-3", "Value3")
ctx.Response().SetHeader("Content-Type", "text/plain")
return ctx.String("Headers set")
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
response := s.Request("GET", "/headers", []web.Header{
{"User-Agent", "Benchmark"},
{"Accept", "*/*"},
{"Authorization", "Bearer token12345"},
}, nil)
if response.Status() != 200 || response.Header("X-Test-1") != "Value1" {
b.Fatalf("Invalid response: status=%d", response.Status())
}
}
}
// BenchmarkSyntheticParallelRequests measures performance under concurrent load
func BenchmarkSyntheticParallelRequests(b *testing.B) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello, World!")
})
s.Get("/users/:id", func(ctx web.Context) error {
return ctx.String("User " + ctx.Request().Param("id"))
})
s.Post("/submit", func(ctx web.Context) error {
return ctx.String("Received " + string(ctx.Request().Body()))
})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
counter := 0
for pb.Next() {
var response web.Response
switch counter % 3 {
case 0:
response = s.Request("GET", "/", nil, nil)
case 1:
id := strconv.Itoa(counter & 0xff)
response = s.Request("GET", "/users/"+id, nil, nil)
case 2:
data := "data" + strconv.Itoa(counter&0xff)
response = s.Request("POST", "/submit", nil,
io.NopCloser(strings.NewReader(data)))
}
if response.Status() != 200 {
b.Fatalf("Invalid response: status=%d", response.Status())
}
counter++
}
})
}

View File

@ -53,7 +53,7 @@ func (ctx *context) Next() error {
// Redirects the client to a different location with the specified status code. // Redirects the client to a different location with the specified status code.
func (ctx *context) Redirect(status int, location string) error { func (ctx *context) Redirect(status int, location string) error {
ctx.response.SetStatus(status) ctx.response.SetStatus(status)
ctx.response.SetHeader("Location", location) ctx.response.SetHeader(HeaderLocation, location)
return nil return nil
} }

View File

@ -1,6 +1,63 @@
package web package web
// Header represents an HTTP header with key and value
type Header struct { type Header struct {
Key string Key string
Value string Value string
} }
// Common HTTP header keys
const (
HeaderContentType = "Content-Type"
HeaderContentLength = "Content-Length"
HeaderHost = "Host"
HeaderAccept = "Accept"
HeaderUserAgent = "User-Agent"
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAcceptLanguage = "Accept-Language"
HeaderConnection = "Connection"
HeaderCookie = "Cookie"
HeaderSetCookie = "Set-Cookie"
HeaderLocation = "Location"
HeaderAuthorization = "Authorization"
HeaderCacheControl = "Cache-Control"
HeaderOrigin = "Origin"
HeaderReferer = "Referer"
HeaderTransferEncoding = "Transfer-Encoding"
)
// Pre-allocated common headers
var (
// Content type headers
HeaderContentTypeJSON = Header{Key: HeaderContentType, Value: "application/json"}
HeaderContentTypeHTML = Header{Key: HeaderContentType, Value: "text/html"}
HeaderContentTypePlain = Header{Key: HeaderContentType, Value: "text/plain"}
HeaderContentTypeXML = Header{Key: HeaderContentType, Value: "application/xml"}
HeaderContentTypeForm = Header{Key: HeaderContentType, Value: "application/x-www-form-urlencoded"}
HeaderContentTypeMultipart = Header{Key: HeaderContentType, Value: "multipart/form-data"}
// Connection headers
HeaderConnectionClose = Header{Key: HeaderConnection, Value: "close"}
HeaderConnectionKeepAlive = Header{Key: HeaderConnection, Value: "keep-alive"}
)
// FindHeader looks for a header by key in a slice of headers
func FindHeader(headers []Header, key string) (string, bool) {
for _, h := range headers {
if h.Key == key {
return h.Value, true
}
}
return "", false
}
// SetHeader sets a header value in a slice of headers
func SetHeader(headers *[]Header, key string, value string) {
for i, h := range *headers {
if h.Key == key {
(*headers)[i].Value = value
return
}
}
*headers = append(*headers, Header{Key: key, Value: value})
}

View File

@ -32,13 +32,8 @@ type request struct {
// Returns the header value for the given key. // Returns the header value for the given key.
func (req *request) Header(key string) string { func (req *request) Header(key string) string {
for _, header := range req.headers { value, _ := FindHeader(req.headers, key)
if header.Key == key { return value
return header.Value
}
}
return ""
} }
// Returns the requested host. // Returns the requested host.

View File

@ -28,25 +28,13 @@ func (res *response) Body() []byte {
// Returns the header value for the given key. // Returns the header value for the given key.
func (res *response) Header(key string) string { func (res *response) Header(key string) string {
for _, header := range res.headers { value, _ := FindHeader(res.headers, key)
if header.Key == key { return value
return header.Value
}
}
return ""
} }
// Sets the header value for the given key. // Sets the header value for the given key.
func (res *response) SetHeader(key string, value string) { func (res *response) SetHeader(key string, value string) {
for i, header := range res.headers { SetHeader(&res.headers, key, value)
if header.Key == key {
res.headers[i].Value = value
return
}
}
res.headers = append(res.headers, Header{Key: key, Value: value})
} }
// Replaces the response body with the new contents. // Replaces the response body with the new contents.

116
server.go
View File

@ -16,9 +16,56 @@ import (
router "git.sharkk.net/Go/Router" router "git.sharkk.net/Go/Router"
) )
// Pre-allocated response components
var (
// Common strings for response headers
contentLengthHeader = []byte(HeaderContentLength + ": ")
crlf = []byte("\r\n")
colonSpace = []byte(": ")
// Status line map for quick lookups
statusLines = map[int][]byte{
100: []byte("HTTP/1.1 100 Continue\r\n"),
101: []byte("HTTP/1.1 101 Switching Protocols\r\n"),
200: []byte("HTTP/1.1 200 OK\r\n"),
201: []byte("HTTP/1.1 201 Created\r\n"),
202: []byte("HTTP/1.1 202 Accepted\r\n"),
204: []byte("HTTP/1.1 204 No Content\r\n"),
206: []byte("HTTP/1.1 206 Partial Content\r\n"),
300: []byte("HTTP/1.1 300 Multiple Choices\r\n"),
301: []byte("HTTP/1.1 301 Moved Permanently\r\n"),
302: []byte("HTTP/1.1 302 Found\r\n"),
304: []byte("HTTP/1.1 304 Not Modified\r\n"),
307: []byte("HTTP/1.1 307 Temporary Redirect\r\n"),
308: []byte("HTTP/1.1 308 Permanent Redirect\r\n"),
400: []byte("HTTP/1.1 400 Bad Request\r\n"),
401: []byte("HTTP/1.1 401 Unauthorized\r\n"),
403: []byte("HTTP/1.1 403 Forbidden\r\n"),
404: []byte("HTTP/1.1 404 Not Found\r\n"),
405: []byte("HTTP/1.1 405 Method Not Allowed\r\n"),
406: []byte("HTTP/1.1 406 Not Acceptable\r\n"),
409: []byte("HTTP/1.1 409 Conflict\r\n"),
410: []byte("HTTP/1.1 410 Gone\r\n"),
412: []byte("HTTP/1.1 412 Precondition Failed\r\n"),
413: []byte("HTTP/1.1 413 Payload Too Large\r\n"),
415: []byte("HTTP/1.1 415 Unsupported Media Type\r\n"),
416: []byte("HTTP/1.1 416 Range Not Satisfiable\r\n"),
429: []byte("HTTP/1.1 429 Too Many Requests\r\n"),
500: []byte("HTTP/1.1 500 Internal Server Error\r\n"),
501: []byte("HTTP/1.1 501 Not Implemented\r\n"),
502: []byte("HTTP/1.1 502 Bad Gateway\r\n"),
503: []byte("HTTP/1.1 503 Service Unavailable\r\n"),
504: []byte("HTTP/1.1 504 Gateway Timeout\r\n"),
}
)
// Interface for an HTTP server. // Interface for an HTTP server.
type Server interface { type Server interface {
Get(path string, handler Handler) Get(path string, handler Handler)
Post(path string, handler Handler)
Put(path string, handler Handler)
Delete(path string, handler Handler)
Patch(path string, handler Handler)
Request(method string, path string, headers []Header, body io.Reader) Response Request(method string, path string, headers []Header, body io.Reader) Response
Router() *router.Router[Handler] Router() *router.Router[Handler]
Run(address string) error Run(address string) error
@ -29,6 +76,7 @@ type Server interface {
type server struct { type server struct {
handlers []Handler handlers []Handler
contextPool sync.Pool contextPool sync.Pool
bufferPool sync.Pool
router *router.Router[Handler] router *router.Router[Handler]
errorHandler func(Context, error) errorHandler func(Context, error)
} }
@ -57,6 +105,7 @@ func NewServer() Server {
} }
s.contextPool.New = func() any { return s.newContext() } s.contextPool.New = func() any { return s.newContext() }
s.bufferPool.New = func() any { return bytes.NewBuffer(make([]byte, 0, 1024)) }
return s return s
} }
@ -179,6 +228,7 @@ func (s *server) handleConnection(conn net.Conn) {
url = message[space+1 : lastSpace] url = message[space+1 : lastSpace]
// Add headers until we meet an empty line // Add headers until we meet an empty line
ctx.request.headers = ctx.request.headers[:0] // Reset headers without allocation
for { for {
message, err = ctx.reader.ReadString('\n') message, err = ctx.reader.ReadString('\n')
@ -206,17 +256,22 @@ func (s *server) handleConnection(conn net.Conn) {
} }
// Read the body, if any // Read the body, if any
if contentLength := ctx.request.Header("Content-Length"); contentLength != "" { if contentLength := ctx.request.Header(HeaderContentLength); contentLength != "" {
length, _ := strconv.Atoi(contentLength) length, _ := strconv.Atoi(contentLength)
ctx.request.body = make([]byte, length) if cap(ctx.request.body) >= length {
ctx.request.body = ctx.request.body[:length] // Reuse existing slice if possible
} else {
ctx.request.body = make([]byte, length)
}
ctx.reader.Read(ctx.request.body) ctx.reader.Read(ctx.request.body)
} else {
ctx.request.body = ctx.request.body[:0] // Empty the body slice without allocation
} }
// Handle the request // Handle the request
s.handleRequest(ctx, method, url, conn) s.handleRequest(ctx, method, url, conn)
// Clean up the context // Clean up the context - reset slices without allocation
ctx.request.headers = ctx.request.headers[:0]
ctx.request.body = ctx.request.body[:0] ctx.request.body = ctx.request.body[:0]
ctx.response.headers = ctx.response.headers[:0] ctx.response.headers = ctx.response.headers[:0]
ctx.response.body = ctx.response.body[:0] ctx.response.body = ctx.response.body[:0]
@ -226,7 +281,7 @@ func (s *server) handleConnection(conn net.Conn) {
} }
} }
// Handles the given request. // Handles the given request with reduced allocations.
func (s *server) handleRequest(ctx *context, method string, url string, writer io.Writer) { func (s *server) handleRequest(ctx *context, method string, url string, writer io.Writer) {
ctx.method = method ctx.method = method
ctx.scheme, ctx.host, ctx.path, ctx.query = parseURL(url) ctx.scheme, ctx.host, ctx.path, ctx.query = parseURL(url)
@ -237,23 +292,44 @@ func (s *server) handleRequest(ctx *context, method string, url string, writer i
s.errorHandler(ctx, err) s.errorHandler(ctx, err)
} }
tmp := bytes.Buffer{} // Get buffer from pool
tmp.WriteString("HTTP/1.1 ") buf := s.bufferPool.Get().(*bytes.Buffer)
tmp.WriteString(strconv.Itoa(int(ctx.status))) buf.Reset()
tmp.WriteString("\r\nContent-Length: ") defer s.bufferPool.Put(buf)
tmp.WriteString(strconv.Itoa(len(ctx.response.body)))
tmp.WriteString("\r\n")
for _, header := range ctx.response.headers { // Write status line using map lookup for efficiency
tmp.WriteString(header.Key) if statusLine, ok := statusLines[int(ctx.status)]; ok {
tmp.WriteString(": ") buf.Write(statusLine)
tmp.WriteString(header.Value) } else {
tmp.WriteString("\r\n") // For uncommon status codes, format the line dynamically
buf.WriteString("HTTP/1.1 ")
buf.WriteString(strconv.Itoa(int(ctx.status)))
buf.Write(crlf)
} }
tmp.WriteString("\r\n") // Write Content-Length header
tmp.Write(ctx.response.body) buf.Write(contentLengthHeader)
writer.Write(tmp.Bytes()) buf.WriteString(strconv.Itoa(len(ctx.response.body)))
buf.Write(crlf)
// Write all response headers
for _, header := range ctx.response.headers {
buf.WriteString(header.Key)
buf.Write(colonSpace)
buf.WriteString(header.Value)
buf.Write(crlf)
}
// End headers
buf.Write(crlf)
// Write headers
writer.Write(buf.Bytes())
// Write body directly to avoid another copy
if len(ctx.response.body) > 0 {
writer.Write(ctx.response.body)
}
} }
// Allocates a new context with the default state. // Allocates a new context with the default state.
@ -262,7 +338,7 @@ func (s *server) newContext() *context {
server: s, server: s,
request: request{ request: request{
reader: bufio.NewReader(nil), reader: bufio.NewReader(nil),
body: make([]byte, 0), body: make([]byte, 0, 1024),
headers: make([]Header, 0, 8), headers: make([]Header, 0, 8),
params: make([]router.Parameter, 0, 8), params: make([]router.Parameter, 0, 8),
}, },

View File

@ -1,78 +1,121 @@
package web_tests package web_tests
import ( import (
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"syscall"
"testing" "testing"
"time"
web "git.sharkk.net/Go/Web" web "git.sharkk.net/Go/Web"
) )
const port = ":8888" // testServer starts a server on a random port, runs the given test function,
// and ensures proper server shutdown afterward.
func testServer(t *testing.T, s web.Server, test func(port string)) {
// Get a free port by letting the OS assign one
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Failed to find available port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
portStr := fmt.Sprintf(":%d", port)
listener.Close()
// Start the server in a goroutine
ready := make(chan bool)
done := make(chan bool)
go func() {
// Signal that we're about to start the server
ready <- true
// Start the server (this blocks until interrupted)
s.Run(portStr)
// Signal that the server has shut down
done <- true
}()
// Wait for server to begin starting
<-ready
// Give it a moment to actually start listening
time.Sleep(50 * time.Millisecond)
// Run the test function with the assigned port
test(portStr)
// Send interrupt to stop the server
p, _ := os.FindProcess(os.Getpid())
p.Signal(os.Interrupt)
// Wait for server to confirm shutdown with timeout
select {
case <-done:
// Server shutdown properly
case <-time.After(100 * time.Millisecond):
t.Log("Warning: Server didn't shut down within timeout")
}
}
func TestRun(t *testing.T) { func TestRun(t *testing.T) {
s := web.NewServer() s := web.NewServer()
go func() { // Add a handler for the root path
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) s.Get("/", func(ctx web.Context) error {
return ctx.String("Hello")
_, err := http.Get("http://127.0.0.1" + port + "/")
if err != nil {
t.Errorf("Error: %s", err)
}
}()
s.Run(port)
}
func TestPanic(t *testing.T) {
s := web.NewServer()
s.Get("/panic", func(ctx web.Context) error {
panic("Something unbelievable happened")
}) })
defer func() { testServer(t, s, func(port string) {
r := recover() resp, err := http.Get("http://127.0.0.1" + port + "/")
if err != nil {
if r == nil { t.Fatalf("Error making request: %v", err)
t.Error("Didn't panic")
} }
}() defer resp.Body.Close()
s.Request("GET", "/panic", nil, nil) if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Error reading response body: %v", err)
}
if string(body) != "Hello" {
t.Errorf("Expected body 'Hello', got '%s'", string(body))
}
})
} }
func TestBadRequest(t *testing.T) { func TestBadRequest(t *testing.T) {
s := web.NewServer() s := web.NewServer()
go func() { testServer(t, s, func(port string) {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) conn, err := net.Dial("tcp", "127.0.0.1"+port)
conn, err := net.Dial("tcp", port)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error connecting: %v", err)
} }
defer conn.Close() defer conn.Close()
_, err = io.WriteString(conn, "BadRequest\r\n\r\n") _, err = io.WriteString(conn, "BadRequest\r\n\r\n")
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error writing: %v", err)
} }
response, err := io.ReadAll(conn) response, err := io.ReadAll(conn)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error reading: %v", err)
} }
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Error: %s", string(response))
}
}()
s.Run(port) if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Expected 400 response, got: %s", string(response))
}
})
} }
func TestBadRequestHeader(t *testing.T) { func TestBadRequestHeader(t *testing.T) {
@ -82,60 +125,54 @@ func TestBadRequestHeader(t *testing.T) {
return ctx.String("Hello") return ctx.String("Hello")
}) })
go func() { testServer(t, s, func(port string) {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) conn, err := net.Dial("tcp", "127.0.0.1"+port)
conn, err := net.Dial("tcp", port)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error connecting: %v", err)
} }
defer conn.Close() defer conn.Close()
_, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n") _, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n")
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error writing: %v", err)
} }
buffer := make([]byte, len("HTTP/1.1 200")) buffer := make([]byte, len("HTTP/1.1 200"))
_, err = conn.Read(buffer) _, err = conn.Read(buffer)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error reading: %v", err)
} }
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Error: %s", string(buffer))
}
}()
s.Run(port) if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Expected '200' response, got: %s", string(buffer))
}
})
} }
func TestBadRequestMethod(t *testing.T) { func TestBadRequestMethod(t *testing.T) {
s := web.NewServer() s := web.NewServer()
go func() { testServer(t, s, func(port string) {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) conn, err := net.Dial("tcp", "127.0.0.1"+port)
conn, err := net.Dial("tcp", port)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error connecting: %v", err)
} }
defer conn.Close() defer conn.Close()
_, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n") _, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n")
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error writing: %v", err)
} }
response, err := io.ReadAll(conn) response, err := io.ReadAll(conn)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error reading: %v", err)
} }
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Error: %s", string(response))
}
}()
s.Run(port) if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
t.Errorf("Expected 400 response, got: %s", string(response))
}
})
} }
func TestBadRequestProtocol(t *testing.T) { func TestBadRequestProtocol(t *testing.T) {
@ -145,67 +182,85 @@ func TestBadRequestProtocol(t *testing.T) {
return ctx.String("Hello") return ctx.String("Hello")
}) })
go func() { testServer(t, s, func(port string) {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) conn, err := net.Dial("tcp", "127.0.0.1"+port)
conn, err := net.Dial("tcp", port)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error connecting: %v", err)
} }
defer conn.Close() defer conn.Close()
_, err = io.WriteString(conn, "GET /\r\n\r\n") _, err = io.WriteString(conn, "GET /\r\n\r\n")
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error writing: %v", err)
} }
buffer := make([]byte, len("HTTP/1.1 200")) buffer := make([]byte, len("HTTP/1.1 200"))
_, err = conn.Read(buffer) _, err = conn.Read(buffer)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error reading: %v", err)
} }
if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Error: %s", string(buffer))
}
}()
s.Run(port) if string(buffer) != "HTTP/1.1 200" {
t.Errorf("Expected '200' response, got: %s", string(buffer))
}
})
} }
func TestEarlyClose(t *testing.T) { func TestEarlyClose(t *testing.T) {
s := web.NewServer() s := web.NewServer()
go func() { testServer(t, s, func(port string) {
defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) conn, err := net.Dial("tcp", "127.0.0.1"+port)
conn, err := net.Dial("tcp", port)
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error connecting: %v", err)
} }
_, err = io.WriteString(conn, "GET /\r\n") _, err = io.WriteString(conn, "GET /\r\n")
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error writing: %v", err)
} }
// Close the connection early
err = conn.Close() err = conn.Close()
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Error closing connection: %v", err)
} }
}()
s.Run(port) // No assertion needed - if the server doesn't crash, the test passes
})
} }
func TestUnavailablePort(t *testing.T) { func TestUnavailablePort(t *testing.T) {
listener, err := net.Listen("tcp", port) // Get a port and keep it occupied
listener, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
t.Errorf("Error: %s", err) t.Fatalf("Failed to find available port: %v", err)
} }
port := listener.Addr().(*net.TCPAddr).Port
portStr := fmt.Sprintf(":%d", port)
defer listener.Close() defer listener.Close()
// Try to run the server on the occupied port
s := web.NewServer() s := web.NewServer()
s.Run(port)
// Run in a goroutine with a timeout
errChan := make(chan error, 1)
go func() {
errChan <- s.Run(portStr)
}()
// Check if we get an error within a reasonable time
select {
case err := <-errChan:
if err == nil {
t.Error("Expected error when binding to unavailable port")
}
case <-time.After(100 * time.Millisecond):
t.Error("Server.Run() didn't return with error on unavailable port")
// Try to stop the server anyway
p, _ := os.FindProcess(os.Getpid())
p.Signal(os.Interrupt)
}
} }
func TestBodyContent(t *testing.T) { func TestBodyContent(t *testing.T) {
@ -216,12 +271,31 @@ func TestBodyContent(t *testing.T) {
return ctx.String(string(body)) return ctx.String(string(body))
}) })
response := s.Request("POST", "/", nil, io.NopCloser(strings.NewReader("Hello"))) testServer(t, s, func(port string) {
if response.Status() != 200 { // Make a POST request with a body
t.Errorf("Error: %s", response.Body()) client := &http.Client{}
} req, err := http.NewRequest("POST", "http://127.0.0.1"+port+"/", strings.NewReader("Hello"))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
if string(response.Body()) != "Hello" { resp, err := client.Do(req)
t.Errorf("Error: %s", response.Body()) if err != nil {
} t.Fatalf("Error making request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Error reading response body: %v", err)
}
if string(body) != "Hello" {
t.Errorf("Expected body 'Hello', got '%s'", string(body))
}
})
} }