diff --git a/send/Send.go b/send/Send.go new file mode 100644 index 0000000..4bea84b --- /dev/null +++ b/send/Send.go @@ -0,0 +1,49 @@ +package send + +import ( + "encoding/json" + + web "git.sharkk.net/Go/Web" +) + +// Sends the body with the content type set to text/css +func CSS(ctx web.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/css") + return ctx.String(body) +} + +// Sends the body with the content type set to text/csv +func CSV(ctx web.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/csv") + return ctx.String(body) +} + +// Sends the body with the content type set to text/html +func HTML(ctx web.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/html") + return ctx.String(body) +} + +// Sends the body with the content type set to text/javascript +func JS(ctx web.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/javascript") + return ctx.String(body) +} + +// Encodes the object in JSON format and sends it with the content type set to application/json +func JSON(ctx web.Context, object any) error { + ctx.Response().SetHeader("Content-Type", "application/json") + return json.NewEncoder(ctx.Response()).Encode(object) +} + +// Sends the body with the content type set to text/plain +func Text(ctx web.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/plain") + return ctx.String(body) +} + +// Sends the body with the content type set to text/xml +func XML(ctx web.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/xml") + return ctx.String(body) +} diff --git a/tests/context_test.go b/tests/context_test.go new file mode 100644 index 0000000..0c766c6 --- /dev/null +++ b/tests/context_test.go @@ -0,0 +1,88 @@ +package web_tests + +import ( + "errors" + "testing" + + web "git.sharkk.net/Go/Web" +) + +func TestBytes(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.Bytes([]byte("Hello")) + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 200 { + t.Error(response.Status()) + } + if string(response.Body()) != "Hello" { + t.Error(string(response.Body())) + } +} + +func TestString(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.String("Hello") + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 200 { + t.Error(response.Status()) + } + if string(response.Body()) != "Hello" { + t.Error(string(response.Body())) + } +} + +func TestError(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.Status(401).Error("Not logged in") + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 401 { + t.Error(response.Status()) + } + if string(response.Body()) != "" { + t.Error(string(response.Body())) + } +} + +func TestErrorMultiple(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.Status(401).Error("Not logged in", errors.New("Missing auth token")) + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 401 { + t.Error(response.Status()) + } + if string(response.Body()) != "" { + t.Error(string(response.Body())) + } +} + +func TestRedirect(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.Redirect(301, "/target") + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 301 { + t.Error(response.Status()) + } + if response.Header("Location") != "/target" { + t.Error(response.Header("Location")) + } +} diff --git a/tests/request_test.go b/tests/request_test.go new file mode 100644 index 0000000..420b28d --- /dev/null +++ b/tests/request_test.go @@ -0,0 +1,70 @@ +package web_tests + +import ( + "fmt" + "testing" + + web "git.sharkk.net/Go/Web" +) + +func TestRequest(t *testing.T) { + s := web.NewServer() + + s.Get("/request", func(ctx web.Context) error { + req := ctx.Request() + method := req.Method() + scheme := req.Scheme() + host := req.Host() + path := req.Path() + return ctx.String(fmt.Sprintf("%s %s %s %s", method, scheme, host, path)) + }) + + response := s.Request("GET", "http://example.com/request?x=1", []web.Header{{"Accept", "*/*"}}, nil) + if response.Status() != 200 { + t.Errorf("Error: %s", response.Body()) + } + + if string(response.Body()) != "GET http example.com /request" { + t.Errorf("Error: %s", response.Body()) + } +} + +func TestRequestHeader(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + accept := ctx.Request().Header("Accept") + empty := ctx.Request().Header("") + return ctx.String(accept + empty) + }) + + response := s.Request("GET", "/", []web.Header{{"Accept", "*/*"}}, nil) + + if response.Status() != 200 { + t.Errorf("Error: %s", response.Body()) + } + + if string(response.Body()) != "*/*" { + t.Errorf("Error: %s", response.Body()) + } +} + +func TestRequestParam(t *testing.T) { + s := web.NewServer() + + s.Get("/blog/:article", func(ctx web.Context) error { + article := ctx.Request().Param("article") + empty := ctx.Request().Param("") + return ctx.String(article + empty) + }) + + response := s.Request("GET", "/blog/my-article", nil, nil) + + if response.Status() != 200 { + t.Errorf("Error: %s", response.Body()) + } + + if string(response.Body()) != "my-article" { + t.Errorf("Error: %s", response.Body()) + } +} diff --git a/tests/response_test.go b/tests/response_test.go new file mode 100644 index 0000000..8d231ad --- /dev/null +++ b/tests/response_test.go @@ -0,0 +1,137 @@ +package web_tests + +import ( + "bytes" + "compress/gzip" + "io" + "reflect" + "testing" + + web "git.sharkk.net/Go/Web" +) + +func TestWrite(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + _, err := ctx.Response().Write([]byte("Hello")) + return err + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 200 { + t.Error(response.Status()) + } + if string(response.Body()) != "Hello" { + t.Error(string(response.Body())) + } +} + +func TestWriteString(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + _, err := io.WriteString(ctx.Response(), "Hello") + return err + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 200 { + t.Error(response.Status()) + } + if string(response.Body()) != "Hello" { + t.Error(string(response.Body())) + } +} + +func TestResponseCompression(t *testing.T) { + s := web.NewServer() + uncompressed := bytes.Repeat([]byte("This text should be compressed to a size smaller than the original."), 5) + + s.Use(func(ctx web.Context) error { + defer func() { + body := ctx.Response().Body() + ctx.Response().SetBody(nil) + zip := gzip.NewWriter(ctx.Response()) + zip.Write(body) + zip.Close() + }() + + return ctx.Next() + }) + + s.Get("/", func(ctx web.Context) error { + return ctx.Bytes(uncompressed) + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 200 { + t.Error(response.Status()) + } + if len(response.Body()) >= len(uncompressed) { + t.Error("Response is larger than original") + } + + reader, err := gzip.NewReader(bytes.NewReader(response.Body())) + if err != nil { + t.Error(err) + } + + decompressed, err := io.ReadAll(reader) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(decompressed, uncompressed) { + t.Error(string(decompressed)) + } +} + +func TestResponseHeader(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + ctx.Response().SetHeader("Content-Type", "text/plain") + contentType := ctx.Response().Header("Content-Type") + return ctx.String(contentType) + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 200 { + t.Error(response.Status()) + } + + if response.Header("Content-Type") != "text/plain" { + t.Error(response.Header("Content-Type")) + } + + if response.Header("Non existent header") != "" { + t.Error(response.Header("Non existent header")) + } + + if string(response.Body()) != "text/plain" { + t.Error(string(response.Body())) + } +} + +func TestResponseHeaderOverwrite(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + ctx.Response().SetHeader("Content-Type", "text/plain") + ctx.Response().SetHeader("Content-Type", "text/html") + return nil + }) + + response := s.Request("GET", "/", nil, nil) + if response.Status() != 200 { + t.Error(response.Status()) + } + + if response.Header("Content-Type") != "text/html" { + t.Error(response.Header("Content-Type")) + } + + if string(response.Body()) != "" { + t.Error(string(response.Body())) + } +} diff --git a/tests/send_test.go b/tests/send_test.go new file mode 100644 index 0000000..06fcd69 --- /dev/null +++ b/tests/send_test.go @@ -0,0 +1,75 @@ +package web_tests + +import ( + "testing" + + web "git.sharkk.net/Go/Web" + "git.sharkk.net/Go/Web/send" +) + +func TestContentTypes(t *testing.T) { + s := web.NewServer() + + s.Get("/css", func(ctx web.Context) error { + return send.CSS(ctx, "body{}") + }) + + s.Get("/csv", func(ctx web.Context) error { + return send.CSV(ctx, "ID;Name\n") + }) + + s.Get("/html", func(ctx web.Context) error { + return send.HTML(ctx, "") + }) + + s.Get("/js", func(ctx web.Context) error { + return send.JS(ctx, "console.log(42)") + }) + + s.Get("/json", func(ctx web.Context) error { + return send.JSON(ctx, struct{ Name string }{Name: "User 1"}) + }) + + s.Get("/text", func(ctx web.Context) error { + return send.Text(ctx, "Hello") + }) + + s.Get("/xml", func(ctx web.Context) error { + return send.XML(ctx, "") + }) + + tests := []struct { + Method string + URL string + Body string + Status int + Response string + ContentType string + }{ + {Method: "GET", URL: "/css", Status: 200, Response: "body{}", ContentType: "text/css"}, + {Method: "GET", URL: "/csv", Status: 200, Response: "ID;Name\n", ContentType: "text/csv"}, + {Method: "GET", URL: "/html", Status: 200, Response: "", ContentType: "text/html"}, + {Method: "GET", URL: "/js", Status: 200, Response: "console.log(42)", ContentType: "text/javascript"}, + {Method: "GET", URL: "/json", Status: 200, Response: "{\"Name\":\"User 1\"}\n", ContentType: "application/json"}, + {Method: "GET", URL: "/text", Status: 200, Response: "Hello", ContentType: "text/plain"}, + {Method: "GET", URL: "/xml", Status: 200, Response: "", ContentType: "text/xml"}, + } + + for _, test := range tests { + t.Run(test.URL, func(t *testing.T) { + response := s.Request(test.Method, "http://example.com"+test.URL, nil, nil) + + if response.Status() != test.Status { + t.Errorf("Expected status %d, got %d", test.Status, response.Status()) + } + + if response.Header("Content-Type") != test.ContentType { + t.Errorf("Expected content type %s, got %s", test.ContentType, response.Header("Content-Type")) + } + + if string(response.Body()) != test.Response { + t.Errorf("Expected response %s, got %s", test.Response, string(response.Body())) + } + }) + } +} diff --git a/tests/server_test.go b/tests/server_test.go new file mode 100644 index 0000000..fac0a9d --- /dev/null +++ b/tests/server_test.go @@ -0,0 +1,208 @@ +package web_tests + +import ( + "io" + "net" + "net/http" + "syscall" + "testing" + + web "git.sharkk.net/Go/Web" +) + +const port = ":8888" + +func TestRun(t *testing.T) { + s := web.NewServer() + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + _, err := http.Get("http://127.0.0.1" + port + "/") + if err != nil { + t.Errorf("Error: %s", err) + } + }() + + s.Run(port) +} + +func TestPanic(t *testing.T) { + s := web.NewServer() + + s.Get("/panic", func(ctx web.Context) error { + panic("Something unbelievable happened") + }) + + defer func() { + r := recover() + + if r == nil { + t.Error("Didn't panic") + } + }() + + s.Request("GET", "/panic", nil, nil) +} + +func TestBadRequest(t *testing.T) { + s := web.NewServer() + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", port) + if err != nil { + t.Errorf("Error: %s", err) + } + defer conn.Close() + + _, err = io.WriteString(conn, "BadRequest\r\n\r\n") + if err != nil { + t.Errorf("Error: %s", err) + } + + response, err := io.ReadAll(conn) + if err != nil { + t.Errorf("Error: %s", err) + } + if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" { + t.Errorf("Error: %s", string(response)) + } + }() + + s.Run(port) +} + +func TestBadRequestHeader(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.String("Hello") + }) + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", port) + if err != nil { + t.Errorf("Error: %s", err) + } + defer conn.Close() + + _, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n") + if err != nil { + t.Errorf("Error: %s", err) + } + + buffer := make([]byte, len("HTTP/1.1 200")) + _, err = conn.Read(buffer) + if err != nil { + t.Errorf("Error: %s", err) + } + if string(buffer) != "HTTP/1.1 200" { + t.Errorf("Error: %s", string(buffer)) + } + }() + + s.Run(port) +} + +func TestBadRequestMethod(t *testing.T) { + s := web.NewServer() + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", port) + if err != nil { + t.Errorf("Error: %s", err) + } + defer conn.Close() + + _, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n") + if err != nil { + t.Errorf("Error: %s", err) + } + + response, err := io.ReadAll(conn) + if err != nil { + t.Errorf("Error: %s", err) + } + if string(response) != "HTTP/1.1 400 Bad Request\r\n\r\n" { + t.Errorf("Error: %s", string(response)) + } + }() + + s.Run(port) +} + +func TestBadRequestProtocol(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.String("Hello") + }) + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", port) + if err != nil { + t.Errorf("Error: %s", err) + } + defer conn.Close() + + _, err = io.WriteString(conn, "GET /\r\n\r\n") + if err != nil { + t.Errorf("Error: %s", err) + } + + buffer := make([]byte, len("HTTP/1.1 200")) + _, err = conn.Read(buffer) + if err != nil { + t.Errorf("Error: %s", err) + } + if string(buffer) != "HTTP/1.1 200" { + t.Errorf("Error: %s", string(buffer)) + } + }() + + s.Run(port) +} + +func TestEarlyClose(t *testing.T) { + s := web.NewServer() + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", port) + if err != nil { + t.Errorf("Error: %s", err) + } + + _, err = io.WriteString(conn, "GET /\r\n") + if err != nil { + t.Errorf("Error: %s", err) + } + + err = conn.Close() + if err != nil { + t.Errorf("Error: %s", err) + } + }() + + s.Run(port) +} + +func TestUnavailablePort(t *testing.T) { + listener, err := net.Listen("tcp", port) + if err != nil { + t.Errorf("Error: %s", err) + } + defer listener.Close() + + s := web.NewServer() + s.Run(port) +}