Compare commits
No commits in common. "c70840271e387a773abd6f8787a6b7155421a962" and "8815ac8f752b3a954f2df8ef27fd52452a161ed6" have entirely different histories.
c70840271e
...
8815ac8f75
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -21,7 +21,3 @@
|
|||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# Test directories and files
|
||||
config.lua
|
||||
routes/
|
||||
static/
|
||||
|
|
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -1,3 +0,0 @@
|
|||
[submodule "luajit"]
|
||||
path = luajit
|
||||
url = https://git.sharkk.net/Sky/LuaJIT-to-Go.git
|
19
LICENSE
19
LICENSE
|
@ -1,19 +0,0 @@
|
|||
## Sharkk Open License
|
||||
|
||||
### Version 1.0, March 2025
|
||||
|
||||
Copyright (c) Sharkk, Skylear Johnson
|
||||
|
||||
Hey there, code surfer! You're free to ride this wave—use, modify, and share this software however you like, as long as you stick to these chill but important rules:
|
||||
|
||||
1. **Share Your Changes**: If you tweak, remix, or build on this software, you’ve gotta share your work with the world under the same license. That means making your modified source code available in a reasonable way—like linking to a public repo. Keep the stoke alive!
|
||||
|
||||
2. **Keep This License**: Whenever you pass this software along (whether you’ve changed it or not), you need to include this license in full. No sneaky restrictions that limit the freedom to ride the digital waves.
|
||||
|
||||
3. **Give Credit Where It’s Due**: Show some love to the original author(s) by keeping the copyright notice and, if possible, linking back to the original source. Good vibes and respect go a long way.
|
||||
|
||||
4. **Make It Your Own**: If you add your own original code or features, you’re totally free to monetize those additions. Sell it, license it, or turn it into the next big thing—just keep the original parts open for everyone.
|
||||
|
||||
5. **No Guarantees**: This software comes "as is." No promises, no warranties—just pure, unfiltered code. If things go sideways, you’re riding that wave at your own risk. The authors aren’t responsible for any wipeouts.
|
||||
|
||||
By using, modifying, or sharing this software, you’re agreeing to these terms. Keep it open, keep it flowing, and most of all—have fun!
|
|
@ -1,396 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Config represents a configuration loaded from a Lua file
|
||||
type Config struct {
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
// New creates a new empty configuration
|
||||
func New() *Config {
|
||||
return &Config{
|
||||
values: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads configuration from a Lua file
|
||||
func Load(filePath string) (*Config, error) {
|
||||
// Create a new Lua state
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
return nil, errors.New("failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Execute the Lua file
|
||||
if err := state.DoFile(filePath); err != nil {
|
||||
return nil, fmt.Errorf("failed to load config file: %w", err)
|
||||
}
|
||||
|
||||
// Create the config instance
|
||||
config := New()
|
||||
|
||||
// Extract values from the Lua state
|
||||
if err := extractGlobals(state, config.values); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// extractGlobals extracts global variables from the Lua state
|
||||
func extractGlobals(state *luajit.State, values map[string]any) error {
|
||||
// Get the globals table (_G)
|
||||
state.GetGlobal("_G")
|
||||
if !state.IsTable(-1) {
|
||||
state.Pop(1)
|
||||
return errors.New("failed to get globals table")
|
||||
}
|
||||
|
||||
// Pre-populate with standard globals for reference checking
|
||||
stdGlobals := map[string]bool{
|
||||
"_G": true, "_VERSION": true, "assert": true, "collectgarbage": true,
|
||||
"coroutine": true, "debug": true, "dofile": true, "error": true,
|
||||
"getmetatable": true, "io": true, "ipairs": true, "load": true,
|
||||
"loadfile": true, "loadstring": true, "math": true, "next": true,
|
||||
"os": true, "package": true, "pairs": true, "pcall": true,
|
||||
"print": true, "rawequal": true, "rawget": true, "rawset": true,
|
||||
"require": true, "select": true, "setmetatable": true, "string": true,
|
||||
"table": true, "tonumber": true, "tostring": true, "type": true,
|
||||
"unpack": true, "xpcall": true,
|
||||
// LuaJIT specific globals
|
||||
"jit": true, "bit": true, "ffi": true, "bit32": true,
|
||||
}
|
||||
|
||||
// First, let's get the original globals to compare with user globals
|
||||
originalGlobals := make(map[string]bool)
|
||||
|
||||
// Execute empty Lua state to get standard globals
|
||||
emptyState := luajit.New()
|
||||
if emptyState != nil {
|
||||
defer emptyState.Close()
|
||||
|
||||
emptyState.GetGlobal("_G")
|
||||
emptyState.PushNil() // Start iteration
|
||||
for emptyState.Next(-2) {
|
||||
if emptyState.IsString(-2) {
|
||||
key := emptyState.ToString(-2)
|
||||
originalGlobals[key] = true
|
||||
}
|
||||
emptyState.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
emptyState.Pop(1) // Pop _G
|
||||
}
|
||||
|
||||
// Iterate through the globals table
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
// Stack now has key at -2 and value at -1
|
||||
|
||||
// Get key as string
|
||||
if !state.IsString(-2) {
|
||||
state.Pop(1) // Pop value, leave key for next iteration
|
||||
continue
|
||||
}
|
||||
|
||||
key := state.ToString(-2)
|
||||
|
||||
// Skip standard Lua globals, but only if they're not overridden by user
|
||||
// (standard globals will be functions or tables, user values usually aren't)
|
||||
valueType := state.GetType(-1)
|
||||
|
||||
// Skip functions, userdata, and threads regardless of origin
|
||||
if valueType == luajit.TypeFunction || valueType == luajit.TypeUserData || valueType == luajit.TypeThread {
|
||||
state.Pop(1)
|
||||
continue
|
||||
}
|
||||
|
||||
// For known Lua globals, we need to see if they're the original or user-defined
|
||||
if stdGlobals[key] {
|
||||
// For simple value types, assume user-defined
|
||||
if valueType == luajit.TypeBoolean || valueType == luajit.TypeNumber || valueType == luajit.TypeString {
|
||||
// These are probably user values with standard names
|
||||
} else if originalGlobals[key] {
|
||||
// If it's in the original globals and not a simple type, skip it
|
||||
state.Pop(1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle primitive types directly
|
||||
switch valueType {
|
||||
case luajit.TypeBoolean:
|
||||
values[key] = state.ToBoolean(-1)
|
||||
state.Pop(1)
|
||||
continue
|
||||
case luajit.TypeNumber:
|
||||
values[key] = state.ToNumber(-1)
|
||||
state.Pop(1)
|
||||
continue
|
||||
case luajit.TypeString:
|
||||
values[key] = state.ToString(-1)
|
||||
state.Pop(1)
|
||||
continue
|
||||
case luajit.TypeTable:
|
||||
// For tables, use the existing conversion logic
|
||||
default:
|
||||
// Skip unsupported types
|
||||
state.Pop(1)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle tables (arrays and maps)
|
||||
if valueType == luajit.TypeTable {
|
||||
// Check if it looks like an array first
|
||||
arrLen := state.GetTableLength(-1)
|
||||
if arrLen > 0 {
|
||||
// Process as array
|
||||
arr := make([]any, arrLen)
|
||||
for i := 1; i <= arrLen; i++ {
|
||||
state.PushNumber(float64(i))
|
||||
state.GetTable(-2) // Get t[i]
|
||||
|
||||
switch state.GetType(-1) {
|
||||
case luajit.TypeBoolean:
|
||||
arr[i-1] = state.ToBoolean(-1)
|
||||
case luajit.TypeNumber:
|
||||
arr[i-1] = state.ToNumber(-1)
|
||||
case luajit.TypeString:
|
||||
arr[i-1] = state.ToString(-1)
|
||||
default:
|
||||
// For complex elements, try to convert
|
||||
if val, err := state.ToValue(-1); err == nil {
|
||||
arr[i-1] = val
|
||||
}
|
||||
}
|
||||
|
||||
state.Pop(1) // Pop value
|
||||
}
|
||||
values[key] = arr
|
||||
state.Pop(1)
|
||||
continue
|
||||
}
|
||||
|
||||
// Try normal table conversion for non-array tables
|
||||
if table, err := state.ToTable(-1); err == nil {
|
||||
values[key] = table
|
||||
}
|
||||
}
|
||||
|
||||
// Pop value, leave key for next iteration
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// Pop the globals table
|
||||
state.Pop(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a configuration value by key
|
||||
func (c *Config) Get(key string) any {
|
||||
return c.values[key]
|
||||
}
|
||||
|
||||
// GetString returns a string configuration value
|
||||
func (c *Config) GetString(key string, defaultValue string) string {
|
||||
value, ok := c.values[key]
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// GetInt returns an integer configuration value
|
||||
func (c *Config) GetInt(key string, defaultValue int) int {
|
||||
value, ok := c.values[key]
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Handle both int and float64 (which is what Lua numbers become in Go)
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case float64:
|
||||
return int(v)
|
||||
default:
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
// GetFloat returns a float configuration value
|
||||
func (c *Config) GetFloat(key string, defaultValue float64) float64 {
|
||||
value, ok := c.values[key]
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Handle both float64 and int
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case int:
|
||||
return float64(v)
|
||||
default:
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
// GetBool returns a boolean configuration value
|
||||
func (c *Config) GetBool(key string, defaultValue bool) bool {
|
||||
value, ok := c.values[key]
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
boolValue, ok := value.(bool)
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return boolValue
|
||||
}
|
||||
|
||||
// GetMap returns a map configuration value
|
||||
func (c *Config) GetMap(key string) map[string]any {
|
||||
value, ok := c.values[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
table, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return table
|
||||
}
|
||||
|
||||
// GetArray returns an array of values from a Lua array
|
||||
func (c *Config) GetArray(key string) []any {
|
||||
value := c.Get(key)
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Direct array
|
||||
if arr, ok := value.([]any); ok {
|
||||
return arr
|
||||
}
|
||||
|
||||
// Arrays in Lua might also be represented as maps with an empty string key
|
||||
valueMap, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The array data is stored with an empty key
|
||||
arr, ok := valueMap[""]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if it's a float64 array (common for Lua numeric arrays)
|
||||
if floatArr, ok := arr.([]float64); ok {
|
||||
// Convert to []any
|
||||
result := make([]any, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
result[i] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Otherwise, try to return as is
|
||||
anyArr, ok := arr.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return anyArr
|
||||
}
|
||||
|
||||
// GetIntArray returns an array of integers from a Lua array
|
||||
func (c *Config) GetIntArray(key string) []int {
|
||||
value := c.Get(key)
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Direct array case
|
||||
if arr, ok := value.([]any); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
if num, ok := v.(float64); ok {
|
||||
result = append(result, int(num))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Arrays in Lua might also be represented as maps with an empty string key
|
||||
valueMap, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The array data is stored with an empty key
|
||||
arr, ok := valueMap[""]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For numeric arrays, LuaJIT returns []float64
|
||||
floatArr, ok := arr.([]float64)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert to int slice
|
||||
result := make([]int, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
result[i] = int(v)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetStringArray returns an array of strings from a Lua array
|
||||
func (c *Config) GetStringArray(key string) []string {
|
||||
arr := c.GetArray(key)
|
||||
if arr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
if str, ok := v.(string); ok {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Values returns all configuration values
|
||||
// Note: The returned map should not be modified
|
||||
func (c *Config) Values() map[string]any {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// Set sets a configuration value
|
||||
func (c *Config) Set(key string, value any) {
|
||||
c.values[key] = value
|
||||
}
|
|
@ -1,351 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLoad verifies we can successfully load configuration values from a Lua file.
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
content := `
|
||||
-- Basic configuration values
|
||||
host = "localhost"
|
||||
port = 8080
|
||||
debug = true
|
||||
pi = 3.14159
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify values were loaded correctly
|
||||
if host := cfg.GetString("host", ""); host != "localhost" {
|
||||
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
||||
}
|
||||
|
||||
if port := cfg.GetInt("port", 0); port != 8080 {
|
||||
t.Errorf("Expected port to be 8080, got %d", port)
|
||||
}
|
||||
|
||||
if debug := cfg.GetBool("debug", false); !debug {
|
||||
t.Errorf("Expected debug to be true")
|
||||
}
|
||||
|
||||
if pi := cfg.GetFloat("pi", 0); pi != 3.14159 {
|
||||
t.Errorf("Expected pi to be 3.14159, got %f", pi)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadErrors ensures the package properly handles loading errors.
|
||||
func TestLoadErrors(t *testing.T) {
|
||||
// Test with non-existent file
|
||||
_, err := Load("nonexistent.lua")
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading non-existent file, got nil")
|
||||
}
|
||||
|
||||
// Test with invalid Lua
|
||||
content := `
|
||||
-- This is invalid Lua
|
||||
host = "localhost
|
||||
port = 8080)
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
_, err = Load(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading invalid Lua, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalVsGlobal verifies only global variables are exported, not locals.
|
||||
func TestLocalVsGlobal(t *testing.T) {
|
||||
// Create a temporary config file with both local and global variables
|
||||
content := `
|
||||
-- Local variables should not be exported
|
||||
local local_var = "hidden"
|
||||
|
||||
-- Global variables should be exported
|
||||
global_var = "visible"
|
||||
|
||||
-- A function that uses both
|
||||
function test_func()
|
||||
return local_var .. " " .. global_var
|
||||
end
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Check that global_var exists
|
||||
if globalVar := cfg.GetString("global_var", ""); globalVar != "visible" {
|
||||
t.Errorf("Expected global_var to be 'visible', got '%s'", globalVar)
|
||||
}
|
||||
|
||||
// Check that local_var does not exist
|
||||
if localVar := cfg.GetString("local_var", "default"); localVar != "default" {
|
||||
t.Errorf("Expected local_var to use default, got '%s'", localVar)
|
||||
}
|
||||
|
||||
// Check that functions are not exported
|
||||
if val := cfg.Get("test_func"); val != nil {
|
||||
t.Errorf("Expected function to not be exported, got %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// TestArrayHandling verifies correct handling of Lua arrays.
|
||||
func TestArrayHandling(t *testing.T) {
|
||||
// Create a temporary config file with arrays
|
||||
content := `
|
||||
-- Numeric array
|
||||
numbers = {10, 20, 30, 40, 50}
|
||||
|
||||
-- String array
|
||||
strings = {"apple", "banana", "cherry"}
|
||||
|
||||
-- Mixed array
|
||||
mixed = {1, "two", true, 4.5}
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test GetIntArray
|
||||
intArray := cfg.GetIntArray("numbers")
|
||||
expectedInts := []int{10, 20, 30, 40, 50}
|
||||
if !reflect.DeepEqual(intArray, expectedInts) {
|
||||
t.Errorf("Expected int array %v, got %v", expectedInts, intArray)
|
||||
}
|
||||
|
||||
// Test GetStringArray
|
||||
strArray := cfg.GetStringArray("strings")
|
||||
expectedStrs := []string{"apple", "banana", "cherry"}
|
||||
if !reflect.DeepEqual(strArray, expectedStrs) {
|
||||
t.Errorf("Expected string array %v, got %v", expectedStrs, strArray)
|
||||
}
|
||||
|
||||
// Test GetArray with mixed types
|
||||
mixedArray := cfg.GetArray("mixed")
|
||||
if len(mixedArray) != 4 {
|
||||
t.Errorf("Expected mixed array length 4, got %d", len(mixedArray))
|
||||
// Skip further tests if array is empty to avoid panic
|
||||
return
|
||||
}
|
||||
|
||||
// Check types - carefully to avoid panics
|
||||
if len(mixedArray) > 0 {
|
||||
if num, ok := mixedArray[0].(float64); !ok || num != 1 {
|
||||
t.Errorf("Expected first element to be 1, got %v", mixedArray[0])
|
||||
}
|
||||
}
|
||||
|
||||
if len(mixedArray) > 1 {
|
||||
if str, ok := mixedArray[1].(string); !ok || str != "two" {
|
||||
t.Errorf("Expected second element to be 'two', got %v", mixedArray[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestComplexTable tests handling of complex nested tables.
|
||||
func TestComplexTable(t *testing.T) {
|
||||
// Create a temporary config file with complex tables
|
||||
content := `
|
||||
-- Nested table structure
|
||||
server = {
|
||||
host = "localhost",
|
||||
port = 8080,
|
||||
settings = {
|
||||
timeout = 30,
|
||||
retries = 3
|
||||
}
|
||||
}
|
||||
|
||||
-- Table with mixed array and map elements
|
||||
mixed_table = {
|
||||
list = {1, 2, 3},
|
||||
mapping = {
|
||||
a = "apple",
|
||||
b = "banana"
|
||||
}
|
||||
}
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test getting nested values
|
||||
serverMap := cfg.GetMap("server")
|
||||
if serverMap == nil {
|
||||
t.Fatal("Expected server map to exist")
|
||||
}
|
||||
|
||||
// Check first level values
|
||||
if host, ok := serverMap["host"].(string); !ok || host != "localhost" {
|
||||
t.Errorf("Expected server.host to be 'localhost', got %v", serverMap["host"])
|
||||
}
|
||||
|
||||
if port, ok := serverMap["port"].(float64); !ok || port != 8080 {
|
||||
t.Errorf("Expected server.port to be 8080, got %v", serverMap["port"])
|
||||
}
|
||||
|
||||
// Check nested settings
|
||||
settings, ok := serverMap["settings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected server.settings to be a map")
|
||||
}
|
||||
|
||||
if timeout, ok := settings["timeout"].(float64); !ok || timeout != 30 {
|
||||
t.Errorf("Expected server.settings.timeout to be 30, got %v", settings["timeout"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultValues verifies default values work correctly when keys don't exist.
|
||||
func TestDefaultValues(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
content := `
|
||||
-- Just one value
|
||||
existing = "value"
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test defaults for non-existent keys
|
||||
if val := cfg.GetString("nonexistent", "default"); val != "default" {
|
||||
t.Errorf("Expected default string, got '%s'", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetInt("nonexistent", 42); val != 42 {
|
||||
t.Errorf("Expected default int 42, got %d", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetFloat("nonexistent", 3.14); val != 3.14 {
|
||||
t.Errorf("Expected default float 3.14, got %f", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetBool("nonexistent", true); !val {
|
||||
t.Errorf("Expected default bool true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestModifyConfig tests the ability to modify configuration values.
|
||||
func TestModifyConfig(t *testing.T) {
|
||||
// Create a config manually
|
||||
cfg := New()
|
||||
|
||||
// Set some values
|
||||
cfg.Set("host", "localhost")
|
||||
cfg.Set("port", 8080)
|
||||
|
||||
// Verify the values were set
|
||||
if host := cfg.GetString("host", ""); host != "localhost" {
|
||||
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
||||
}
|
||||
|
||||
if port := cfg.GetInt("port", 0); port != 8080 {
|
||||
t.Errorf("Expected port to be 8080, got %d", port)
|
||||
}
|
||||
|
||||
// Modify a value
|
||||
cfg.Set("host", "127.0.0.1")
|
||||
|
||||
// Verify the change
|
||||
if host := cfg.GetString("host", ""); host != "127.0.0.1" {
|
||||
t.Errorf("Expected modified host to be '127.0.0.1', got '%s'", host)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTypeConversion tests the type conversion in getter methods.
|
||||
func TestTypeConversion(t *testing.T) {
|
||||
// Create a temporary config file with values that need conversion
|
||||
content := `
|
||||
-- Numbers that can be integers
|
||||
int_as_float = 42.0
|
||||
|
||||
-- Floats that should remain floats
|
||||
float_val = 3.14159
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test GetInt with a float value
|
||||
if val := cfg.GetInt("int_as_float", 0); val != 42 {
|
||||
t.Errorf("Expected int 42, got %d", val)
|
||||
}
|
||||
|
||||
// Test GetFloat with an int value
|
||||
if val := cfg.GetFloat("int_as_float", 0); val != 42.0 {
|
||||
t.Errorf("Expected float 42.0, got %f", val)
|
||||
}
|
||||
|
||||
// Test incorrect type handling
|
||||
cfg.Set("string_val", "not a number")
|
||||
|
||||
if val := cfg.GetInt("string_val", 99); val != 99 {
|
||||
t.Errorf("Expected default int 99 for string value, got %d", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetFloat("string_val", 99.9); val != 99.9 {
|
||||
t.Errorf("Expected default float 99.9 for string value, got %f", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetBool("float_val", false); val != false {
|
||||
t.Errorf("Expected default false for non-bool value, got true")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a temporary Lua file with content
|
||||
func createTempLuaFile(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
|
||||
tempFile, err := os.CreateTemp("", "config-test-*.lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
if _, err := tempFile.WriteString(content); err != nil {
|
||||
os.Remove(tempFile.Name())
|
||||
t.Fatalf("Failed to write to temp file: %v", err)
|
||||
}
|
||||
|
||||
if err := tempFile.Close(); err != nil {
|
||||
os.Remove(tempFile.Name())
|
||||
t.Fatalf("Failed to close temp file: %v", err)
|
||||
}
|
||||
|
||||
return tempFile.Name()
|
||||
}
|
|
@ -1,136 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Maximum form parse size (16MB)
|
||||
const maxFormSize = 16 << 20
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrFormSizeTooLarge = errors.New("form size too large")
|
||||
ErrInvalidFormType = errors.New("invalid form content type")
|
||||
)
|
||||
|
||||
// ParseForm parses a POST request body into a map of values
|
||||
// Supports both application/x-www-form-urlencoded and multipart/form-data content types
|
||||
func ParseForm(r *http.Request) (map[string]any, error) {
|
||||
// Only handle POST, PUT, PATCH
|
||||
if r.Method != http.MethodPost &&
|
||||
r.Method != http.MethodPut &&
|
||||
r.Method != http.MethodPatch {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
// Parse the media type
|
||||
mediaType, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidFormType
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
switch {
|
||||
case mediaType == "application/x-www-form-urlencoded":
|
||||
// Handle URL-encoded form
|
||||
if err := parseURLEncodedForm(r, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case strings.HasPrefix(mediaType, "multipart/form-data"):
|
||||
// Handle multipart form
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, ErrInvalidFormType
|
||||
}
|
||||
|
||||
if err := parseMultipartForm(r, boundary, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
// Unrecognized content type
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseURLEncodedForm handles application/x-www-form-urlencoded forms
|
||||
func parseURLEncodedForm(r *http.Request, result map[string]any) error {
|
||||
// Enforce size limit
|
||||
r.Body = http.MaxBytesReader(nil, r.Body, maxFormSize)
|
||||
|
||||
// Read the entire body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "http: request body too large") {
|
||||
return ErrFormSizeTooLarge
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse form values
|
||||
form, err := url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert to map[string]any
|
||||
for key, values := range form {
|
||||
if len(values) == 1 {
|
||||
// Single value
|
||||
result[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
// Multiple values
|
||||
result[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseMultipartForm handles multipart/form-data forms
|
||||
func parseMultipartForm(r *http.Request, boundary string, result map[string]any) error {
|
||||
// Limit the form size
|
||||
if err := r.ParseMultipartForm(maxFormSize); err != nil {
|
||||
if strings.Contains(err.Error(), "http: request body too large") {
|
||||
return ErrFormSizeTooLarge
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Process form values
|
||||
for key, values := range r.MultipartForm.Value {
|
||||
if len(values) == 1 {
|
||||
// Single value
|
||||
result[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
// Multiple values
|
||||
result[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
// We don't handle file uploads here - could be extended in the future
|
||||
// if needed to support file uploads to Lua
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Usage:
|
||||
// After parsing the form with ParseForm, you can add it to the context with:
|
||||
// ctx.Set("form", formData)
|
||||
//
|
||||
// This makes the form data accessible in Lua as ctx.form.field_name
|
|
@ -1,44 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
)
|
||||
|
||||
// StatusColors for different status code ranges
|
||||
const (
|
||||
colorGreen = "\033[32m" // 2xx - Success
|
||||
colorCyan = "\033[36m" // 3xx - Redirection
|
||||
colorYellow = "\033[33m" // 4xx - Client Errors
|
||||
colorRed = "\033[31m" // 5xx - Server Errors
|
||||
colorReset = "\033[0m" // Reset color
|
||||
)
|
||||
|
||||
// LogRequest logs an HTTP request with custom formatting
|
||||
func LogRequest(log *logger.Logger, statusCode int, r *http.Request, duration time.Duration) {
|
||||
statusColor := getStatusColor(statusCode)
|
||||
|
||||
// Use the logger's raw message writer to bypass the standard format
|
||||
log.LogRaw("%s [%s%d%s] %s %s (%v)",
|
||||
time.Now().Format(log.TimeFormat()),
|
||||
statusColor, statusCode, colorReset,
|
||||
r.Method, r.URL.Path, duration)
|
||||
}
|
||||
|
||||
// getStatusColor returns the ANSI color code for a status code
|
||||
func getStatusColor(code int) string {
|
||||
switch {
|
||||
case code >= 200 && code < 300:
|
||||
return colorGreen
|
||||
case code >= 300 && code < 400:
|
||||
return colorCyan
|
||||
case code >= 400 && code < 500:
|
||||
return colorYellow
|
||||
case code >= 500:
|
||||
return colorRed
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
// QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT.
|
||||
// Single value parameters are stored as strings.
|
||||
// Multi-value parameters are converted to []any arrays.
|
||||
func QueryToLua(r *http.Request) map[string]any {
|
||||
if r == nil || r.URL == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
query := r.URL.Query()
|
||||
if len(query) == 0 {
|
||||
return nil // Avoid allocation for empty queries
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(query))
|
||||
for key, values := range query {
|
||||
switch len(values) {
|
||||
case 0:
|
||||
// Skip empty values
|
||||
case 1:
|
||||
// Single value
|
||||
result[key] = values[0]
|
||||
default:
|
||||
// Multiple values - convert to []any
|
||||
arr := make([]any, len(values))
|
||||
for i, v := range values {
|
||||
arr[i] = v
|
||||
}
|
||||
result[key] = arr
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
"git.sharkk.net/Sky/Moonshark/core/routers"
|
||||
"git.sharkk.net/Sky/Moonshark/core/workers"
|
||||
)
|
||||
|
||||
// Server handles HTTP requests using Lua and static file routers
|
||||
type Server struct {
|
||||
luaRouter *routers.LuaRouter
|
||||
staticRouter *routers.StaticRouter
|
||||
workerPool *workers.Pool
|
||||
logger *logger.Logger
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
// New creates a new HTTP server
|
||||
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, pool *workers.Pool, log *logger.Logger) *Server {
|
||||
server := &Server{
|
||||
luaRouter: luaRouter,
|
||||
staticRouter: staticRouter,
|
||||
workerPool: pool,
|
||||
logger: log,
|
||||
httpServer: &http.Server{},
|
||||
}
|
||||
server.httpServer.Handler = server
|
||||
return server
|
||||
}
|
||||
|
||||
// ListenAndServe starts the server on the given address
|
||||
func (s *Server) ListenAndServe(addr string) error {
|
||||
s.httpServer.Addr = addr
|
||||
s.logger.Info("Server starting on %s", addr)
|
||||
return s.httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
s.logger.Info("Server shutting down...")
|
||||
return s.httpServer.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// ServeHTTP handles HTTP requests
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Wrap the ResponseWriter to capture status code
|
||||
wrappedWriter := newStatusCaptureWriter(w)
|
||||
|
||||
// Process the request
|
||||
s.handleRequest(wrappedWriter, r)
|
||||
|
||||
// Calculate request duration
|
||||
duration := time.Since(start)
|
||||
|
||||
// Get the status code
|
||||
statusCode := wrappedWriter.StatusCode()
|
||||
|
||||
// Log the request with our custom format
|
||||
LogRequest(s.logger, statusCode, r, duration)
|
||||
}
|
||||
|
||||
// handleRequest processes the actual request
|
||||
func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.Debug("Processing request %s %s", r.Method, r.URL.Path)
|
||||
|
||||
// Try Lua routes first
|
||||
params := &routers.Params{}
|
||||
if bytecode, found := s.luaRouter.GetBytecode(r.Method, r.URL.Path, params); found {
|
||||
s.logger.Debug("Found Lua route match for %s %s with %d params", r.Method, r.URL.Path, params.Count)
|
||||
s.handleLuaRoute(w, r, bytecode, params)
|
||||
return
|
||||
}
|
||||
|
||||
// Then try static files
|
||||
if filePath, found := s.staticRouter.Match(r.URL.Path); found {
|
||||
http.ServeFile(w, r, filePath)
|
||||
return
|
||||
}
|
||||
|
||||
// No route found
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
||||
// handleLuaRoute executes a Lua route
|
||||
func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode []byte, params *routers.Params) {
|
||||
ctx := workers.NewContext()
|
||||
|
||||
// Log bytecode size
|
||||
s.logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode))
|
||||
|
||||
// Add request info directly to context
|
||||
ctx.Set("method", r.Method)
|
||||
ctx.Set("path", r.URL.Path)
|
||||
ctx.Set("host", r.Host)
|
||||
ctx.Set("headers", makeHeaderMap(r.Header))
|
||||
|
||||
// Add URL parameters
|
||||
if params.Count > 0 {
|
||||
paramMap := make(map[string]any, params.Count)
|
||||
for i := 0; i < params.Count; i++ {
|
||||
paramMap[params.Keys[i]] = params.Values[i]
|
||||
}
|
||||
ctx.Set("params", paramMap)
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
if queryParams := QueryToLua(r); queryParams != nil {
|
||||
ctx.Set("query", queryParams)
|
||||
}
|
||||
|
||||
// Add form data
|
||||
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
||||
if formData, err := ParseForm(r); err == nil && len(formData) > 0 {
|
||||
ctx.Set("form", formData)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute Lua script
|
||||
result, err := s.workerPool.Submit(bytecode, ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("Error executing Lua route: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeResponse(w, result, s.logger)
|
||||
}
|
||||
|
||||
// makeHeaderMap converts HTTP headers to a map
|
||||
func makeHeaderMap(header http.Header) map[string]any {
|
||||
result := make(map[string]any, len(header))
|
||||
for name, values := range header {
|
||||
if len(values) == 1 {
|
||||
result[name] = values[0]
|
||||
} else {
|
||||
result[name] = values
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// writeResponse writes the Lua result to the HTTP response
|
||||
func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||
if result == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
switch res := result.(type) {
|
||||
case string:
|
||||
// String result
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Write([]byte(res))
|
||||
|
||||
case map[string]any:
|
||||
// Table result - convert to JSON
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
data, err := json.Marshal(res)
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal response: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
|
||||
default:
|
||||
// Other result types - convert to JSON
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal response: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
}
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// statusCaptureWriter is a ResponseWriter that captures the status code
|
||||
type statusCaptureWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and passes it to the wrapped ResponseWriter
|
||||
func (w *statusCaptureWriter) WriteHeader(code int) {
|
||||
w.statusCode = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// StatusCode returns the captured status code
|
||||
func (w *statusCaptureWriter) StatusCode() int {
|
||||
if w.statusCode == 0 {
|
||||
return http.StatusOK // Default to 200 if not explicitly set
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// newStatusCaptureWriter creates a new statusCaptureWriter
|
||||
func newStatusCaptureWriter(w http.ResponseWriter) *statusCaptureWriter {
|
||||
return &statusCaptureWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: 0,
|
||||
}
|
||||
}
|
|
@ -1,338 +0,0 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ANSI color codes
|
||||
const (
|
||||
colorReset = "\033[0m"
|
||||
colorRed = "\033[31m"
|
||||
colorGreen = "\033[32m"
|
||||
colorYellow = "\033[33m"
|
||||
colorBlue = "\033[34m"
|
||||
colorPurple = "\033[35m"
|
||||
colorCyan = "\033[36m"
|
||||
colorWhite = "\033[37m"
|
||||
)
|
||||
|
||||
// Log levels
|
||||
const (
|
||||
LevelDebug = iota
|
||||
LevelInfo
|
||||
LevelWarning
|
||||
LevelError
|
||||
LevelFatal
|
||||
)
|
||||
|
||||
// Level names and colors
|
||||
var levelProps = map[int]struct {
|
||||
tag string
|
||||
color string
|
||||
}{
|
||||
LevelDebug: {"DBG", colorCyan},
|
||||
LevelInfo: {"INF", colorBlue},
|
||||
LevelWarning: {"WRN", colorYellow},
|
||||
LevelError: {"ERR", colorRed},
|
||||
LevelFatal: {"FTL", colorPurple},
|
||||
}
|
||||
|
||||
// Time format for log messages
|
||||
const timeFormat = "15:04:05"
|
||||
|
||||
// logMessage represents a message to be logged
|
||||
type logMessage struct {
|
||||
level int
|
||||
message string
|
||||
rawMode bool // Indicates if raw formatting should be used
|
||||
}
|
||||
|
||||
// Logger handles logging operations
|
||||
type Logger struct {
|
||||
writer io.Writer
|
||||
messages chan logMessage
|
||||
wg sync.WaitGroup
|
||||
level int
|
||||
useColors bool
|
||||
done chan struct{}
|
||||
timeFormat string
|
||||
mu sync.Mutex // Mutex for thread-safe writing
|
||||
}
|
||||
|
||||
// New creates a new logger
|
||||
func New(minLevel int, useColors bool) *Logger {
|
||||
l := &Logger{
|
||||
writer: os.Stdout,
|
||||
messages: make(chan logMessage, 100), // Buffer 100 messages
|
||||
level: minLevel,
|
||||
useColors: useColors,
|
||||
done: make(chan struct{}),
|
||||
timeFormat: timeFormat,
|
||||
}
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.processLogs()
|
||||
return l
|
||||
}
|
||||
|
||||
// SetOutput changes the output destination
|
||||
func (l *Logger) SetOutput(w io.Writer) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.writer = w
|
||||
}
|
||||
|
||||
// TimeFormat returns the current time format
|
||||
func (l *Logger) TimeFormat() string {
|
||||
return l.timeFormat
|
||||
}
|
||||
|
||||
// SetTimeFormat changes the time format string
|
||||
func (l *Logger) SetTimeFormat(format string) {
|
||||
l.timeFormat = format
|
||||
}
|
||||
|
||||
// SetLevel changes the minimum log level
|
||||
func (l *Logger) SetLevel(level int) {
|
||||
l.level = level
|
||||
}
|
||||
|
||||
// EnableColors enables ANSI color codes in the output
|
||||
func (l *Logger) EnableColors() {
|
||||
l.useColors = true
|
||||
}
|
||||
|
||||
// DisableColors disables ANSI color codes in the output
|
||||
func (l *Logger) DisableColors() {
|
||||
l.useColors = false
|
||||
}
|
||||
|
||||
// processLogs processes incoming log messages
|
||||
func (l *Logger) processLogs() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-l.messages:
|
||||
if msg.level >= l.level {
|
||||
l.writeMessage(msg)
|
||||
}
|
||||
case <-l.done:
|
||||
// Process remaining messages
|
||||
for {
|
||||
select {
|
||||
case msg := <-l.messages:
|
||||
if msg.level >= l.level {
|
||||
l.writeMessage(msg)
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeMessage writes a formatted log message
|
||||
func (l *Logger) writeMessage(msg logMessage) {
|
||||
var logLine string
|
||||
|
||||
if msg.rawMode {
|
||||
// Raw mode - message is already formatted, just append newline
|
||||
logLine = msg.message + "\n"
|
||||
} else {
|
||||
// Standard format with timestamp, level tag, and message
|
||||
now := time.Now().Format(l.timeFormat)
|
||||
props := levelProps[msg.level]
|
||||
|
||||
if l.useColors {
|
||||
logLine = fmt.Sprintf("%s %s[%s]%s %s\n",
|
||||
now, props.color, props.tag, colorReset, msg.message)
|
||||
} else {
|
||||
logLine = fmt.Sprintf("%s [%s] %s\n",
|
||||
now, props.tag, msg.message)
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronize writing
|
||||
l.mu.Lock()
|
||||
_, _ = fmt.Fprint(l.writer, logLine)
|
||||
l.mu.Unlock()
|
||||
|
||||
// Auto-flush for fatal errors
|
||||
if msg.level == LevelFatal {
|
||||
if f, ok := l.writer.(*os.File); ok {
|
||||
_ = f.Sync()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// log sends a message to the logger goroutine
|
||||
func (l *Logger) log(level int, format string, args ...any) {
|
||||
if level < l.level {
|
||||
return
|
||||
}
|
||||
|
||||
var message string
|
||||
if len(args) > 0 {
|
||||
message = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
message = format
|
||||
}
|
||||
|
||||
// Don't block if channel is full
|
||||
select {
|
||||
case l.messages <- logMessage{level: level, message: message, rawMode: false}:
|
||||
// Message sent
|
||||
default:
|
||||
// Channel full, write directly
|
||||
l.writeMessage(logMessage{level: level, message: message, rawMode: false})
|
||||
}
|
||||
|
||||
// Exit on fatal errors
|
||||
if level == LevelFatal {
|
||||
l.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// LogRaw logs a message with raw formatting, bypassing the standard format
|
||||
func (l *Logger) LogRaw(format string, args ...any) {
|
||||
// Use info level for filtering
|
||||
if LevelInfo < l.level {
|
||||
return
|
||||
}
|
||||
|
||||
var message string
|
||||
if len(args) > 0 {
|
||||
message = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
message = format
|
||||
}
|
||||
|
||||
// Don't apply colors if disabled
|
||||
if !l.useColors {
|
||||
// Strip ANSI color codes if colors are disabled
|
||||
// Simple approach to strip common ANSI codes
|
||||
message = removeAnsiColors(message)
|
||||
}
|
||||
|
||||
// Don't block if channel is full
|
||||
select {
|
||||
case l.messages <- logMessage{level: LevelInfo, message: message, rawMode: true}:
|
||||
// Message sent
|
||||
default:
|
||||
// Channel full, write directly
|
||||
l.writeMessage(logMessage{level: LevelInfo, message: message, rawMode: true})
|
||||
}
|
||||
}
|
||||
|
||||
// Simple helper to remove ANSI color codes
|
||||
func removeAnsiColors(s string) string {
|
||||
result := ""
|
||||
inEscape := false
|
||||
|
||||
for _, c := range s {
|
||||
if inEscape {
|
||||
if c == 'm' {
|
||||
inEscape = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '\033' {
|
||||
inEscape = true
|
||||
continue
|
||||
}
|
||||
|
||||
result += string(c)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
|
||||
// Info logs an informational message
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
|
||||
// Warning logs a warning message
|
||||
func (l *Logger) Warning(format string, args ...any) {
|
||||
l.log(LevelWarning, format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal error message and exits
|
||||
func (l *Logger) Fatal(format string, args ...any) {
|
||||
l.log(LevelFatal, format, args...)
|
||||
// No need for os.Exit here as it's handled in log()
|
||||
}
|
||||
|
||||
// Close shuts down the logger goroutine
|
||||
func (l *Logger) Close() {
|
||||
close(l.done)
|
||||
l.wg.Wait()
|
||||
close(l.messages)
|
||||
}
|
||||
|
||||
// Default global logger
|
||||
var defaultLogger = New(LevelInfo, true)
|
||||
|
||||
// Debug logs a debug message to the default logger
|
||||
func Debug(format string, args ...any) {
|
||||
defaultLogger.Debug(format, args...)
|
||||
}
|
||||
|
||||
// Info logs an informational message to the default logger
|
||||
func Info(format string, args ...any) {
|
||||
defaultLogger.Info(format, args...)
|
||||
}
|
||||
|
||||
// Warning logs a warning message to the default logger
|
||||
func Warning(format string, args ...any) {
|
||||
defaultLogger.Warning(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message to the default logger
|
||||
func Error(format string, args ...any) {
|
||||
defaultLogger.Error(format, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal error message to the default logger and exits
|
||||
func Fatal(format string, args ...any) {
|
||||
defaultLogger.Fatal(format, args...)
|
||||
}
|
||||
|
||||
// LogRaw logs a raw message to the default logger
|
||||
func LogRaw(format string, args ...any) {
|
||||
defaultLogger.LogRaw(format, args...)
|
||||
}
|
||||
|
||||
// SetLevel changes the minimum log level of the default logger
|
||||
func SetLevel(level int) {
|
||||
defaultLogger.SetLevel(level)
|
||||
}
|
||||
|
||||
// SetOutput changes the output destination of the default logger
|
||||
func SetOutput(w io.Writer) {
|
||||
defaultLogger.SetOutput(w)
|
||||
}
|
||||
|
||||
// Close shuts down the default logger
|
||||
func Close() {
|
||||
defaultLogger.Close()
|
||||
}
|
|
@ -1,172 +0,0 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoggerLevels(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelInfo, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Debug should be below threshold
|
||||
logger.Debug("This should not appear")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
if buf.Len() > 0 {
|
||||
t.Error("Debug message appeared when it should be filtered")
|
||||
}
|
||||
|
||||
// Info and above should appear
|
||||
logger.Info("Info message")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
if !strings.Contains(buf.String(), "[INF]") {
|
||||
t.Errorf("Info message not logged, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
logger.Warning("Warning message")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
if !strings.Contains(buf.String(), "[WRN]") {
|
||||
t.Errorf("Warning message not logged, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
logger.Error("Error message")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
if !strings.Contains(buf.String(), "[ERR]") {
|
||||
t.Errorf("Error message not logged, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test format strings
|
||||
logger.Info("Count: %d", 42)
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
if !strings.Contains(buf.String(), "Count: 42") {
|
||||
t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test changing level
|
||||
logger.SetLevel(LevelError)
|
||||
logger.Info("This should not appear")
|
||||
logger.Warning("This should not appear")
|
||||
if buf.Len() > 0 {
|
||||
t.Error("Messages below threshold appeared")
|
||||
}
|
||||
|
||||
logger.Error("Error should appear")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
if !strings.Contains(buf.String(), "[ERR]") {
|
||||
t.Errorf("Error message not logged after level change, got: %q", buf.String())
|
||||
}
|
||||
|
||||
logger.Close()
|
||||
}
|
||||
|
||||
func TestLoggerConcurrency(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelDebug, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Log a bunch of messages concurrently
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
logger.Info("Concurrent message %d", n)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Wait for processing
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Check all messages were logged
|
||||
content := buf.String()
|
||||
for i := 0; i < 100; i++ {
|
||||
msg := "Concurrent message " + strconv.Itoa(i)
|
||||
if !strings.Contains(content, msg) && !strings.Contains(content, "Concurrent message") {
|
||||
t.Errorf("Missing concurrent messages")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
logger.Close()
|
||||
}
|
||||
|
||||
func TestLoggerColors(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelInfo, true)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Test with color
|
||||
logger.Info("Colored message")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
|
||||
content := buf.String()
|
||||
t.Logf("Colored output: %q", content) // Print actual output for diagnosis
|
||||
if !strings.Contains(content, "\033[") {
|
||||
t.Errorf("Color codes not present when enabled, got: %q", content)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
logger.DisableColors()
|
||||
logger.Info("Non-colored message")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
|
||||
content = buf.String()
|
||||
if strings.Contains(content, "\033[") {
|
||||
t.Errorf("Color codes present when disabled, got: %q", content)
|
||||
}
|
||||
|
||||
logger.Close()
|
||||
}
|
||||
|
||||
func TestDefaultLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
SetOutput(&buf)
|
||||
|
||||
Info("Test default logger")
|
||||
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||
|
||||
content := buf.String()
|
||||
if !strings.Contains(content, "[INF]") {
|
||||
t.Errorf("Default logger not working, got: %q", content)
|
||||
}
|
||||
|
||||
Close()
|
||||
}
|
||||
|
||||
func BenchmarkLogger(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelInfo, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Info("Benchmark message %d", i)
|
||||
}
|
||||
logger.Close()
|
||||
}
|
||||
|
||||
func BenchmarkLoggerParallel(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelInfo, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
logger.Info("Parallel benchmark message %d", i)
|
||||
i++
|
||||
}
|
||||
})
|
||||
logger.Close()
|
||||
}
|
|
@ -1,276 +0,0 @@
|
|||
package routers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Maximum number of URL parameters per route
|
||||
const maxParams = 20
|
||||
|
||||
// LuaRouter is a filesystem-based HTTP router for Lua files
|
||||
type LuaRouter struct {
|
||||
routesDir string // Root directory containing route files
|
||||
routes map[string]*node // Method -> route tree
|
||||
mu sync.RWMutex // Lock for concurrent access to routes
|
||||
}
|
||||
|
||||
// node represents a node in the routing trie
|
||||
type node struct {
|
||||
handler string // Path to Lua file (empty if not an endpoint)
|
||||
bytecode []byte // Pre-compiled Lua bytecode
|
||||
paramName string // Parameter name (if this is a parameter node)
|
||||
staticChild map[string]*node // Static children by segment name
|
||||
paramChild *node // Parameter/wildcard child
|
||||
}
|
||||
|
||||
// Params holds URL parameters with fixed-size arrays to avoid allocations
|
||||
type Params struct {
|
||||
Keys [maxParams]string
|
||||
Values [maxParams]string
|
||||
Count int
|
||||
}
|
||||
|
||||
// Get returns a parameter value by name
|
||||
func (p *Params) Get(name string) string {
|
||||
for i := 0; i < p.Count; i++ {
|
||||
if p.Keys[i] == name {
|
||||
return p.Values[i]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// NewLuaRouter creates a new LuaRouter instance
|
||||
func NewLuaRouter(routesDir string) (*LuaRouter, error) {
|
||||
// Verify routes directory exists
|
||||
info, err := os.Stat(routesDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, errors.New("routes path is not a directory")
|
||||
}
|
||||
|
||||
r := &LuaRouter{
|
||||
routesDir: routesDir,
|
||||
routes: make(map[string]*node),
|
||||
}
|
||||
|
||||
// Initialize method trees
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
|
||||
for _, method := range methods {
|
||||
r.routes[method] = &node{
|
||||
staticChild: make(map[string]*node),
|
||||
}
|
||||
}
|
||||
|
||||
// Build routes
|
||||
if err := r.buildRoutes(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// buildRoutes scans the routes directory and builds the routing tree
|
||||
func (r *LuaRouter) buildRoutes() error {
|
||||
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only process .lua files
|
||||
if !strings.HasSuffix(info.Name(), ".lua") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract method from filename
|
||||
method := strings.ToUpper(strings.TrimSuffix(info.Name(), ".lua"))
|
||||
|
||||
// Check if valid method
|
||||
root, exists := r.routes[method]
|
||||
if !exists {
|
||||
return nil // Skip invalid methods
|
||||
}
|
||||
|
||||
// Get relative path for URL
|
||||
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build URL path
|
||||
urlPath := "/"
|
||||
if relDir != "." {
|
||||
urlPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
||||
}
|
||||
|
||||
// Add route to tree
|
||||
return r.addRoute(root, urlPath, path)
|
||||
})
|
||||
}
|
||||
|
||||
// addRoute adds a route to the routing tree and compiles the Lua file to bytecode
|
||||
func (r *LuaRouter) addRoute(root *node, urlPath, handlerPath string) error {
|
||||
segments := strings.Split(strings.Trim(urlPath, "/"), "/")
|
||||
current := root
|
||||
|
||||
for _, segment := range segments {
|
||||
if len(segment) >= 2 && segment[0] == '[' && segment[len(segment)-1] == ']' {
|
||||
if current.paramChild == nil {
|
||||
current.paramChild = &node{
|
||||
paramName: segment[1 : len(segment)-1],
|
||||
staticChild: make(map[string]*node),
|
||||
}
|
||||
}
|
||||
current = current.paramChild
|
||||
} else {
|
||||
// Create or get static child
|
||||
child, exists := current.staticChild[segment]
|
||||
if !exists {
|
||||
child = &node{
|
||||
staticChild: make(map[string]*node),
|
||||
}
|
||||
current.staticChild[segment] = child
|
||||
}
|
||||
current = child
|
||||
}
|
||||
}
|
||||
|
||||
// Set handler path
|
||||
current.handler = handlerPath
|
||||
|
||||
// Compile Lua file to bytecode
|
||||
if err := r.compileHandler(current); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Match finds a handler for the given method and path
|
||||
// Uses the pre-allocated params struct to avoid allocations
|
||||
func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
|
||||
// Reset params
|
||||
params.Count = 0
|
||||
|
||||
// Get route tree for method
|
||||
r.mu.RLock()
|
||||
root, exists := r.routes[method]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Split path
|
||||
segments := strings.Split(strings.Trim(path, "/"), "/")
|
||||
|
||||
// Match path
|
||||
return r.matchPath(root, segments, params, 0)
|
||||
}
|
||||
|
||||
// matchPath recursively matches a path against the routing tree
|
||||
func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) {
|
||||
// Base case: no more segments
|
||||
if len(segments) == 0 {
|
||||
if current.handler != "" {
|
||||
return current, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
segment := segments[0]
|
||||
remaining := segments[1:]
|
||||
|
||||
// Try static child first (exact match takes precedence)
|
||||
if child, exists := current.staticChild[segment]; exists {
|
||||
if node, found := r.matchPath(child, remaining, params, depth+1); found {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try parameter child
|
||||
if current.paramChild != nil {
|
||||
// Store parameter
|
||||
if params.Count < maxParams {
|
||||
params.Keys[params.Count] = current.paramChild.paramName
|
||||
params.Values[params.Count] = segment
|
||||
params.Count++
|
||||
}
|
||||
|
||||
if node, found := r.matchPath(current.paramChild, remaining, params, depth+1); found {
|
||||
return node, true
|
||||
}
|
||||
|
||||
// Backtrack: remove parameter if no match
|
||||
params.Count--
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// compileHandler compiles a Lua file to bytecode
|
||||
func (r *LuaRouter) compileHandler(n *node) error {
|
||||
if n.handler == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the Lua file
|
||||
content, err := os.ReadFile(n.handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compile to bytecode
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
return errors.New("failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(string(content), n.handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store bytecode in the node
|
||||
n.bytecode = bytecode
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBytecode returns the compiled bytecode for a matched route
|
||||
func (r *LuaRouter) GetBytecode(method, path string, params *Params) ([]byte, bool) {
|
||||
node, found := r.Match(method, path, params)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
return node.bytecode, true
|
||||
}
|
||||
|
||||
// Refresh rebuilds the router by rescanning the routes directory
|
||||
func (r *LuaRouter) Refresh() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Reset routes
|
||||
for method := range r.routes {
|
||||
r.routes[method] = &node{
|
||||
staticChild: make(map[string]*node),
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild routes
|
||||
return r.buildRoutes()
|
||||
}
|
|
@ -1,256 +0,0 @@
|
|||
package routers
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setupTestRoutes(t *testing.T) (string, func()) {
|
||||
// Create a temporary directory for test routes
|
||||
tempDir, err := os.MkdirTemp("", "fsrouter-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
|
||||
// Create route structure with valid Lua code
|
||||
routes := map[string]string{
|
||||
"get.lua": "return { path = '/' }",
|
||||
"post.lua": "return { path = '/' }",
|
||||
"api/get.lua": "return { path = '/api' }",
|
||||
"api/users/get.lua": "return { path = '/api/users' }",
|
||||
"api/users/[id]/get.lua": "return { path = '/api/users/[id]' }",
|
||||
"api/users/[id]/posts/get.lua": "return { path = '/api/users/[id]/posts' }",
|
||||
"api/[version]/docs/get.lua": "return { path = '/api/[version]/docs' }",
|
||||
}
|
||||
|
||||
for path, content := range routes {
|
||||
routePath := filepath.Join(tempDir, path)
|
||||
|
||||
// Create directories
|
||||
err := os.MkdirAll(filepath.Dir(routePath), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create directory %s: %v", filepath.Dir(routePath), err)
|
||||
}
|
||||
|
||||
// Create file
|
||||
err = os.WriteFile(routePath, []byte(content), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create file %s: %v", routePath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Return cleanup function
|
||||
cleanup := func() {
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return tempDir, cleanup
|
||||
}
|
||||
|
||||
func TestRouterInitialization(t *testing.T) {
|
||||
routesDir, cleanup := setupTestRoutes(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewLuaRouter(routesDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
if router == nil {
|
||||
t.Fatal("Router is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteMatching(t *testing.T) {
|
||||
routesDir, cleanup := setupTestRoutes(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewLuaRouter(routesDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
wantFound bool
|
||||
wantParams map[string]string
|
||||
wantHandler string
|
||||
}{
|
||||
// Static routes
|
||||
{"GET", "/", true, nil, filepath.Join(routesDir, "get.lua")},
|
||||
{"POST", "/", true, nil, filepath.Join(routesDir, "post.lua")},
|
||||
{"GET", "/api", true, nil, filepath.Join(routesDir, "api/get.lua")},
|
||||
{"GET", "/api/users", true, nil, filepath.Join(routesDir, "api/users/get.lua")},
|
||||
|
||||
// Parameterized routes
|
||||
{"GET", "/api/users/123", true, map[string]string{"id": "123"}, filepath.Join(routesDir, "api/users/[id]/get.lua")},
|
||||
{"GET", "/api/users/456/posts", true, map[string]string{"id": "456"}, filepath.Join(routesDir, "api/users/[id]/posts/get.lua")},
|
||||
{"GET", "/api/v1/docs", true, map[string]string{"version": "v1"}, filepath.Join(routesDir, "api/[version]/docs/get.lua")},
|
||||
|
||||
// Non-existent routes
|
||||
{"PUT", "/", false, nil, ""},
|
||||
{"GET", "/nonexistent", false, nil, ""},
|
||||
{"GET", "/api/nonexistent", false, nil, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.method+" "+tt.path, func(t *testing.T) {
|
||||
var params Params
|
||||
node, found := router.Match(tt.method, tt.path, ¶ms)
|
||||
|
||||
if found != tt.wantFound {
|
||||
t.Errorf("Match() found = %v, want %v", found, tt.wantFound)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
if node.handler != tt.wantHandler {
|
||||
t.Errorf("Match() handler = %v, want %v", node.handler, tt.wantHandler)
|
||||
}
|
||||
|
||||
// Verify bytecode was compiled
|
||||
if len(node.bytecode) == 0 {
|
||||
t.Errorf("No bytecode found for handler: %s", node.handler)
|
||||
}
|
||||
|
||||
// Verify parameters
|
||||
if tt.wantParams != nil {
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue := params.Get(key)
|
||||
if gotValue != wantValue {
|
||||
t.Errorf("Parameter %s = %s, want %s", key, gotValue, wantValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamExtraction(t *testing.T) {
|
||||
routesDir, cleanup := setupTestRoutes(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewLuaRouter(routesDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
var params Params
|
||||
_, found := router.Match("GET", "/api/v2/docs", ¶ms)
|
||||
|
||||
if !found {
|
||||
t.Fatalf("Route not found")
|
||||
}
|
||||
|
||||
if params.Count != 1 {
|
||||
t.Errorf("Expected 1 parameter, got %d", params.Count)
|
||||
}
|
||||
|
||||
if params.Keys[0] != "version" {
|
||||
t.Errorf("Expected parameter key 'version', got '%s'", params.Keys[0])
|
||||
}
|
||||
|
||||
if params.Values[0] != "v2" {
|
||||
t.Errorf("Expected parameter value 'v2', got '%s'", params.Values[0])
|
||||
}
|
||||
|
||||
if params.Get("version") != "v2" {
|
||||
t.Errorf("Get(\"version\") returned '%s', expected 'v2'", params.Get("version"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBytecode(t *testing.T) {
|
||||
routesDir, cleanup := setupTestRoutes(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewLuaRouter(routesDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
var params Params
|
||||
bytecode, found := router.GetBytecode("GET", "/api/users/123", ¶ms)
|
||||
|
||||
if !found {
|
||||
t.Fatalf("Route not found")
|
||||
}
|
||||
|
||||
if len(bytecode) == 0 {
|
||||
t.Errorf("Expected non-empty bytecode")
|
||||
}
|
||||
|
||||
// Check parameters were extracted
|
||||
if params.Get("id") != "123" {
|
||||
t.Errorf("Expected id parameter '123', got '%s'", params.Get("id"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
routesDir, cleanup := setupTestRoutes(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewLuaRouter(routesDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
// Add a new route file
|
||||
newRoutePath := filepath.Join(routesDir, "new", "get.lua")
|
||||
err = os.MkdirAll(filepath.Dir(newRoutePath), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(newRoutePath, []byte("return { path = '/new' }"), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Before refresh, route should not be found
|
||||
var params Params
|
||||
_, found := router.GetBytecode("GET", "/new", ¶ms)
|
||||
if found {
|
||||
t.Errorf("New route should not be found before refresh")
|
||||
}
|
||||
|
||||
// Refresh router
|
||||
err = router.Refresh()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to refresh router: %v", err)
|
||||
}
|
||||
|
||||
// After refresh, route should be found
|
||||
bytecode, found := router.GetBytecode("GET", "/new", ¶ms)
|
||||
if !found {
|
||||
t.Errorf("New route should be found after refresh")
|
||||
}
|
||||
|
||||
if len(bytecode) == 0 {
|
||||
t.Errorf("Expected non-empty bytecode for new route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidRoutesDir(t *testing.T) {
|
||||
// Non-existent directory
|
||||
_, err := NewLuaRouter("/non/existent/directory")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent directory, got nil")
|
||||
}
|
||||
|
||||
// Create a file instead of a directory
|
||||
tmpFile, err := os.CreateTemp("", "fsrouter-test-file")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
defer tmpFile.Close()
|
||||
|
||||
_, err = NewLuaRouter(tmpFile.Name())
|
||||
if err == nil {
|
||||
t.Error("Expected error for file as routes dir, got nil")
|
||||
}
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
package routers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// StaticRouter is a filesystem-based router for static files
|
||||
type StaticRouter struct {
|
||||
rootDir string // Root directory containing files
|
||||
routes map[string]string // Direct mapping from URL path to file path
|
||||
mu sync.RWMutex // Lock for concurrent access to routes
|
||||
}
|
||||
|
||||
// NewStaticRouter creates a new StaticRouter instance
|
||||
func NewStaticRouter(rootDir string) (*StaticRouter, error) {
|
||||
// Verify root directory exists
|
||||
info, err := os.Stat(rootDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, errors.New("root path is not a directory")
|
||||
}
|
||||
|
||||
r := &StaticRouter{
|
||||
rootDir: rootDir,
|
||||
routes: make(map[string]string),
|
||||
}
|
||||
|
||||
// Build routes
|
||||
if err := r.buildRoutes(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// buildRoutes scans the root directory and builds the routing map
|
||||
func (r *StaticRouter) buildRoutes() error {
|
||||
return filepath.Walk(r.rootDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get relative path for URL
|
||||
relPath, err := filepath.Rel(r.rootDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert to URL path with forward slashes for consistency
|
||||
urlPath := "/" + strings.ReplaceAll(relPath, "\\", "/")
|
||||
|
||||
// Add to routes map
|
||||
r.routes[urlPath] = path
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Match finds a file path for the given URL path
|
||||
func (r *StaticRouter) Match(path string) (string, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
filePath, found := r.routes[path]
|
||||
return filePath, found
|
||||
}
|
||||
|
||||
// Refresh rebuilds the router by rescanning the root directory
|
||||
func (r *StaticRouter) Refresh() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Clear routes
|
||||
r.routes = make(map[string]string)
|
||||
|
||||
// Rebuild routes
|
||||
return r.buildRoutes()
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
package routers
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setupStaticFiles(t *testing.T) (string, func()) {
|
||||
// Create a temporary directory
|
||||
tempDir, err := os.MkdirTemp("", "staticrouter-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
|
||||
// Create file structure
|
||||
files := map[string]string{
|
||||
"index.html": "<html>Home</html>",
|
||||
"about.html": "<html>About</html>",
|
||||
"api/index.json": `{"version": "1.0"}`,
|
||||
"users/index.html": "<html>Users</html>",
|
||||
"users/123/profile.html": "<html>User Profile</html>",
|
||||
"posts/hello-world/comments.html": "<html>Post Comments</html>",
|
||||
"docs/v1/api.html": "<html>API Docs</html>",
|
||||
}
|
||||
|
||||
for path, content := range files {
|
||||
filePath := filepath.Join(tempDir, path)
|
||||
|
||||
// Create directories
|
||||
err := os.MkdirAll(filepath.Dir(filePath), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create directory %s: %v", filepath.Dir(filePath), err)
|
||||
}
|
||||
|
||||
// Create file
|
||||
err = os.WriteFile(filePath, []byte(content), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create file %s: %v", filePath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Return cleanup function
|
||||
cleanup := func() {
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return tempDir, cleanup
|
||||
}
|
||||
|
||||
func TestStaticRouterInitialization(t *testing.T) {
|
||||
rootDir, cleanup := setupStaticFiles(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewStaticRouter(rootDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
if router == nil {
|
||||
t.Fatal("Router is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticRouteMatching(t *testing.T) {
|
||||
rootDir, cleanup := setupStaticFiles(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewStaticRouter(rootDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
wantFound bool
|
||||
wantHandler string
|
||||
}{
|
||||
{"/index.html", true, filepath.Join(rootDir, "index.html")},
|
||||
{"/about.html", true, filepath.Join(rootDir, "about.html")},
|
||||
{"/api/index.json", true, filepath.Join(rootDir, "api/index.json")},
|
||||
{"/users/index.html", true, filepath.Join(rootDir, "users/index.html")},
|
||||
{"/users/123/profile.html", true, filepath.Join(rootDir, "users/123/profile.html")},
|
||||
{"/posts/hello-world/comments.html", true, filepath.Join(rootDir, "posts/hello-world/comments.html")},
|
||||
{"/docs/v1/api.html", true, filepath.Join(rootDir, "docs/v1/api.html")},
|
||||
|
||||
// Non-existent routes
|
||||
{"/nonexistent.html", false, ""},
|
||||
{"/api/nonexistent.json", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
filePath, found := router.Match(tt.path)
|
||||
|
||||
if found != tt.wantFound {
|
||||
t.Errorf("Match() found = %v, want %v", found, tt.wantFound)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
if filePath != tt.wantHandler {
|
||||
t.Errorf("Match() handler = %v, want %v", filePath, tt.wantHandler)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//TestStaticParamExtraction has been removed since we no longer extract parameters
|
||||
|
||||
func TestStaticRefresh(t *testing.T) {
|
||||
rootDir, cleanup := setupStaticFiles(t)
|
||||
defer cleanup()
|
||||
|
||||
router, err := NewStaticRouter(rootDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create router: %v", err)
|
||||
}
|
||||
|
||||
// Add a new file
|
||||
newFilePath := filepath.Join(rootDir, "new.html")
|
||||
err = os.WriteFile(newFilePath, []byte("<html>New</html>"), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Before refresh, file should not be found
|
||||
_, found := router.Match("/new.html")
|
||||
if found {
|
||||
t.Errorf("New file should not be found before refresh")
|
||||
}
|
||||
|
||||
// Refresh router
|
||||
err = router.Refresh()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to refresh router: %v", err)
|
||||
}
|
||||
|
||||
// After refresh, file should be found
|
||||
filePath, found := router.Match("/new.html")
|
||||
if !found {
|
||||
t.Errorf("New file should be found after refresh")
|
||||
}
|
||||
|
||||
if filePath != newFilePath {
|
||||
t.Errorf("Expected path %s, got %s", newFilePath, filePath)
|
||||
}
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// EnsureDir checks if a directory exists and creates it if it doesn't.
|
||||
// Returns any error encountered during directory creation.
|
||||
func EnsureDir(path string) error {
|
||||
// Clean the path to handle any malformed input
|
||||
path = filepath.Clean(path)
|
||||
|
||||
// Check if the directory exists
|
||||
info, err := os.Stat(path)
|
||||
|
||||
// If no error, check if it's a directory
|
||||
if err == nil {
|
||||
if info.IsDir() {
|
||||
return nil // Directory already exists
|
||||
}
|
||||
return os.ErrExist // Path exists but is not a directory
|
||||
}
|
||||
|
||||
// If the error is not that the path doesn't exist, return it
|
||||
if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the directory with default permissions (0755)
|
||||
return os.MkdirAll(path, 0755)
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
# Worker Pool
|
||||
|
||||
### Pool
|
||||
|
||||
```go
|
||||
type Pool struct { ... }
|
||||
|
||||
// Create a pool with specified number of workers
|
||||
func NewPool(numWorkers int) (*Pool, error)
|
||||
|
||||
// Submit a job with default context
|
||||
func (p *Pool) Submit(bytecode []byte, ctx *Context) (any, error)
|
||||
|
||||
// Submit with timeout/cancellation support
|
||||
func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error)
|
||||
|
||||
// Shutdown the pool
|
||||
func (p *Pool) Shutdown() error
|
||||
|
||||
// Get number of active workers
|
||||
func (p *Pool) ActiveWorkers() uint32
|
||||
```
|
||||
|
||||
### Context
|
||||
|
||||
```go
|
||||
type Context struct { ... }
|
||||
|
||||
// Create a new execution context
|
||||
func NewContext() *Context
|
||||
|
||||
// Set a value
|
||||
func (c *Context) Set(key string, value any)
|
||||
|
||||
// Get a value
|
||||
func (c *Context) Get(key string) any
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```go
|
||||
// Create worker pool
|
||||
pool, err := workers.NewPool(4)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Compile bytecode (typically done once and reused)
|
||||
state := luajit.New()
|
||||
bytecode, err := state.CompileBytecode(`
|
||||
return ctx.message .. " from Lua"
|
||||
`, "script")
|
||||
state.Close()
|
||||
|
||||
// Set up execution context
|
||||
ctx := workers.NewContext()
|
||||
ctx.Set("message", "Hello")
|
||||
ctx.Set("params", map[string]any{"id": "123"})
|
||||
|
||||
// Execute bytecode
|
||||
result, err := pool.Submit(bytecode, ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(result) // "Hello from Lua"
|
||||
```
|
||||
|
||||
## With Timeout
|
||||
|
||||
```go
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Execute with timeout
|
||||
result, err := pool.SubmitWithContext(ctx, bytecode, execCtx)
|
||||
if err != nil {
|
||||
// Handle timeout or error
|
||||
}
|
||||
```
|
||||
|
||||
## In Lua Scripts
|
||||
|
||||
Inside Lua, the context is available as the global `ctx` table:
|
||||
|
||||
```lua
|
||||
-- Access a simple value
|
||||
local msg = ctx.message
|
||||
|
||||
-- Access nested values
|
||||
local id = ctx.params.id
|
||||
|
||||
-- Return a result to Go
|
||||
return {
|
||||
status = "success",
|
||||
data = msg
|
||||
}
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
- The pool is thread-safe; multiple goroutines can submit jobs concurrently
|
||||
- Each execution is isolated; global state is reset between executions
|
||||
- Bytecode should be compiled once and reused for better performance
|
||||
- Context values should be serializable to Lua (numbers, strings, booleans, maps, slices)
|
|
@ -1,24 +0,0 @@
|
|||
package workers
|
||||
|
||||
// Context represents execution context for a Lua script
|
||||
type Context struct {
|
||||
// Generic map for any context values (route params, HTTP request info, etc.)
|
||||
Values map[string]any
|
||||
}
|
||||
|
||||
// NewContext creates a new context with initialized maps
|
||||
func NewContext() *Context {
|
||||
return &Context{
|
||||
Values: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds a value to the context
|
||||
func (c *Context) Set(key string, value any) {
|
||||
c.Values[key] = value
|
||||
}
|
||||
|
||||
// Get retrieves a value from the context
|
||||
func (c *Context) Get(key string) any {
|
||||
return c.Values[key]
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
package workers
|
||||
|
||||
// JobResult represents the result of a Lua script execution
|
||||
type JobResult struct {
|
||||
Value any // Return value from Lua
|
||||
Error error // Error if any
|
||||
}
|
||||
|
||||
// job represents a Lua script execution request
|
||||
type job struct {
|
||||
Bytecode []byte // Compiled LuaJIT bytecode
|
||||
Context *Context // Execution context
|
||||
Result chan<- JobResult // Channel to send result back
|
||||
}
|
|
@ -1,103 +0,0 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Pool manages a pool of Lua worker goroutines
|
||||
type Pool struct {
|
||||
workers uint32 // Number of workers
|
||||
jobs chan job // Channel to send jobs to workers
|
||||
wg sync.WaitGroup // WaitGroup to track active workers
|
||||
quit chan struct{} // Channel to signal shutdown
|
||||
isRunning atomic.Bool // Flag to track if pool is running
|
||||
}
|
||||
|
||||
// NewPool creates a new worker pool with the specified number of workers
|
||||
func NewPool(numWorkers int) (*Pool, error) {
|
||||
if numWorkers <= 0 {
|
||||
return nil, ErrNoWorkers
|
||||
}
|
||||
|
||||
p := &Pool{
|
||||
workers: uint32(numWorkers),
|
||||
jobs: make(chan job, numWorkers), // Buffer equal to worker count
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
p.isRunning.Store(true)
|
||||
|
||||
// Start workers
|
||||
p.wg.Add(numWorkers)
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
w := &worker{
|
||||
pool: p,
|
||||
id: uint32(i),
|
||||
}
|
||||
go w.run()
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// SubmitWithContext sends a job to the worker pool with context
|
||||
func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
|
||||
if !p.isRunning.Load() {
|
||||
return nil, ErrPoolClosed
|
||||
}
|
||||
|
||||
resultChan := make(chan JobResult, 1)
|
||||
j := job{
|
||||
Bytecode: bytecode,
|
||||
Context: execCtx,
|
||||
Result: resultChan,
|
||||
}
|
||||
|
||||
// Submit job with context
|
||||
select {
|
||||
case p.jobs <- j:
|
||||
// Job submitted
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Wait for result with context
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
return result.Value, result.Error
|
||||
case <-ctx.Done():
|
||||
// Note: The job will still be processed by a worker,
|
||||
// but the result will be discarded
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Submit sends a job to the worker pool
|
||||
func (p *Pool) Submit(bytecode []byte, execCtx *Context) (any, error) {
|
||||
return p.SubmitWithContext(context.Background(), bytecode, execCtx)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the worker pool
|
||||
func (p *Pool) Shutdown() error {
|
||||
if !p.isRunning.Load() {
|
||||
return ErrPoolClosed
|
||||
}
|
||||
p.isRunning.Store(false)
|
||||
|
||||
// Signal workers to quit
|
||||
close(p.quit)
|
||||
|
||||
// Wait for workers to finish
|
||||
p.wg.Wait()
|
||||
|
||||
// Close jobs channel
|
||||
close(p.jobs)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActiveWorkers returns the number of active workers
|
||||
func (p *Pool) ActiveWorkers() uint32 {
|
||||
return atomic.LoadUint32(&p.workers)
|
||||
}
|
|
@ -1,160 +0,0 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrPoolClosed = errors.New("worker pool is closed")
|
||||
ErrNoWorkers = errors.New("no workers available")
|
||||
)
|
||||
|
||||
// worker represents a single Lua execution worker
|
||||
type worker struct {
|
||||
pool *Pool // Reference to the pool
|
||||
state *luajit.State // Lua state
|
||||
id uint32 // Worker ID
|
||||
}
|
||||
|
||||
// run is the main worker function that processes jobs
|
||||
func (w *worker) run() {
|
||||
defer w.pool.wg.Done()
|
||||
|
||||
// Initialize Lua state
|
||||
w.state = luajit.New()
|
||||
if w.state == nil {
|
||||
// Worker failed to initialize, decrement counter
|
||||
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
||||
return
|
||||
}
|
||||
defer w.state.Close()
|
||||
|
||||
// Set up reset function for clearing state between requests
|
||||
if err := w.setupResetFunction(); err != nil {
|
||||
// Worker failed to initialize reset function, decrement counter
|
||||
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
||||
return
|
||||
}
|
||||
|
||||
// Main worker loop
|
||||
for {
|
||||
select {
|
||||
case job, ok := <-w.pool.jobs:
|
||||
if !ok {
|
||||
// Jobs channel closed, exit
|
||||
return
|
||||
}
|
||||
|
||||
// Execute job
|
||||
result := w.executeJob(job)
|
||||
job.Result <- result
|
||||
|
||||
case <-w.pool.quit:
|
||||
// Quit signal received, exit
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setupResetFunction initializes the reset function for clearing globals
|
||||
func (w *worker) setupResetFunction() error {
|
||||
resetScript := `
|
||||
-- Create reset function to efficiently clear globals after each request
|
||||
function __reset_globals()
|
||||
-- Only keep builtin globals, remove all user-defined globals
|
||||
local preserve = {
|
||||
["_G"] = true, ["_VERSION"] = true, ["__reset_globals"] = true,
|
||||
["assert"] = true, ["collectgarbage"] = true, ["coroutine"] = true,
|
||||
["debug"] = true, ["dofile"] = true, ["error"] = true,
|
||||
["getmetatable"] = true, ["io"] = true, ["ipairs"] = true,
|
||||
["load"] = true, ["loadfile"] = true, ["loadstring"] = true,
|
||||
["math"] = true, ["next"] = true, ["os"] = true,
|
||||
["package"] = true, ["pairs"] = true, ["pcall"] = true,
|
||||
["print"] = true, ["rawequal"] = true, ["rawget"] = true,
|
||||
["rawset"] = true, ["require"] = true, ["select"] = true,
|
||||
["setmetatable"] = true, ["string"] = true, ["table"] = true,
|
||||
["tonumber"] = true, ["tostring"] = true, ["type"] = true,
|
||||
["unpack"] = true, ["xpcall"] = true
|
||||
}
|
||||
|
||||
-- Clear all non-standard globals
|
||||
for name in pairs(_G) do
|
||||
if not preserve[name] then
|
||||
_G[name] = nil
|
||||
end
|
||||
end
|
||||
|
||||
-- Run garbage collection to release memory
|
||||
collectgarbage('collect')
|
||||
end
|
||||
`
|
||||
|
||||
return w.state.DoString(resetScript)
|
||||
}
|
||||
|
||||
// resetState prepares the Lua state for a new job
|
||||
func (w *worker) resetState() {
|
||||
w.state.DoString("__reset_globals()")
|
||||
}
|
||||
|
||||
// setContext sets job context as global tables in Lua state
|
||||
func (w *worker) setContext(ctx *Context) error {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create context table
|
||||
w.state.NewTable()
|
||||
|
||||
// Add values to context table
|
||||
for key, value := range ctx.Values {
|
||||
// Push key
|
||||
w.state.PushString(key)
|
||||
|
||||
// Push value
|
||||
if err := w.state.PushValue(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set table[key] = value
|
||||
w.state.SetTable(-3)
|
||||
}
|
||||
|
||||
// Set the table as global 'ctx'
|
||||
w.state.SetGlobal("ctx")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeJob executes a Lua job in the worker's state
|
||||
func (w *worker) executeJob(j job) JobResult {
|
||||
// Reset state before execution
|
||||
w.resetState()
|
||||
|
||||
// Set context
|
||||
if j.Context != nil {
|
||||
if err := w.setContext(j.Context); err != nil {
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
}
|
||||
|
||||
// Load bytecode
|
||||
if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil {
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Execute script with one result
|
||||
if err := w.state.RunBytecodeWithResults(1); err != nil {
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
||||
// Get result
|
||||
value, err := w.state.ToValue(-1)
|
||||
w.state.Pop(1) // Pop result
|
||||
|
||||
return JobResult{value, err}
|
||||
}
|
|
@ -1,445 +0,0 @@
|
|||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// This helper function creates real LuaJIT bytecode for our tests. Instead of using
|
||||
// mocks, we compile actual Lua code into bytecode just like we would in production.
|
||||
func createTestBytecode(t *testing.T, code string) []byte {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile test bytecode: %v", err)
|
||||
}
|
||||
|
||||
return bytecode
|
||||
}
|
||||
|
||||
// This test makes sure we can create a worker pool with a valid number of workers,
|
||||
// and that we properly reject attempts to create a pool with zero or negative workers.
|
||||
func TestNewPool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workers int
|
||||
expectErr bool
|
||||
}{
|
||||
{"valid workers", 4, false},
|
||||
{"zero workers", 0, true},
|
||||
{"negative workers", -1, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pool, err := NewPool(tt.workers)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for %d workers, got nil", tt.workers)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Errorf("Expected non-nil pool")
|
||||
} else {
|
||||
pool.Shutdown()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Here we're testing the basic job submission flow. We run a simple Lua script
|
||||
// that returns the number 42 and make sure we get that same value back from the worker pool.
|
||||
func TestPoolSubmit(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
|
||||
result, err := pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected float64 result, got %T", result)
|
||||
}
|
||||
|
||||
if num != 42 {
|
||||
t.Errorf("Expected 42, got %f", num)
|
||||
}
|
||||
}
|
||||
|
||||
// This test checks how our worker pool handles timeouts. We run a script that takes
|
||||
// some time to complete and verify two scenarios: one where the timeout is long enough
|
||||
// for successful completion, and another where we expect the operation to be canceled
|
||||
// due to a short timeout.
|
||||
func TestPoolSubmitWithContext(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Create bytecode that sleeps
|
||||
bytecode := createTestBytecode(t, `
|
||||
-- Sleep for 500ms
|
||||
local start = os.time()
|
||||
while os.difftime(os.time(), start) < 0.5 do end
|
||||
return "done"
|
||||
`)
|
||||
|
||||
// Test with timeout that should succeed
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := pool.SubmitWithContext(ctx, bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
|
||||
}
|
||||
if result != "done" {
|
||||
t.Errorf("Expected 'done', got %v", result)
|
||||
}
|
||||
|
||||
// Test with timeout that should fail
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = pool.SubmitWithContext(ctx, bytecode, nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected timeout error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// We need to make sure we can pass different types of context values from Go to Lua and
|
||||
// get them back properly. This test sends numbers, strings, booleans, and arrays to
|
||||
// a Lua script and verifies they're all handled correctly in both directions.
|
||||
func TestContextValues(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
bytecode := createTestBytecode(t, `
|
||||
return {
|
||||
num = ctx.number,
|
||||
str = ctx.text,
|
||||
flag = ctx.enabled,
|
||||
list = {ctx.table[1], ctx.table[2], ctx.table[3]},
|
||||
}
|
||||
`)
|
||||
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("number", 42.5)
|
||||
execCtx.Set("text", "hello")
|
||||
execCtx.Set("enabled", true)
|
||||
execCtx.Set("table", []float64{10, 20, 30})
|
||||
|
||||
result, err := pool.Submit(bytecode, execCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Result should be a map
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Check values
|
||||
if resultMap["num"] != 42.5 {
|
||||
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
|
||||
}
|
||||
if resultMap["str"] != "hello" {
|
||||
t.Errorf("Expected str=hello, got %v", resultMap["str"])
|
||||
}
|
||||
if resultMap["flag"] != true {
|
||||
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
|
||||
}
|
||||
|
||||
arr, ok := resultMap["list"].([]float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []float64, got %T", resultMap["list"])
|
||||
}
|
||||
|
||||
expected := []float64{10, 20, 30}
|
||||
for i, v := range expected {
|
||||
if arr[i] != v {
|
||||
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test context with nested data structures
|
||||
func TestNestedContext(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
bytecode := createTestBytecode(t, `
|
||||
return {
|
||||
id = ctx.params.id,
|
||||
name = ctx.params.name,
|
||||
method = ctx.request.method,
|
||||
path = ctx.request.path
|
||||
}
|
||||
`)
|
||||
|
||||
execCtx := NewContext()
|
||||
|
||||
// Set nested params
|
||||
params := map[string]any{
|
||||
"id": "123",
|
||||
"name": "test",
|
||||
}
|
||||
execCtx.Set("params", params)
|
||||
|
||||
// Set nested request info
|
||||
request := map[string]any{
|
||||
"method": "GET",
|
||||
"path": "/api/test",
|
||||
}
|
||||
execCtx.Set("request", request)
|
||||
|
||||
result, err := pool.Submit(bytecode, execCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Result should be a map
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["id"] != "123" {
|
||||
t.Errorf("Expected id=123, got %v", resultMap["id"])
|
||||
}
|
||||
if resultMap["name"] != "test" {
|
||||
t.Errorf("Expected name=test, got %v", resultMap["name"])
|
||||
}
|
||||
if resultMap["method"] != "GET" {
|
||||
t.Errorf("Expected method=GET, got %v", resultMap["method"])
|
||||
}
|
||||
if resultMap["path"] != "/api/test" {
|
||||
t.Errorf("Expected path=/api/test, got %v", resultMap["path"])
|
||||
}
|
||||
}
|
||||
|
||||
// A key requirement for our worker pool is that we don't leak state between executions.
|
||||
// This test confirms that by setting a global variable in one job and then checking
|
||||
// that it's been cleared before the next job runs on the same worker.
|
||||
func TestStateReset(t *testing.T) {
|
||||
pool, err := NewPool(1) // Use 1 worker to ensure same state is reused
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// First job sets a global
|
||||
bytecode1 := createTestBytecode(t, `
|
||||
global_var = "should be cleared"
|
||||
return true
|
||||
`)
|
||||
|
||||
// Second job checks if global exists
|
||||
bytecode2 := createTestBytecode(t, `
|
||||
return global_var ~= nil
|
||||
`)
|
||||
|
||||
// Run first job
|
||||
_, err = pool.Submit(bytecode1, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit first job: %v", err)
|
||||
}
|
||||
|
||||
// Run second job
|
||||
result, err := pool.Submit(bytecode2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit second job: %v", err)
|
||||
}
|
||||
|
||||
// Global should be cleared
|
||||
if result.(bool) {
|
||||
t.Errorf("Expected global_var to be cleared, but it still exists")
|
||||
}
|
||||
}
|
||||
|
||||
// Let's make sure our pool shuts down cleanly. This test confirms that jobs work
|
||||
// before shutdown, that we get the right error when trying to submit after shutdown,
|
||||
// and that we properly handle attempts to shut down an already closed pool.
|
||||
func TestPoolShutdown(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
|
||||
// Submit a job to verify pool works
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
_, err = pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to submit job: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown
|
||||
if err := pool.Shutdown(); err != nil {
|
||||
t.Errorf("Shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
// Submit after shutdown should fail
|
||||
_, err = pool.Submit(bytecode, nil)
|
||||
if err != ErrPoolClosed {
|
||||
t.Errorf("Expected ErrPoolClosed, got %v", err)
|
||||
}
|
||||
|
||||
// Second shutdown should return error
|
||||
if err := pool.Shutdown(); err != ErrPoolClosed {
|
||||
t.Errorf("Expected ErrPoolClosed on second shutdown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// A robust worker pool needs to handle errors gracefully. This test checks various
|
||||
// error scenarios: invalid bytecode, Lua runtime errors, nil context (which
|
||||
// should work fine), and unsupported parameter types (which should properly error out).
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
pool, err := NewPool(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Test invalid bytecode
|
||||
_, err = pool.Submit([]byte("not valid bytecode"), nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid bytecode, got nil")
|
||||
}
|
||||
|
||||
// Test Lua runtime error
|
||||
bytecode := createTestBytecode(t, `
|
||||
error("intentional error")
|
||||
return true
|
||||
`)
|
||||
|
||||
_, err = pool.Submit(bytecode, nil)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from Lua error() call, got nil")
|
||||
}
|
||||
|
||||
// Test with nil context
|
||||
bytecode = createTestBytecode(t, "return ctx == nil")
|
||||
result, err := pool.Submit(bytecode, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error with nil context: %v", err)
|
||||
}
|
||||
if result.(bool) != true {
|
||||
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
|
||||
}
|
||||
|
||||
// Test invalid context value
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("param", complex128(1+2i)) // Unsupported type
|
||||
|
||||
bytecode = createTestBytecode(t, "return ctx.param")
|
||||
_, err = pool.Submit(bytecode, execCtx)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for unsupported context value type, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// The whole point of a worker pool is concurrent processing, so we need to verify
|
||||
// it works under load. This test submits multiple jobs simultaneously and makes sure
|
||||
// they all complete correctly with their own unique results.
|
||||
func TestConcurrentExecution(t *testing.T) {
|
||||
const workers = 4
|
||||
const jobs = 20
|
||||
|
||||
pool, err := NewPool(workers)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create pool: %v", err)
|
||||
}
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Create bytecode that returns its input
|
||||
bytecode := createTestBytecode(t, "return ctx.n")
|
||||
|
||||
// Run multiple jobs concurrently
|
||||
results := make(chan int, jobs)
|
||||
for i := 0; i < jobs; i++ {
|
||||
i := i // Capture loop variable
|
||||
go func() {
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("n", float64(i))
|
||||
|
||||
result, err := pool.Submit(bytecode, execCtx)
|
||||
if err != nil {
|
||||
t.Errorf("Job %d failed: %v", i, err)
|
||||
results <- -1
|
||||
return
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Errorf("Job %d: expected float64, got %T", i, result)
|
||||
results <- -1
|
||||
return
|
||||
}
|
||||
|
||||
results <- int(num)
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
counts := make(map[int]bool)
|
||||
for i := 0; i < jobs; i++ {
|
||||
result := <-results
|
||||
if result != -1 {
|
||||
counts[result] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all jobs were processed
|
||||
if len(counts) != jobs {
|
||||
t.Errorf("Expected %d unique results, got %d", jobs, len(counts))
|
||||
}
|
||||
}
|
||||
|
||||
// Test context operations
|
||||
func TestContext(t *testing.T) {
|
||||
ctx := NewContext()
|
||||
|
||||
// Test Set and Get
|
||||
ctx.Set("key", "value")
|
||||
if ctx.Get("key") != "value" {
|
||||
t.Errorf("Expected value, got %v", ctx.Get("key"))
|
||||
}
|
||||
|
||||
// Test overwriting
|
||||
ctx.Set("key", 123)
|
||||
if ctx.Get("key") != 123 {
|
||||
t.Errorf("Expected 123, got %v", ctx.Get("key"))
|
||||
}
|
||||
|
||||
// Test missing key
|
||||
if ctx.Get("missing") != nil {
|
||||
t.Errorf("Expected nil for missing key, got %v", ctx.Get("missing"))
|
||||
}
|
||||
}
|
7
go.mod
7
go.mod
|
@ -1,7 +0,0 @@
|
|||
module git.sharkk.net/Sky/Moonshark
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require git.sharkk.net/Sky/LuaJIT-to-Go v0.0.0
|
||||
|
||||
replace git.sharkk.net/Sky/LuaJIT-to-Go => ./luajit
|
1
luajit
1
luajit
|
@ -1 +0,0 @@
|
|||
Subproject commit 98ca857d73956bf69a07641710b678c11681319f
|
105
moonshark.go
105
moonshark.go
|
@ -1,105 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.sharkk.net/Sky/Moonshark/core/config"
|
||||
"git.sharkk.net/Sky/Moonshark/core/http"
|
||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||
"git.sharkk.net/Sky/Moonshark/core/routers"
|
||||
"git.sharkk.net/Sky/Moonshark/core/utils"
|
||||
"git.sharkk.net/Sky/Moonshark/core/workers"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize logger
|
||||
log := logger.New(logger.LevelDebug, true)
|
||||
defer log.Close()
|
||||
|
||||
log.Info("Starting Moonshark server")
|
||||
|
||||
// Load configuration from config.lua
|
||||
cfg, err := config.Load("config.lua")
|
||||
if err != nil {
|
||||
log.Warning("Failed to load config.lua: %v", err)
|
||||
log.Info("Using default configuration")
|
||||
cfg = config.New()
|
||||
}
|
||||
|
||||
// Get port from config or use default
|
||||
port := cfg.GetInt("port", 3117)
|
||||
|
||||
// Initialize routers
|
||||
routesDir := cfg.GetString("routes_dir", "./routes")
|
||||
staticDir := cfg.GetString("static_dir", "./static")
|
||||
|
||||
// Get worker pool size from config or use default
|
||||
workerPoolSize := cfg.GetInt("worker_pool_size", 4)
|
||||
|
||||
// Ensure directories exist
|
||||
if err = utils.EnsureDir(routesDir); err != nil {
|
||||
log.Fatal("Routes directory doesn't exist, and could not create it: %v", err)
|
||||
}
|
||||
if err = utils.EnsureDir(staticDir); err != nil {
|
||||
log.Fatal("Static directory doesn't exist, and could not create it: %v", err)
|
||||
}
|
||||
|
||||
// Initialize worker pool
|
||||
pool, err := workers.NewPool(workerPoolSize)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize worker pool: %v", err)
|
||||
}
|
||||
log.Info("Worker pool initialized with %d workers", workerPoolSize)
|
||||
defer pool.Shutdown()
|
||||
|
||||
// Initialize Lua router for dynamic routes
|
||||
luaRouter, err := routers.NewLuaRouter(routesDir)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize Lua router: %v", err)
|
||||
}
|
||||
log.Info("Lua router initialized with routes from %s", routesDir)
|
||||
|
||||
// Initialize static file router
|
||||
staticRouter, err := routers.NewStaticRouter(staticDir)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize static router: %v", err)
|
||||
}
|
||||
log.Info("Static router initialized with files from %s", staticDir)
|
||||
|
||||
// Create HTTP server
|
||||
server := http.New(luaRouter, staticRouter, pool, log)
|
||||
|
||||
// Handle graceful shutdown
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
// Start server in a goroutine
|
||||
go func() {
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
log.Info("Server listening on http://localhost%s", addr)
|
||||
if err := server.ListenAndServe(addr); err != nil {
|
||||
if err.Error() != "http: Server closed" {
|
||||
log.Error("Server error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal
|
||||
<-stop
|
||||
log.Info("Shutdown signal received")
|
||||
|
||||
// Gracefully shut down the server
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
log.Error("Server shutdown error: %v", err)
|
||||
}
|
||||
|
||||
log.Info("Server stopped")
|
||||
}
|
Loading…
Reference in New Issue
Block a user