From f9d077690b44a0eeccacb61358d5675a723cc8ca Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Fri, 23 Aug 2024 22:36:32 -0500 Subject: [PATCH] Upload Eduard's code, with updated comments --- Context.go | 80 +++++++++++++++++ HTTP.go | 41 +++++++++ Handler.go | 3 + Header.go | 6 ++ Request.go | 82 +++++++++++++++++ Response.go | 77 ++++++++++++++++ Server.go | 248 ++++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 ++ go.sum | 2 + 9 files changed, 544 insertions(+) create mode 100644 Context.go create mode 100644 HTTP.go create mode 100644 Handler.go create mode 100644 Header.go create mode 100644 Request.go create mode 100644 Response.go create mode 100644 Server.go create mode 100644 go.mod create mode 100644 go.sum diff --git a/Context.go b/Context.go new file mode 100644 index 0000000..85a0a9a --- /dev/null +++ b/Context.go @@ -0,0 +1,80 @@ +package web + +import "errors" + +// Interface for a request and its response. +type Context interface { + Bytes([]byte) error + Error(...any) error + Next() error + Redirect(int, string) error + Request() Request + Response() Response + Status(int) Context + String(string) error +} + +// Contains the request and response data. +type context struct { + request + response + server *server + handlerCount uint8 +} + +// Adds a raw byte slice to the response body. +func (ctx *context) Bytes(body []byte) error { + ctx.response.body = append(ctx.response.body, body...) + return nil +} + +// Provides a convenient way to wrap multiple errors. +func (ctx *context) Error(messages ...any) error { + var combined []error + + for _, msg := range messages { + switch err := msg.(type) { + case error: + combined = append(combined, err) + case string: + combined = append(combined, errors.New(err)) + } + } + + return errors.Join(combined...) +} + +// Executes the next handler in the middleware chain. +func (ctx *context) Next() error { + ctx.handlerCount++ + return ctx.server.handlers[ctx.handlerCount](ctx) +} + +// Redirects the client to a different location with the specified status code. +func (ctx *context) Redirect(status int, location string) error { + ctx.response.SetStatus(status) + ctx.response.SetHeader("Location", location) + return nil +} + +// Returns the HTTP request. +func (ctx *context) Request() Request { + return &ctx.request +} + +// Returns the HTTP response. +func (ctx *context) Response() Response { + return &ctx.response +} + +// Sets the HTTP status of the response and returns the context for method chaining. +func (ctx *context) Status(status int) Context { + ctx.response.SetStatus(status) + return ctx +} + +// Adds the given string to the response body. +func (ctx *context) String(body string) error { + ctx.response.body = append(ctx.response.body, body...) + return nil +} diff --git a/HTTP.go b/HTTP.go new file mode 100644 index 0000000..b27b311 --- /dev/null +++ b/HTTP.go @@ -0,0 +1,41 @@ +package web + +import "strings" + +// Returns true if the given string is a valid HTTP request method. +func isRequestMethod(method string) bool { + switch method { + case "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH": + return true + default: + return false + } +} + +// Parses a URL and returns the scheme, host, path and query. +func parseURL(url string) (scheme string, host string, path string, query string) { + schemePos := strings.Index(url, "://") + + if schemePos != -1 { + scheme = url[:schemePos] + url = url[schemePos+len("://"):] + } + + pathPos := strings.IndexByte(url, '/') + + if pathPos != -1 { + host = url[:pathPos] + url = url[pathPos:] + } + + queryPos := strings.IndexByte(url, '?') + + if queryPos != -1 { + path = url[:queryPos] + query = url[queryPos+1:] + return + } + + path = url + return +} diff --git a/Handler.go b/Handler.go new file mode 100644 index 0000000..3b3d5ab --- /dev/null +++ b/Handler.go @@ -0,0 +1,3 @@ +package web + +type Handler func(Context) error diff --git a/Header.go b/Header.go new file mode 100644 index 0000000..ae86680 --- /dev/null +++ b/Header.go @@ -0,0 +1,6 @@ +package web + +type Header struct { + Key string + Value string +} diff --git a/Request.go b/Request.go new file mode 100644 index 0000000..f8e65a6 --- /dev/null +++ b/Request.go @@ -0,0 +1,82 @@ +package web + +import ( + "bufio" + + router "git.sharkk.net/Go/Router" +) + +// Interface for HTTP requests. +type Request interface { + Header(string) string + Host() string + Method() string + Path() string + Scheme() string + Param(string) string +} + +// Represents the HTTP request used in the given context. +type request struct { + reader *bufio.Reader + scheme string + host string + method string + path string + query string + headers []Header + body []byte + params []router.Parameter +} + +// Returns the header value for the given key. +func (req *request) Header(key string) string { + for _, header := range req.headers { + if header.Key == key { + return header.Value + } + } + + return "" +} + +// Returns the requested host. +func (req *request) Host() string { + return req.host +} + +// Returns the request method. +func (req *request) Method() string { + return req.method +} + +// Retrieves a parameter. +func (req *request) Param(name string) string { + for i := range len(req.params) { + p := req.params[i] + + if p.Key == name { + return p.Value + } + } + + return "" +} + +// Returns the requested path. +func (req *request) Path() string { + return req.path +} + +// Returns either 'http', 'https', or an empty string. +func (req request) Scheme() string { + return req.scheme +} + +// Adds a new parameter to the request. +func (req *request) addParameter(key string, value string) { + req.params = append(req.params, router.Parameter{ + Key: key, + Value: value, + }) +} diff --git a/Response.go b/Response.go new file mode 100644 index 0000000..c1ae09d --- /dev/null +++ b/Response.go @@ -0,0 +1,77 @@ +package web + +import "io" + +// Interface for an HTTP response. +type Response interface { + io.Writer + io.StringWriter + Body() []byte + Header(string) string + SetHeader(key string, value string) + SetBody([]byte) + SetStatus(int) + Status() int +} + +// Represents the HTTP response used in the given context. +type response struct { + body []byte + headers []Header + status uint16 +} + +// Returns the response body. +func (res *response) Body() []byte { + return res.body +} + +// Returns the header value for the given key. +func (res *response) Header(key string) string { + for _, header := range res.headers { + if header.Key == key { + return header.Value + } + } + + return "" +} + +// Sets the header value for the given key. +func (res *response) SetHeader(key string, value string) { + for i, header := range res.headers { + if header.Key == key { + res.headers[i].Value = value + return + } + } + + res.headers = append(res.headers, Header{Key: key, Value: value}) +} + +// Replaces the response body with the new contents. +func (res *response) SetBody(body []byte) { + res.body = body +} + +// Sets the HTTP status code. +func (res *response) SetStatus(status int) { + res.status = uint16(status) +} + +// Returns the HTTP status code. +func (res *response) Status() int { + return int(res.status) +} + +// Implements the io.Writer interface for the body. +func (res *response) Write(body []byte) (int, error) { + res.body = append(res.body, body...) + return len(body), nil +} + +// Implements the io.StringWriter interface for the body. +func (res *response) WriteString(body string) (int, error) { + res.body = append(res.body, body...) + return len(body), nil +} diff --git a/Server.go b/Server.go new file mode 100644 index 0000000..28f532c --- /dev/null +++ b/Server.go @@ -0,0 +1,248 @@ +package web + +import ( + "bufio" + "bytes" + "io" + "log" + "net" + "os" + "os/signal" + "strconv" + "strings" + "sync" + "syscall" + + router "git.sharkk.net/Go/Router" +) + +// Interface for an HTTP server. +type Server interface { + Get(path string, handler Handler) + Request(method string, path string, headers []Header, body io.Reader) Response + Router() *router.Router[Handler] + Run(address string) error + Use(handlers ...Handler) +} + +// HTTP server. +type server struct { + handlers []Handler + contextPool sync.Pool + router *router.Router[Handler] + errorHandler func(Context, error) +} + +// Creates a new HTTP server. +func NewServer() Server { + r := &router.Router[Handler]{} + s := &server{ + router: r, + handlers: []Handler{ + func(c Context) error { + ctx := c.(*context) + handler := r.LookupNoAlloc(ctx.request.method, ctx.request.path, ctx.request.addParameter) + + if handler == nil { + ctx.SetStatus(404) + return nil + } + + return handler(c) + }, + }, + errorHandler: func(ctx Context, err error) { + log.Println(ctx.Request().Path(), err) + }, + } + + s.contextPool.New = func() any { return s.newContext() } + return s +} + +// Registers a handler to be called when the given GET path has been requested. +func (s *server) Get(path string, handler Handler) { + s.Router().Add("GET", path, handler) +} + +// Performs a synthetic request and returns the response. +// This function keeps the response in memory so it's slightly slower than a real request. +// However it is very useful inside tests where you don't want to spin up a real web server. +func (s *server) Request(method string, url string, headers []Header, body io.Reader) Response { + ctx := s.newContext() + ctx.request.headers = headers + s.handleRequest(ctx, method, url, io.Discard) + return ctx.Response() +} + +// Starts the server on the given address. +func (s *server) Run(address string) error { + listener, err := net.Listen("tcp", address) + + if err != nil { + return err + } + + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + + if err != nil { + continue + } + + go s.handleConnection(conn) + } + }() + + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + <-stop + return nil +} + +// Returns the router used by the server. +func (s *server) Router() *router.Router[Handler] { + return s.router +} + +// Adds handlers to your handlers chain. +func (s *server) Use(handlers ...Handler) { + last := s.handlers[len(s.handlers)-1] + s.handlers = append(s.handlers[:len(s.handlers)-1], handlers...) + s.handlers = append(s.handlers, last) +} + +// Handles an accepted connection. +func (s *server) handleConnection(conn net.Conn) { + var ( + ctx = s.contextPool.Get().(*context) + method string + url string + ) + + ctx.reader.Reset(conn) + + defer conn.Close() + defer s.contextPool.Put(ctx) + + for { + // Read the HTTP request line + message, err := ctx.reader.ReadString('\n') + + if err != nil { + return + } + + space := strings.IndexByte(message, ' ') + + if space <= 0 { + io.WriteString(conn, "HTTP/1.1 400 Bad Request\r\n\r\n") + return + } + + method = message[:space] + + if !isRequestMethod(method) { + io.WriteString(conn, "HTTP/1.1 400 Bad Request\r\n\r\n") + return + } + + lastSpace := strings.LastIndexByte(message, ' ') + + if lastSpace == space { + lastSpace = len(message) - len("\r\n") + } + + url = message[space+1 : lastSpace] + + // Add headers until we meet an empty line + for { + message, err = ctx.reader.ReadString('\n') + + if err != nil { + return + } + + if message == "\r\n" { + break + } + + colon := strings.IndexByte(message, ':') + + if colon <= 0 { + continue + } + + key := message[:colon] + value := message[colon+2 : len(message)-2] + + ctx.request.headers = append(ctx.request.headers, Header{ + Key: key, + Value: value, + }) + } + + // Handle the request + s.handleRequest(ctx, method, url, conn) + + // Clean up the context + ctx.request.headers = ctx.request.headers[:0] + ctx.request.body = ctx.request.body[:0] + ctx.response.headers = ctx.response.headers[:0] + ctx.response.body = ctx.response.body[:0] + ctx.params = ctx.params[:0] + ctx.handlerCount = 0 + ctx.status = 200 + } +} + +// Handles the given request. +func (s *server) handleRequest(ctx *context, method string, url string, writer io.Writer) { + ctx.method = method + ctx.scheme, ctx.host, ctx.path, ctx.query = parseURL(url) + + err := s.handlers[0](ctx) + + if err != nil { + s.errorHandler(ctx, err) + } + + tmp := bytes.Buffer{} + tmp.WriteString("HTTP/1.1 ") + tmp.WriteString(strconv.Itoa(int(ctx.status))) + tmp.WriteString("\r\nContent-Length: ") + tmp.WriteString(strconv.Itoa(len(ctx.response.body))) + tmp.WriteString("\r\n") + + for _, header := range ctx.response.headers { + tmp.WriteString(header.Key) + tmp.WriteString(": ") + tmp.WriteString(header.Value) + tmp.WriteString("\r\n") + } + + tmp.WriteString("\r\n") + tmp.Write(ctx.response.body) + writer.Write(tmp.Bytes()) +} + +// Allocates a new context with the default state. +func (s *server) newContext() *context { + return &context{ + server: s, + request: request{ + reader: bufio.NewReader(nil), + body: make([]byte, 0), + headers: make([]Header, 0, 8), + params: make([]router.Parameter, 0, 8), + }, + response: response{ + body: make([]byte, 0, 1024), + headers: make([]Header, 0, 8), + status: 200, + }, + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..20cfb33 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.sharkk.net/Go/Web + +go 1.23.0 + +require git.sharkk.net/Go/Router v0.0.0-20240824032014-8d9ebd32141b diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c55c912 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +git.sharkk.net/Go/Router v0.0.0-20240824032014-8d9ebd32141b h1:RaKQ/5Uu4oD9HNIJfMLwT8JRY8S+0MSv7OhWtX99OE0= +git.sharkk.net/Go/Router v0.0.0-20240824032014-8d9ebd32141b/go.mod h1:xUexHGjhY7bDBD3RsurM0SoUseaKxDov3jNiZRT67fA=