302 lines
6.7 KiB
Go
302 lines
6.7 KiB
Go
package web_tests
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
web "git.sharkk.net/Go/Web"
|
|
)
|
|
|
|
// testServer starts a server on a random port, runs the given test function,
|
|
// and ensures proper server shutdown afterward.
|
|
func testServer(t *testing.T, s web.Server, test func(port string)) {
|
|
// Get a free port by letting the OS assign one
|
|
listener, err := net.Listen("tcp", ":0")
|
|
if err != nil {
|
|
t.Fatalf("Failed to find available port: %v", err)
|
|
}
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
portStr := fmt.Sprintf(":%d", port)
|
|
listener.Close()
|
|
|
|
// Start the server in a goroutine
|
|
ready := make(chan bool)
|
|
done := make(chan bool)
|
|
|
|
go func() {
|
|
// Signal that we're about to start the server
|
|
ready <- true
|
|
|
|
// Start the server (this blocks until interrupted)
|
|
s.Run(portStr)
|
|
|
|
// Signal that the server has shut down
|
|
done <- true
|
|
}()
|
|
|
|
// Wait for server to begin starting
|
|
<-ready
|
|
|
|
// Give it a moment to actually start listening
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Run the test function with the assigned port
|
|
test(portStr)
|
|
|
|
// Send interrupt to stop the server
|
|
p, _ := os.FindProcess(os.Getpid())
|
|
p.Signal(os.Interrupt)
|
|
|
|
// Wait for server to confirm shutdown with timeout
|
|
select {
|
|
case <-done:
|
|
// Server shutdown properly
|
|
case <-time.After(100 * time.Millisecond):
|
|
t.Log("Warning: Server didn't shut down within timeout")
|
|
}
|
|
}
|
|
|
|
func TestRun(t *testing.T) {
|
|
s := web.NewServer()
|
|
|
|
// Add a handler for the root path
|
|
s.Get("/", func(ctx web.Context) error {
|
|
return ctx.String("Hello")
|
|
})
|
|
|
|
testServer(t, s, func(port string) {
|
|
resp, err := http.Get("http://127.0.0.1" + port + "/")
|
|
if err != nil {
|
|
t.Fatalf("Error making request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
if string(body) != "Hello" {
|
|
t.Errorf("Expected body 'Hello', got '%s'", string(body))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBadRequest(t *testing.T) {
|
|
s := web.NewServer()
|
|
|
|
testServer(t, s, func(port string) {
|
|
conn, err := net.Dial("tcp", "127.0.0.1"+port)
|
|
if err != nil {
|
|
t.Fatalf("Error connecting: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, err = io.WriteString(conn, "BadRequest\r\n\r\n")
|
|
if err != nil {
|
|
t.Fatalf("Error writing: %v", err)
|
|
}
|
|
|
|
response, err := io.ReadAll(conn)
|
|
if err != nil {
|
|
t.Fatalf("Error reading: %v", err)
|
|
}
|
|
|
|
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
|
|
t.Errorf("Expected 400 response, got: %s", string(response))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBadRequestHeader(t *testing.T) {
|
|
s := web.NewServer()
|
|
|
|
s.Get("/", func(ctx web.Context) error {
|
|
return ctx.String("Hello")
|
|
})
|
|
|
|
testServer(t, s, func(port string) {
|
|
conn, err := net.Dial("tcp", "127.0.0.1"+port)
|
|
if err != nil {
|
|
t.Fatalf("Error connecting: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n")
|
|
if err != nil {
|
|
t.Fatalf("Error writing: %v", err)
|
|
}
|
|
|
|
buffer := make([]byte, len("HTTP/1.1 200"))
|
|
_, err = conn.Read(buffer)
|
|
if err != nil {
|
|
t.Fatalf("Error reading: %v", err)
|
|
}
|
|
|
|
if string(buffer) != "HTTP/1.1 200" {
|
|
t.Errorf("Expected '200' response, got: %s", string(buffer))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBadRequestMethod(t *testing.T) {
|
|
s := web.NewServer()
|
|
|
|
testServer(t, s, func(port string) {
|
|
conn, err := net.Dial("tcp", "127.0.0.1"+port)
|
|
if err != nil {
|
|
t.Fatalf("Error connecting: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n")
|
|
if err != nil {
|
|
t.Fatalf("Error writing: %v", err)
|
|
}
|
|
|
|
response, err := io.ReadAll(conn)
|
|
if err != nil {
|
|
t.Fatalf("Error reading: %v", err)
|
|
}
|
|
|
|
if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" {
|
|
t.Errorf("Expected 400 response, got: %s", string(response))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBadRequestProtocol(t *testing.T) {
|
|
s := web.NewServer()
|
|
|
|
s.Get("/", func(ctx web.Context) error {
|
|
return ctx.String("Hello")
|
|
})
|
|
|
|
testServer(t, s, func(port string) {
|
|
conn, err := net.Dial("tcp", "127.0.0.1"+port)
|
|
if err != nil {
|
|
t.Fatalf("Error connecting: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, err = io.WriteString(conn, "GET /\r\n\r\n")
|
|
if err != nil {
|
|
t.Fatalf("Error writing: %v", err)
|
|
}
|
|
|
|
buffer := make([]byte, len("HTTP/1.1 200"))
|
|
_, err = conn.Read(buffer)
|
|
if err != nil {
|
|
t.Fatalf("Error reading: %v", err)
|
|
}
|
|
|
|
if string(buffer) != "HTTP/1.1 200" {
|
|
t.Errorf("Expected '200' response, got: %s", string(buffer))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestEarlyClose(t *testing.T) {
|
|
s := web.NewServer()
|
|
|
|
testServer(t, s, func(port string) {
|
|
conn, err := net.Dial("tcp", "127.0.0.1"+port)
|
|
if err != nil {
|
|
t.Fatalf("Error connecting: %v", err)
|
|
}
|
|
|
|
_, err = io.WriteString(conn, "GET /\r\n")
|
|
if err != nil {
|
|
t.Fatalf("Error writing: %v", err)
|
|
}
|
|
|
|
// Close the connection early
|
|
err = conn.Close()
|
|
if err != nil {
|
|
t.Fatalf("Error closing connection: %v", err)
|
|
}
|
|
|
|
// No assertion needed - if the server doesn't crash, the test passes
|
|
})
|
|
}
|
|
|
|
func TestUnavailablePort(t *testing.T) {
|
|
// Get a port and keep it occupied
|
|
listener, err := net.Listen("tcp", ":0")
|
|
if err != nil {
|
|
t.Fatalf("Failed to find available port: %v", err)
|
|
}
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
portStr := fmt.Sprintf(":%d", port)
|
|
defer listener.Close()
|
|
|
|
// Try to run the server on the occupied port
|
|
s := web.NewServer()
|
|
|
|
// Run in a goroutine with a timeout
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
errChan <- s.Run(portStr)
|
|
}()
|
|
|
|
// Check if we get an error within a reasonable time
|
|
select {
|
|
case err := <-errChan:
|
|
if err == nil {
|
|
t.Error("Expected error when binding to unavailable port")
|
|
}
|
|
case <-time.After(100 * time.Millisecond):
|
|
t.Error("Server.Run() didn't return with error on unavailable port")
|
|
// Try to stop the server anyway
|
|
p, _ := os.FindProcess(os.Getpid())
|
|
p.Signal(os.Interrupt)
|
|
}
|
|
}
|
|
|
|
func TestBodyContent(t *testing.T) {
|
|
s := web.NewServer()
|
|
|
|
s.Post("/", func(ctx web.Context) error {
|
|
body := ctx.Request().Body()
|
|
return ctx.String(string(body))
|
|
})
|
|
|
|
testServer(t, s, func(port string) {
|
|
// Make a POST request with a body
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest("POST", "http://127.0.0.1"+port+"/", strings.NewReader("Hello"))
|
|
if err != nil {
|
|
t.Fatalf("Error creating request: %v", err)
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Error making request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
if string(body) != "Hello" {
|
|
t.Errorf("Expected body 'Hello', got '%s'", string(body))
|
|
}
|
|
})
|
|
}
|