diff --git a/Context.go b/Context.go new file mode 100644 index 0000000..85a0a9a --- /dev/null +++ b/Context.go @@ -0,0 +1,80 @@ +package web + +import "errors" + +// Interface for a request and its response. +type Context interface { + Bytes([]byte) error + Error(...any) error + Next() error + Redirect(int, string) error + Request() Request + Response() Response + Status(int) Context + String(string) error +} + +// Contains the request and response data. +type context struct { + request + response + server *server + handlerCount uint8 +} + +// Adds a raw byte slice to the response body. +func (ctx *context) Bytes(body []byte) error { + ctx.response.body = append(ctx.response.body, body...) + return nil +} + +// Provides a convenient way to wrap multiple errors. +func (ctx *context) Error(messages ...any) error { + var combined []error + + for _, msg := range messages { + switch err := msg.(type) { + case error: + combined = append(combined, err) + case string: + combined = append(combined, errors.New(err)) + } + } + + return errors.Join(combined...) +} + +// Executes the next handler in the middleware chain. +func (ctx *context) Next() error { + ctx.handlerCount++ + return ctx.server.handlers[ctx.handlerCount](ctx) +} + +// Redirects the client to a different location with the specified status code. +func (ctx *context) Redirect(status int, location string) error { + ctx.response.SetStatus(status) + ctx.response.SetHeader("Location", location) + return nil +} + +// Returns the HTTP request. +func (ctx *context) Request() Request { + return &ctx.request +} + +// Returns the HTTP response. +func (ctx *context) Response() Response { + return &ctx.response +} + +// Sets the HTTP status of the response and returns the context for method chaining. +func (ctx *context) Status(status int) Context { + ctx.response.SetStatus(status) + return ctx +} + +// Adds the given string to the response body. +func (ctx *context) String(body string) error { + ctx.response.body = append(ctx.response.body, body...) + return nil +} diff --git a/HTTP.go b/HTTP.go new file mode 100644 index 0000000..b27b311 --- /dev/null +++ b/HTTP.go @@ -0,0 +1,41 @@ +package web + +import "strings" + +// Returns true if the given string is a valid HTTP request method. +func isRequestMethod(method string) bool { + switch method { + case "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH": + return true + default: + return false + } +} + +// Parses a URL and returns the scheme, host, path and query. +func parseURL(url string) (scheme string, host string, path string, query string) { + schemePos := strings.Index(url, "://") + + if schemePos != -1 { + scheme = url[:schemePos] + url = url[schemePos+len("://"):] + } + + pathPos := strings.IndexByte(url, '/') + + if pathPos != -1 { + host = url[:pathPos] + url = url[pathPos:] + } + + queryPos := strings.IndexByte(url, '?') + + if queryPos != -1 { + path = url[:queryPos] + query = url[queryPos+1:] + return + } + + path = url + return +} diff --git a/Handler.go b/Handler.go new file mode 100644 index 0000000..3b3d5ab --- /dev/null +++ b/Handler.go @@ -0,0 +1,3 @@ +package web + +type Handler func(Context) error diff --git a/Header.go b/Header.go new file mode 100644 index 0000000..ae86680 --- /dev/null +++ b/Header.go @@ -0,0 +1,6 @@ +package web + +type Header struct { + Key string + Value string +} diff --git a/Request.go b/Request.go new file mode 100644 index 0000000..f8e65a6 --- /dev/null +++ b/Request.go @@ -0,0 +1,82 @@ +package web + +import ( + "bufio" + + router "git.sharkk.net/Go/Router" +) + +// Interface for HTTP requests. +type Request interface { + Header(string) string + Host() string + Method() string + Path() string + Scheme() string + Param(string) string +} + +// Represents the HTTP request used in the given context. +type request struct { + reader *bufio.Reader + scheme string + host string + method string + path string + query string + headers []Header + body []byte + params []router.Parameter +} + +// Returns the header value for the given key. +func (req *request) Header(key string) string { + for _, header := range req.headers { + if header.Key == key { + return header.Value + } + } + + return "" +} + +// Returns the requested host. +func (req *request) Host() string { + return req.host +} + +// Returns the request method. +func (req *request) Method() string { + return req.method +} + +// Retrieves a parameter. +func (req *request) Param(name string) string { + for i := range len(req.params) { + p := req.params[i] + + if p.Key == name { + return p.Value + } + } + + return "" +} + +// Returns the requested path. +func (req *request) Path() string { + return req.path +} + +// Returns either 'http', 'https', or an empty string. +func (req request) Scheme() string { + return req.scheme +} + +// Adds a new parameter to the request. +func (req *request) addParameter(key string, value string) { + req.params = append(req.params, router.Parameter{ + Key: key, + Value: value, + }) +} diff --git a/Response.go b/Response.go new file mode 100644 index 0000000..c1ae09d --- /dev/null +++ b/Response.go @@ -0,0 +1,77 @@ +package web + +import "io" + +// Interface for an HTTP response. +type Response interface { + io.Writer + io.StringWriter + Body() []byte + Header(string) string + SetHeader(key string, value string) + SetBody([]byte) + SetStatus(int) + Status() int +} + +// Represents the HTTP response used in the given context. +type response struct { + body []byte + headers []Header + status uint16 +} + +// Returns the response body. +func (res *response) Body() []byte { + return res.body +} + +// Returns the header value for the given key. +func (res *response) Header(key string) string { + for _, header := range res.headers { + if header.Key == key { + return header.Value + } + } + + return "" +} + +// Sets the header value for the given key. +func (res *response) SetHeader(key string, value string) { + for i, header := range res.headers { + if header.Key == key { + res.headers[i].Value = value + return + } + } + + res.headers = append(res.headers, Header{Key: key, Value: value}) +} + +// Replaces the response body with the new contents. +func (res *response) SetBody(body []byte) { + res.body = body +} + +// Sets the HTTP status code. +func (res *response) SetStatus(status int) { + res.status = uint16(status) +} + +// Returns the HTTP status code. +func (res *response) Status() int { + return int(res.status) +} + +// Implements the io.Writer interface for the body. +func (res *response) Write(body []byte) (int, error) { + res.body = append(res.body, body...) + return len(body), nil +} + +// Implements the io.StringWriter interface for the body. +func (res *response) WriteString(body string) (int, error) { + res.body = append(res.body, body...) + return len(body), nil +} diff --git a/Server.go b/Server.go new file mode 100644 index 0000000..28f532c --- /dev/null +++ b/Server.go @@ -0,0 +1,248 @@ +package web + +import ( + "bufio" + "bytes" + "io" + "log" + "net" + "os" + "os/signal" + "strconv" + "strings" + "sync" + "syscall" + + router "git.sharkk.net/Go/Router" +) + +// Interface for an HTTP server. +type Server interface { + Get(path string, handler Handler) + Request(method string, path string, headers []Header, body io.Reader) Response + Router() *router.Router[Handler] + Run(address string) error + Use(handlers ...Handler) +} + +// HTTP server. +type server struct { + handlers []Handler + contextPool sync.Pool + router *router.Router[Handler] + errorHandler func(Context, error) +} + +// Creates a new HTTP server. +func NewServer() Server { + r := &router.Router[Handler]{} + s := &server{ + router: r, + handlers: []Handler{ + func(c Context) error { + ctx := c.(*context) + handler := r.LookupNoAlloc(ctx.request.method, ctx.request.path, ctx.request.addParameter) + + if handler == nil { + ctx.SetStatus(404) + return nil + } + + return handler(c) + }, + }, + errorHandler: func(ctx Context, err error) { + log.Println(ctx.Request().Path(), err) + }, + } + + s.contextPool.New = func() any { return s.newContext() } + return s +} + +// Registers a handler to be called when the given GET path has been requested. +func (s *server) Get(path string, handler Handler) { + s.Router().Add("GET", path, handler) +} + +// Performs a synthetic request and returns the response. +// This function keeps the response in memory so it's slightly slower than a real request. +// However it is very useful inside tests where you don't want to spin up a real web server. +func (s *server) Request(method string, url string, headers []Header, body io.Reader) Response { + ctx := s.newContext() + ctx.request.headers = headers + s.handleRequest(ctx, method, url, io.Discard) + return ctx.Response() +} + +// Starts the server on the given address. +func (s *server) Run(address string) error { + listener, err := net.Listen("tcp", address) + + if err != nil { + return err + } + + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + + if err != nil { + continue + } + + go s.handleConnection(conn) + } + }() + + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + <-stop + return nil +} + +// Returns the router used by the server. +func (s *server) Router() *router.Router[Handler] { + return s.router +} + +// Adds handlers to your handlers chain. +func (s *server) Use(handlers ...Handler) { + last := s.handlers[len(s.handlers)-1] + s.handlers = append(s.handlers[:len(s.handlers)-1], handlers...) + s.handlers = append(s.handlers, last) +} + +// Handles an accepted connection. +func (s *server) handleConnection(conn net.Conn) { + var ( + ctx = s.contextPool.Get().(*context) + method string + url string + ) + + ctx.reader.Reset(conn) + + defer conn.Close() + defer s.contextPool.Put(ctx) + + for { + // Read the HTTP request line + message, err := ctx.reader.ReadString('\n') + + if err != nil { + return + } + + space := strings.IndexByte(message, ' ') + + if space <= 0 { + io.WriteString(conn, "HTTP/1.1 400 Bad Request\r\n\r\n") + return + } + + method = message[:space] + + if !isRequestMethod(method) { + io.WriteString(conn, "HTTP/1.1 400 Bad Request\r\n\r\n") + return + } + + lastSpace := strings.LastIndexByte(message, ' ') + + if lastSpace == space { + lastSpace = len(message) - len("\r\n") + } + + url = message[space+1 : lastSpace] + + // Add headers until we meet an empty line + for { + message, err = ctx.reader.ReadString('\n') + + if err != nil { + return + } + + if message == "\r\n" { + break + } + + colon := strings.IndexByte(message, ':') + + if colon <= 0 { + continue + } + + key := message[:colon] + value := message[colon+2 : len(message)-2] + + ctx.request.headers = append(ctx.request.headers, Header{ + Key: key, + Value: value, + }) + } + + // Handle the request + s.handleRequest(ctx, method, url, conn) + + // Clean up the context + ctx.request.headers = ctx.request.headers[:0] + ctx.request.body = ctx.request.body[:0] + ctx.response.headers = ctx.response.headers[:0] + ctx.response.body = ctx.response.body[:0] + ctx.params = ctx.params[:0] + ctx.handlerCount = 0 + ctx.status = 200 + } +} + +// Handles the given request. +func (s *server) handleRequest(ctx *context, method string, url string, writer io.Writer) { + ctx.method = method + ctx.scheme, ctx.host, ctx.path, ctx.query = parseURL(url) + + err := s.handlers[0](ctx) + + if err != nil { + s.errorHandler(ctx, err) + } + + tmp := bytes.Buffer{} + tmp.WriteString("HTTP/1.1 ") + tmp.WriteString(strconv.Itoa(int(ctx.status))) + tmp.WriteString("\r\nContent-Length: ") + tmp.WriteString(strconv.Itoa(len(ctx.response.body))) + tmp.WriteString("\r\n") + + for _, header := range ctx.response.headers { + tmp.WriteString(header.Key) + tmp.WriteString(": ") + tmp.WriteString(header.Value) + tmp.WriteString("\r\n") + } + + tmp.WriteString("\r\n") + tmp.Write(ctx.response.body) + writer.Write(tmp.Bytes()) +} + +// Allocates a new context with the default state. +func (s *server) newContext() *context { + return &context{ + server: s, + request: request{ + reader: bufio.NewReader(nil), + body: make([]byte, 0), + headers: make([]Header, 0, 8), + params: make([]router.Parameter, 0, 8), + }, + response: response{ + body: make([]byte, 0, 1024), + headers: make([]Header, 0, 8), + status: 200, + }, + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..20cfb33 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.sharkk.net/Go/Web + +go 1.23.0 + +require git.sharkk.net/Go/Router v0.0.0-20240824032014-8d9ebd32141b diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c55c912 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +git.sharkk.net/Go/Router v0.0.0-20240824032014-8d9ebd32141b h1:RaKQ/5Uu4oD9HNIJfMLwT8JRY8S+0MSv7OhWtX99OE0= +git.sharkk.net/Go/Router v0.0.0-20240824032014-8d9ebd32141b/go.mod h1:xUexHGjhY7bDBD3RsurM0SoUseaKxDov3jNiZRT67fA=