Compare commits
No commits in common. "0259d1a135aac788980925b07056a57fd08f023a" and "7bc5194b10c90f9b0889cbf50426398bac0e29fc" have entirely different histories.
0259d1a135
...
7bc5194b10
351
core/config/config_test.go
Normal file
351
core/config/config_test.go
Normal file
|
@ -0,0 +1,351 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestLoad verifies we can successfully load configuration values from a Lua file.
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
// Create a temporary config file
|
||||||
|
content := `
|
||||||
|
-- Basic configuration values
|
||||||
|
host = "localhost"
|
||||||
|
port = 8080
|
||||||
|
debug = true
|
||||||
|
pi = 3.14159
|
||||||
|
`
|
||||||
|
configFile := createTempLuaFile(t, content)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
// Load the config
|
||||||
|
cfg, err := Load(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify values were loaded correctly
|
||||||
|
if host := cfg.GetString("host", ""); host != "localhost" {
|
||||||
|
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port := cfg.GetInt("port", 0); port != 8080 {
|
||||||
|
t.Errorf("Expected port to be 8080, got %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug := cfg.GetBool("debug", false); !debug {
|
||||||
|
t.Errorf("Expected debug to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pi := cfg.GetFloat("pi", 0); pi != 3.14159 {
|
||||||
|
t.Errorf("Expected pi to be 3.14159, got %f", pi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadErrors ensures the package properly handles loading errors.
|
||||||
|
func TestLoadErrors(t *testing.T) {
|
||||||
|
// Test with non-existent file
|
||||||
|
_, err := Load("nonexistent.lua")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when loading non-existent file, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with invalid Lua
|
||||||
|
content := `
|
||||||
|
-- This is invalid Lua
|
||||||
|
host = "localhost
|
||||||
|
port = 8080)
|
||||||
|
`
|
||||||
|
configFile := createTempLuaFile(t, content)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
_, err = Load(configFile)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when loading invalid Lua, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalVsGlobal verifies only global variables are exported, not locals.
|
||||||
|
func TestLocalVsGlobal(t *testing.T) {
|
||||||
|
// Create a temporary config file with both local and global variables
|
||||||
|
content := `
|
||||||
|
-- Local variables should not be exported
|
||||||
|
local local_var = "hidden"
|
||||||
|
|
||||||
|
-- Global variables should be exported
|
||||||
|
global_var = "visible"
|
||||||
|
|
||||||
|
-- A function that uses both
|
||||||
|
function test_func()
|
||||||
|
return local_var .. " " .. global_var
|
||||||
|
end
|
||||||
|
`
|
||||||
|
configFile := createTempLuaFile(t, content)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
// Load the config
|
||||||
|
cfg, err := Load(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that global_var exists
|
||||||
|
if globalVar := cfg.GetString("global_var", ""); globalVar != "visible" {
|
||||||
|
t.Errorf("Expected global_var to be 'visible', got '%s'", globalVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that local_var does not exist
|
||||||
|
if localVar := cfg.GetString("local_var", "default"); localVar != "default" {
|
||||||
|
t.Errorf("Expected local_var to use default, got '%s'", localVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that functions are not exported
|
||||||
|
if val := cfg.Get("test_func"); val != nil {
|
||||||
|
t.Errorf("Expected function to not be exported, got %v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestArrayHandling verifies correct handling of Lua arrays.
|
||||||
|
func TestArrayHandling(t *testing.T) {
|
||||||
|
// Create a temporary config file with arrays
|
||||||
|
content := `
|
||||||
|
-- Numeric array
|
||||||
|
numbers = {10, 20, 30, 40, 50}
|
||||||
|
|
||||||
|
-- String array
|
||||||
|
strings = {"apple", "banana", "cherry"}
|
||||||
|
|
||||||
|
-- Mixed array
|
||||||
|
mixed = {1, "two", true, 4.5}
|
||||||
|
`
|
||||||
|
configFile := createTempLuaFile(t, content)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
// Load the config
|
||||||
|
cfg, err := Load(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetIntArray
|
||||||
|
intArray := cfg.GetIntArray("numbers")
|
||||||
|
expectedInts := []int{10, 20, 30, 40, 50}
|
||||||
|
if !reflect.DeepEqual(intArray, expectedInts) {
|
||||||
|
t.Errorf("Expected int array %v, got %v", expectedInts, intArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetStringArray
|
||||||
|
strArray := cfg.GetStringArray("strings")
|
||||||
|
expectedStrs := []string{"apple", "banana", "cherry"}
|
||||||
|
if !reflect.DeepEqual(strArray, expectedStrs) {
|
||||||
|
t.Errorf("Expected string array %v, got %v", expectedStrs, strArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetArray with mixed types
|
||||||
|
mixedArray := cfg.GetArray("mixed")
|
||||||
|
if len(mixedArray) != 4 {
|
||||||
|
t.Errorf("Expected mixed array length 4, got %d", len(mixedArray))
|
||||||
|
// Skip further tests if array is empty to avoid panic
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check types - carefully to avoid panics
|
||||||
|
if len(mixedArray) > 0 {
|
||||||
|
if num, ok := mixedArray[0].(float64); !ok || num != 1 {
|
||||||
|
t.Errorf("Expected first element to be 1, got %v", mixedArray[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(mixedArray) > 1 {
|
||||||
|
if str, ok := mixedArray[1].(string); !ok || str != "two" {
|
||||||
|
t.Errorf("Expected second element to be 'two', got %v", mixedArray[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComplexTable tests handling of complex nested tables.
|
||||||
|
func TestComplexTable(t *testing.T) {
|
||||||
|
// Create a temporary config file with complex tables
|
||||||
|
content := `
|
||||||
|
-- Nested table structure
|
||||||
|
server = {
|
||||||
|
host = "localhost",
|
||||||
|
port = 8080,
|
||||||
|
settings = {
|
||||||
|
timeout = 30,
|
||||||
|
retries = 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
-- Table with mixed array and map elements
|
||||||
|
mixed_table = {
|
||||||
|
list = {1, 2, 3},
|
||||||
|
mapping = {
|
||||||
|
a = "apple",
|
||||||
|
b = "banana"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
configFile := createTempLuaFile(t, content)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
// Load the config
|
||||||
|
cfg, err := Load(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test getting nested values
|
||||||
|
serverMap := cfg.GetMap("server")
|
||||||
|
if serverMap == nil {
|
||||||
|
t.Fatal("Expected server map to exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check first level values
|
||||||
|
if host, ok := serverMap["host"].(string); !ok || host != "localhost" {
|
||||||
|
t.Errorf("Expected server.host to be 'localhost', got %v", serverMap["host"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if port, ok := serverMap["port"].(float64); !ok || port != 8080 {
|
||||||
|
t.Errorf("Expected server.port to be 8080, got %v", serverMap["port"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check nested settings
|
||||||
|
settings, ok := serverMap["settings"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected server.settings to be a map")
|
||||||
|
}
|
||||||
|
|
||||||
|
if timeout, ok := settings["timeout"].(float64); !ok || timeout != 30 {
|
||||||
|
t.Errorf("Expected server.settings.timeout to be 30, got %v", settings["timeout"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultValues verifies default values work correctly when keys don't exist.
|
||||||
|
func TestDefaultValues(t *testing.T) {
|
||||||
|
// Create a temporary config file
|
||||||
|
content := `
|
||||||
|
-- Just one value
|
||||||
|
existing = "value"
|
||||||
|
`
|
||||||
|
configFile := createTempLuaFile(t, content)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
// Load the config
|
||||||
|
cfg, err := Load(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test defaults for non-existent keys
|
||||||
|
if val := cfg.GetString("nonexistent", "default"); val != "default" {
|
||||||
|
t.Errorf("Expected default string, got '%s'", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := cfg.GetInt("nonexistent", 42); val != 42 {
|
||||||
|
t.Errorf("Expected default int 42, got %d", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := cfg.GetFloat("nonexistent", 3.14); val != 3.14 {
|
||||||
|
t.Errorf("Expected default float 3.14, got %f", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := cfg.GetBool("nonexistent", true); !val {
|
||||||
|
t.Errorf("Expected default bool true, got false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModifyConfig tests the ability to modify configuration values.
|
||||||
|
func TestModifyConfig(t *testing.T) {
|
||||||
|
// Create a config manually
|
||||||
|
cfg := New()
|
||||||
|
|
||||||
|
// Set some values
|
||||||
|
cfg.Set("host", "localhost")
|
||||||
|
cfg.Set("port", 8080)
|
||||||
|
|
||||||
|
// Verify the values were set
|
||||||
|
if host := cfg.GetString("host", ""); host != "localhost" {
|
||||||
|
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port := cfg.GetInt("port", 0); port != 8080 {
|
||||||
|
t.Errorf("Expected port to be 8080, got %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify a value
|
||||||
|
cfg.Set("host", "127.0.0.1")
|
||||||
|
|
||||||
|
// Verify the change
|
||||||
|
if host := cfg.GetString("host", ""); host != "127.0.0.1" {
|
||||||
|
t.Errorf("Expected modified host to be '127.0.0.1', got '%s'", host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTypeConversion tests the type conversion in getter methods.
|
||||||
|
func TestTypeConversion(t *testing.T) {
|
||||||
|
// Create a temporary config file with values that need conversion
|
||||||
|
content := `
|
||||||
|
-- Numbers that can be integers
|
||||||
|
int_as_float = 42.0
|
||||||
|
|
||||||
|
-- Floats that should remain floats
|
||||||
|
float_val = 3.14159
|
||||||
|
`
|
||||||
|
configFile := createTempLuaFile(t, content)
|
||||||
|
defer os.Remove(configFile)
|
||||||
|
|
||||||
|
// Load the config
|
||||||
|
cfg, err := Load(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetInt with a float value
|
||||||
|
if val := cfg.GetInt("int_as_float", 0); val != 42 {
|
||||||
|
t.Errorf("Expected int 42, got %d", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetFloat with an int value
|
||||||
|
if val := cfg.GetFloat("int_as_float", 0); val != 42.0 {
|
||||||
|
t.Errorf("Expected float 42.0, got %f", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test incorrect type handling
|
||||||
|
cfg.Set("string_val", "not a number")
|
||||||
|
|
||||||
|
if val := cfg.GetInt("string_val", 99); val != 99 {
|
||||||
|
t.Errorf("Expected default int 99 for string value, got %d", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := cfg.GetFloat("string_val", 99.9); val != 99.9 {
|
||||||
|
t.Errorf("Expected default float 99.9 for string value, got %f", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := cfg.GetBool("float_val", false); val != false {
|
||||||
|
t.Errorf("Expected default false for non-bool value, got true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create a temporary Lua file with content
|
||||||
|
func createTempLuaFile(t *testing.T, content string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tempFile, err := os.CreateTemp("", "config-test-*.lua")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tempFile.WriteString(content); err != nil {
|
||||||
|
os.Remove(tempFile.Name())
|
||||||
|
t.Fatalf("Failed to write to temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tempFile.Close(); err != nil {
|
||||||
|
os.Remove(tempFile.Name())
|
||||||
|
t.Fatalf("Failed to close temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tempFile.Name()
|
||||||
|
}
|
|
@ -170,32 +170,15 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for HTTPResponse type
|
|
||||||
if httpResp, ok := result.(*runner.HTTPResponse); ok {
|
|
||||||
// Set response headers
|
|
||||||
for name, value := range httpResp.Headers {
|
|
||||||
w.Header().Set(name, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set status code
|
|
||||||
w.WriteHeader(httpResp.Status)
|
|
||||||
|
|
||||||
// Process the body based on its type
|
|
||||||
if httpResp.Body == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result = httpResp.Body // Set result to body for processing below
|
|
||||||
}
|
|
||||||
|
|
||||||
switch res := result.(type) {
|
switch res := result.(type) {
|
||||||
case string:
|
case string:
|
||||||
// String result - plain text
|
// String result
|
||||||
setContentTypeIfMissing(w, contentTypePlain)
|
w.Header().Set("Content-Type", contentTypePlain)
|
||||||
w.Write([]byte(res))
|
w.Write([]byte(res))
|
||||||
default:
|
|
||||||
// All other types - convert to JSON
|
case map[string]any, []any:
|
||||||
setContentTypeIfMissing(w, contentTypeJSON)
|
// Table or array result - convert to JSON
|
||||||
|
w.Header().Set("Content-Type", contentTypeJSON)
|
||||||
data, err := json.Marshal(res)
|
data, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Failed to marshal response: %v", err)
|
log.Error("Failed to marshal response: %v", err)
|
||||||
|
@ -203,11 +186,16 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Write(data)
|
w.Write(data)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Other result types - convert to JSON
|
||||||
|
w.Header().Set("Content-Type", contentTypeJSON)
|
||||||
|
data, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to marshal response: %v", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setContentTypeIfMissing(w http.ResponseWriter, contentType string) {
|
|
||||||
if w.Header().Get("Content-Type") == "" {
|
|
||||||
w.Header().Set("Content-Type", contentType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
214
core/logger/logger_test.go
Normal file
214
core/logger/logger_test.go
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoggerLevels(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelInfo, false)
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
// Debug should be below threshold
|
||||||
|
logger.Debug("This should not appear")
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
t.Error("Debug message appeared when it should be filtered")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info and above should appear
|
||||||
|
logger.Info("Info message")
|
||||||
|
if !strings.Contains(buf.String(), "INFO") {
|
||||||
|
t.Errorf("Info message not logged, got: %q", buf.String())
|
||||||
|
}
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
logger.Warning("Warning message")
|
||||||
|
if !strings.Contains(buf.String(), "WARN") {
|
||||||
|
t.Errorf("Warning message not logged, got: %q", buf.String())
|
||||||
|
}
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
logger.Error("Error message")
|
||||||
|
if !strings.Contains(buf.String(), "ERROR") {
|
||||||
|
t.Errorf("Error message not logged, got: %q", buf.String())
|
||||||
|
}
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
// Test format strings
|
||||||
|
logger.Info("Count: %d", 42)
|
||||||
|
if !strings.Contains(buf.String(), "Count: 42") {
|
||||||
|
t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
|
||||||
|
}
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
// Test changing level
|
||||||
|
logger.SetLevel(LevelError)
|
||||||
|
logger.Info("This should not appear")
|
||||||
|
logger.Warning("This should not appear")
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
t.Error("Messages below threshold appeared")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error("Error should appear")
|
||||||
|
if !strings.Contains(buf.String(), "ERROR") {
|
||||||
|
t.Errorf("Error message not logged after level change, got: %q", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerRateLimit(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelDebug, false)
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
// Override max logs per second to something small for testing
|
||||||
|
logger.maxLogsPerSec = 5
|
||||||
|
logger.limitDuration = 1 * time.Second
|
||||||
|
|
||||||
|
// Send debug messages (should get limited)
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
logger.Debug("Debug message %d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error messages should always go through
|
||||||
|
logger.Error("Error message should appear")
|
||||||
|
|
||||||
|
content := buf.String()
|
||||||
|
|
||||||
|
// We should see some debug messages, then a warning about rate limiting,
|
||||||
|
// and finally the error message
|
||||||
|
if !strings.Contains(content, "Debug message 0") {
|
||||||
|
t.Error("First debug message should appear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(content, "Rate limiting logger") {
|
||||||
|
t.Error("Rate limiting message should appear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(content, "ERROR") {
|
||||||
|
t.Error("Error message should always appear despite rate limiting")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerConcurrency(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelDebug, false)
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
// Increase log threshold for this test
|
||||||
|
logger.maxLogsPerSec = 1000
|
||||||
|
|
||||||
|
// Log a bunch of messages concurrently
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(n int) {
|
||||||
|
defer wg.Done()
|
||||||
|
logger.Info("Concurrent message %d", n)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Check logs were processed
|
||||||
|
content := buf.String()
|
||||||
|
if !strings.Contains(content, "Concurrent message") {
|
||||||
|
t.Error("Concurrent messages should appear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerColors(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelInfo, true)
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
// Test with color
|
||||||
|
logger.Info("Colored message")
|
||||||
|
|
||||||
|
content := buf.String()
|
||||||
|
t.Logf("Colored output: %q", content) // Print actual output for diagnosis
|
||||||
|
if !strings.Contains(content, "\033[") {
|
||||||
|
t.Errorf("Color codes not present when enabled, got: %q", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
logger.DisableColors()
|
||||||
|
logger.Info("Non-colored message")
|
||||||
|
|
||||||
|
content = buf.String()
|
||||||
|
if strings.Contains(content, "\033[") {
|
||||||
|
t.Errorf("Color codes present when disabled, got: %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultLogger(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
SetOutput(&buf)
|
||||||
|
|
||||||
|
Info("Test default logger")
|
||||||
|
|
||||||
|
content := buf.String()
|
||||||
|
if !strings.Contains(content, "INFO") {
|
||||||
|
t.Errorf("Default logger not working, got: %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelInfo, false)
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
// Set very high threshold to avoid rate limiting during benchmark
|
||||||
|
logger.maxLogsPerSec = int64(b.N + 1)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Info("Benchmark message %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLoggerWithRateLimit(b *testing.B) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelDebug, false)
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
// Set threshold to allow about 10% of messages through
|
||||||
|
logger.maxLogsPerSec = int64(b.N / 10)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Debug("Benchmark message %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelDebug, false)
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
// Set very high threshold to avoid rate limiting during benchmark
|
||||||
|
logger.maxLogsPerSec = int64(b.N + 1)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Debug("Parallel benchmark message %d", i)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkProductionLevels(b *testing.B) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := New(LevelWarning, false) // Only log warnings and above
|
||||||
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// This should be filtered out before any processing
|
||||||
|
logger.Debug("Debug message that won't be logged %d", i)
|
||||||
|
}
|
||||||
|
}
|
|
@ -103,25 +103,18 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
return nil, ErrInitFailed
|
return nil, ErrInitFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize HTTP module BEFORE sandbox setup
|
|
||||||
httpInit := HTTPModuleInitFunc()
|
|
||||||
if err := httpInit(state); err != nil {
|
|
||||||
state.Close()
|
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up sandbox AFTER HTTP module is initialized
|
|
||||||
if err := runner.sandbox.Setup(state); err != nil {
|
|
||||||
state.Close()
|
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preload all modules into package.loaded
|
// Preload all modules into package.loaded
|
||||||
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
|
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
|
||||||
state.Close()
|
state.Close()
|
||||||
return nil, errors.New("failed to preload modules")
|
return nil, errors.New("failed to preload modules")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up sandbox
|
||||||
|
if err := runner.sandbox.Setup(state); err != nil {
|
||||||
|
state.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
|
||||||
// Run init function if provided
|
// Run init function if provided
|
||||||
if runner.initFunc != nil {
|
if runner.initFunc != nil {
|
||||||
if err := runner.initFunc(state); err != nil {
|
if err := runner.initFunc(state); err != nil {
|
||||||
|
|
373
core/runner/luarunner_test.go
Normal file
373
core/runner/luarunner_test.go
Normal file
|
@ -0,0 +1,373 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to create bytecode for testing
|
||||||
|
func createTestBytecode(t *testing.T, code string) []byte {
|
||||||
|
state := luajit.New()
|
||||||
|
if state == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer state.Close()
|
||||||
|
|
||||||
|
bytecode, err := state.CompileBytecode(code, "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to compile test bytecode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytecode
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerBasic(t *testing.T) {
|
||||||
|
runner, err := NewRunner()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
defer runner.Close()
|
||||||
|
|
||||||
|
bytecode := createTestBytecode(t, "return 42")
|
||||||
|
|
||||||
|
result, err := runner.Run(bytecode, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to run script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num, ok := result.(float64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected float64 result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if num != 42 {
|
||||||
|
t.Errorf("Expected 42, got %f", num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerWithContext(t *testing.T) {
|
||||||
|
runner, err := NewRunner()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
defer runner.Close()
|
||||||
|
|
||||||
|
bytecode := createTestBytecode(t, `
|
||||||
|
return {
|
||||||
|
num = ctx.number,
|
||||||
|
str = ctx.text,
|
||||||
|
flag = ctx.enabled,
|
||||||
|
list = {ctx.table[1], ctx.table[2], ctx.table[3]},
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
|
||||||
|
execCtx := NewContext()
|
||||||
|
execCtx.Set("number", 42.5)
|
||||||
|
execCtx.Set("text", "hello")
|
||||||
|
execCtx.Set("enabled", true)
|
||||||
|
execCtx.Set("table", []float64{10, 20, 30})
|
||||||
|
|
||||||
|
result, err := runner.Run(bytecode, execCtx, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to run job: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result should be a map
|
||||||
|
resultMap, ok := result.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected map result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check values
|
||||||
|
if resultMap["num"] != 42.5 {
|
||||||
|
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
|
||||||
|
}
|
||||||
|
if resultMap["str"] != "hello" {
|
||||||
|
t.Errorf("Expected str=hello, got %v", resultMap["str"])
|
||||||
|
}
|
||||||
|
if resultMap["flag"] != true {
|
||||||
|
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
|
||||||
|
}
|
||||||
|
|
||||||
|
arr, ok := resultMap["list"].([]float64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected []float64, got %T", resultMap["list"])
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []float64{10, 20, 30}
|
||||||
|
for i, v := range expected {
|
||||||
|
if arr[i] != v {
|
||||||
|
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerWithTimeout(t *testing.T) {
|
||||||
|
runner, err := NewRunner()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
defer runner.Close()
|
||||||
|
|
||||||
|
// Create bytecode that sleeps
|
||||||
|
bytecode := createTestBytecode(t, `
|
||||||
|
-- Sleep for 500ms
|
||||||
|
local start = os.time()
|
||||||
|
while os.difftime(os.time(), start) < 0.5 do end
|
||||||
|
return "done"
|
||||||
|
`)
|
||||||
|
|
||||||
|
// Test with timeout that should succeed
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, err := runner.RunWithContext(ctx, bytecode, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
|
||||||
|
}
|
||||||
|
if result != "done" {
|
||||||
|
t.Errorf("Expected 'done', got %v", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with timeout that should fail
|
||||||
|
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err = runner.RunWithContext(ctx, bytecode, nil, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected timeout error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSandboxIsolation(t *testing.T) {
|
||||||
|
runner, err := NewRunner()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
defer runner.Close()
|
||||||
|
|
||||||
|
// Create a script that tries to modify a global variable
|
||||||
|
bytecode1 := createTestBytecode(t, `
|
||||||
|
-- Set a "global" variable
|
||||||
|
my_global = "test value"
|
||||||
|
return true
|
||||||
|
`)
|
||||||
|
|
||||||
|
_, err = runner.Run(bytecode1, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute first script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now try to access that variable from another script
|
||||||
|
bytecode2 := createTestBytecode(t, `
|
||||||
|
-- Try to access the previously set global
|
||||||
|
return my_global ~= nil
|
||||||
|
`)
|
||||||
|
|
||||||
|
result, err := runner.Run(bytecode2, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute second script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The variable should not be accessible (sandbox isolation)
|
||||||
|
if result.(bool) {
|
||||||
|
t.Errorf("Expected sandbox isolation, but global variable was accessible")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerWithInit(t *testing.T) {
|
||||||
|
// Define an init function that registers a simple "math" module
|
||||||
|
mathInit := func(state *luajit.State) error {
|
||||||
|
// Register the "add" function
|
||||||
|
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a + b)
|
||||||
|
return 1 // Return one result
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a whole module
|
||||||
|
mathFuncs := map[string]luajit.GoFunction{
|
||||||
|
"multiply": func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a * b)
|
||||||
|
return 1
|
||||||
|
},
|
||||||
|
"subtract": func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a - b)
|
||||||
|
return 1
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return RegisterModule(state, "math2", mathFuncs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a runner with our init function
|
||||||
|
runner, err := NewRunner(WithInitFunc(mathInit))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
defer runner.Close()
|
||||||
|
|
||||||
|
// Test the add function
|
||||||
|
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
||||||
|
result1, err := runner.Run(bytecode1, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to call add function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num1, ok := result1.(float64)
|
||||||
|
if !ok || num1 != 12 {
|
||||||
|
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the math2 module
|
||||||
|
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
||||||
|
result2, err := runner.Run(bytecode2, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to call math2.multiply: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num2, ok := result2.(float64)
|
||||||
|
if !ok || num2 != 48 {
|
||||||
|
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentExecution(t *testing.T) {
|
||||||
|
const jobs = 20
|
||||||
|
|
||||||
|
runner, err := NewRunner(WithBufferSize(20))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
defer runner.Close()
|
||||||
|
|
||||||
|
// Create bytecode that returns its input
|
||||||
|
bytecode := createTestBytecode(t, "return ctx.n")
|
||||||
|
|
||||||
|
// Run multiple jobs concurrently
|
||||||
|
results := make(chan int, jobs)
|
||||||
|
for i := 0; i < jobs; i++ {
|
||||||
|
i := i // Capture loop variable
|
||||||
|
go func() {
|
||||||
|
execCtx := NewContext()
|
||||||
|
execCtx.Set("n", float64(i))
|
||||||
|
|
||||||
|
result, err := runner.Run(bytecode, execCtx, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Job %d failed: %v", i, err)
|
||||||
|
results <- -1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
num, ok := result.(float64)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Job %d: expected float64, got %T", i, result)
|
||||||
|
results <- -1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
results <- int(num)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect results
|
||||||
|
seen := make(map[int]bool)
|
||||||
|
for i := 0; i < jobs; i++ {
|
||||||
|
result := <-results
|
||||||
|
if result != -1 {
|
||||||
|
seen[result] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all jobs were processed
|
||||||
|
if len(seen) != jobs {
|
||||||
|
t.Errorf("Expected %d unique results, got %d", jobs, len(seen))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunnerClose(t *testing.T) {
|
||||||
|
runner, err := NewRunner()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit a job to verify runner works
|
||||||
|
bytecode := createTestBytecode(t, "return 42")
|
||||||
|
_, err = runner.Run(bytecode, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to run job: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close
|
||||||
|
if err := runner.Close(); err != nil {
|
||||||
|
t.Errorf("Close failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run after close should fail
|
||||||
|
_, err = runner.Run(bytecode, nil, "")
|
||||||
|
if err != ErrRunnerClosed {
|
||||||
|
t.Errorf("Expected ErrRunnerClosed, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second close should return error
|
||||||
|
if err := runner.Close(); err != ErrRunnerClosed {
|
||||||
|
t.Errorf("Expected ErrRunnerClosed on second close, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorHandling(t *testing.T) {
|
||||||
|
runner, err := NewRunner()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create runner: %v", err)
|
||||||
|
}
|
||||||
|
defer runner.Close()
|
||||||
|
|
||||||
|
// Test invalid bytecode
|
||||||
|
_, err = runner.Run([]byte("not valid bytecode"), nil, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error for invalid bytecode, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Lua runtime error
|
||||||
|
bytecode := createTestBytecode(t, `
|
||||||
|
error("intentional error")
|
||||||
|
return true
|
||||||
|
`)
|
||||||
|
|
||||||
|
_, err = runner.Run(bytecode, nil, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from Lua error() call, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with nil context
|
||||||
|
bytecode = createTestBytecode(t, "return ctx == nil")
|
||||||
|
result, err := runner.Run(bytecode, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error with nil context: %v", err)
|
||||||
|
}
|
||||||
|
if result.(bool) != true {
|
||||||
|
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid context value
|
||||||
|
execCtx := NewContext()
|
||||||
|
execCtx.Set("param", complex128(1+2i)) // Unsupported type
|
||||||
|
|
||||||
|
bytecode = createTestBytecode(t, "return ctx.param")
|
||||||
|
_, err = runner.Run(bytecode, execCtx, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error for unsupported context value type, got nil")
|
||||||
|
}
|
||||||
|
}
|
281
core/runner/require_test.go
Normal file
281
core/runner/require_test.go
Normal file
|
@ -0,0 +1,281 @@
|
||||||
|
package runner_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
"git.sharkk.net/Sky/Moonshark/core/runner"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequireFunctionality(t *testing.T) {
|
||||||
|
// Create temporary directories for test
|
||||||
|
tempDir, err := os.MkdirTemp("", "luarunner-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create script directory and lib directory
|
||||||
|
scriptDir := filepath.Join(tempDir, "scripts")
|
||||||
|
libDir := filepath.Join(tempDir, "libs")
|
||||||
|
|
||||||
|
if err := os.Mkdir(scriptDir, 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create script directory: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.Mkdir(libDir, 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create lib directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a module in the lib directory
|
||||||
|
libModule := `
|
||||||
|
local lib = {}
|
||||||
|
|
||||||
|
function lib.add(a, b)
|
||||||
|
return a + b
|
||||||
|
end
|
||||||
|
|
||||||
|
function lib.mul(a, b)
|
||||||
|
return a * b
|
||||||
|
end
|
||||||
|
|
||||||
|
return lib
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(libDir, "mathlib.lua"), []byte(libModule), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write lib module: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a helper module in the script directory
|
||||||
|
helperModule := `
|
||||||
|
local helper = {}
|
||||||
|
|
||||||
|
function helper.square(x)
|
||||||
|
return x * x
|
||||||
|
end
|
||||||
|
|
||||||
|
function helper.cube(x)
|
||||||
|
return x * x * x
|
||||||
|
end
|
||||||
|
|
||||||
|
return helper
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(scriptDir, "helper.lua"), []byte(helperModule), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write helper module: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create main script that requires both modules
|
||||||
|
mainScript := `
|
||||||
|
-- Require from the same directory
|
||||||
|
local helper = require("helper")
|
||||||
|
|
||||||
|
-- Require from the lib directory
|
||||||
|
local mathlib = require("mathlib")
|
||||||
|
|
||||||
|
-- Use both modules
|
||||||
|
local result = {
|
||||||
|
add = mathlib.add(10, 5),
|
||||||
|
mul = mathlib.mul(10, 5),
|
||||||
|
square = helper.square(5),
|
||||||
|
cube = helper.cube(3)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
`
|
||||||
|
mainScriptPath := filepath.Join(scriptDir, "main.lua")
|
||||||
|
if err := os.WriteFile(mainScriptPath, []byte(mainScript), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write main script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create LuaRunner
|
||||||
|
luaRunner, err := runner.NewRunner(
|
||||||
|
runner.WithScriptDir(scriptDir),
|
||||||
|
runner.WithLibDirs(libDir),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create LuaRunner: %v", err)
|
||||||
|
}
|
||||||
|
defer luaRunner.Close()
|
||||||
|
|
||||||
|
// Compile the main script
|
||||||
|
state := luajit.New()
|
||||||
|
if state == nil {
|
||||||
|
t.Fatal("Failed to create Lua state")
|
||||||
|
}
|
||||||
|
defer state.Close()
|
||||||
|
|
||||||
|
bytecode, err := state.CompileBytecode(mainScript, "main.lua")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to compile script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the script
|
||||||
|
result, err := luaRunner.Run(bytecode, nil, mainScriptPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to run script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check result
|
||||||
|
resultMap, ok := result.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected map result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
expectedResults := map[string]float64{
|
||||||
|
"add": 15, // 10 + 5
|
||||||
|
"mul": 50, // 10 * 5
|
||||||
|
"square": 25, // 5^2
|
||||||
|
"cube": 27, // 3^3
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, expected := range expectedResults {
|
||||||
|
if val, ok := resultMap[key]; !ok {
|
||||||
|
t.Errorf("Missing result key: %s", key)
|
||||||
|
} else if val != expected {
|
||||||
|
t.Errorf("For %s: expected %.1f, got %v", key, expected, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireSecurityBoundaries(t *testing.T) {
|
||||||
|
// Create temporary directories for test
|
||||||
|
tempDir, err := os.MkdirTemp("", "luarunner-security-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create script directory and lib directory
|
||||||
|
scriptDir := filepath.Join(tempDir, "scripts")
|
||||||
|
libDir := filepath.Join(tempDir, "libs")
|
||||||
|
secretDir := filepath.Join(tempDir, "secret")
|
||||||
|
|
||||||
|
err = os.MkdirAll(scriptDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create script directory: %v", err)
|
||||||
|
}
|
||||||
|
err = os.MkdirAll(libDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create lib directory: %v", err)
|
||||||
|
}
|
||||||
|
err = os.MkdirAll(secretDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create secret directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a "secret" module that should not be accessible
|
||||||
|
secretModule := `
|
||||||
|
local secret = "TOP SECRET"
|
||||||
|
return secret
|
||||||
|
`
|
||||||
|
err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write secret module: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a normal module in lib
|
||||||
|
normalModule := `return "normal module"`
|
||||||
|
err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write normal module: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a compile-and-run function that takes care of both compilation and execution
|
||||||
|
compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) {
|
||||||
|
// Compile
|
||||||
|
state := luajit.New()
|
||||||
|
if state == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
defer state.Close()
|
||||||
|
|
||||||
|
bytecode, err := state.CompileBytecode(scriptText, scriptName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and configure a new runner each time
|
||||||
|
r, err := runner.NewRunner(
|
||||||
|
runner.WithScriptDir(scriptDir),
|
||||||
|
runner.WithLibDirs(libDir),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
// Run
|
||||||
|
return r.Run(bytecode, nil, scriptPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that normal require works
|
||||||
|
normalScript := `
|
||||||
|
local normal = require("normal")
|
||||||
|
return normal
|
||||||
|
`
|
||||||
|
normalPath := filepath.Join(scriptDir, "normal_test.lua")
|
||||||
|
err = os.WriteFile(normalPath, []byte(normalScript), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write normal script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := compileAndRun(normalScript, "normal_test.lua", normalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to run normal script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != "normal module" {
|
||||||
|
t.Errorf("Expected 'normal module', got %v", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test path traversal attempts
|
||||||
|
pathTraversalTests := []struct {
|
||||||
|
name string
|
||||||
|
script string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Direct path traversal",
|
||||||
|
script: `
|
||||||
|
-- Try path traversal
|
||||||
|
local secret = require("../secret/secret")
|
||||||
|
return secret ~= nil
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Double dot traversal",
|
||||||
|
script: `
|
||||||
|
local secret = require("..secret.secret")
|
||||||
|
return secret ~= nil
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Absolute path traversal",
|
||||||
|
script: `
|
||||||
|
local secret = require("` + filepath.Join(secretDir, "secret") + `")
|
||||||
|
return secret ~= nil
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range pathTraversalTests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
|
||||||
|
err := os.WriteFile(scriptPath, []byte(tt.script), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write test script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath)
|
||||||
|
// If there's an error, that's expected and good
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no error, then the script should have returned false (couldn't get the module)
|
||||||
|
if result == true {
|
||||||
|
t.Errorf("Security breach! Script was able to access restricted module")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,12 +12,10 @@ type Sandbox struct {
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox
|
// NewSandbox creates a new sandbox
|
||||||
func NewSandbox() *Sandbox {
|
func NewSandbox() *Sandbox {
|
||||||
s := &Sandbox{
|
return &Sandbox{
|
||||||
modules: make(map[string]any),
|
modules: make(map[string]any),
|
||||||
initialized: false,
|
initialized: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
// AddModule adds a module to the sandbox environment
|
||||||
|
@ -34,69 +32,63 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
|
|
||||||
// Create high-performance persistent environment
|
// Create high-performance persistent environment
|
||||||
return state.DoString(`
|
return state.DoString(`
|
||||||
-- Global shared environment (created once)
|
-- Global shared environment (created once)
|
||||||
__env_system = __env_system or {
|
__env_system = __env_system or {
|
||||||
base_env = nil, -- Template environment
|
base_env = nil, -- Template environment
|
||||||
initialized = false, -- Initialization flag
|
initialized = false, -- Initialization flag
|
||||||
env_pool = {}, -- Pre-allocated environment pool
|
env_pool = {}, -- Pre-allocated environment pool
|
||||||
pool_size = 0, -- Current pool size
|
pool_size = 0, -- Current pool size
|
||||||
max_pool_size = 8 -- Maximum pool size
|
max_pool_size = 8 -- Maximum pool size
|
||||||
}
|
}
|
||||||
|
|
||||||
-- Initialize base environment once
|
-- Initialize base environment once
|
||||||
if not __env_system.initialized then
|
if not __env_system.initialized then
|
||||||
-- Create base environment with all standard libraries
|
-- Create base environment with all standard libraries
|
||||||
local base = {}
|
local base = {}
|
||||||
|
|
||||||
-- Safe standard libraries
|
-- Safe standard libraries
|
||||||
base.string = string
|
base.string = string
|
||||||
base.table = table
|
base.table = table
|
||||||
base.math = math
|
base.math = math
|
||||||
base.os = {
|
base.os = {
|
||||||
time = os.time,
|
time = os.time,
|
||||||
date = os.date,
|
date = os.date,
|
||||||
difftime = os.difftime,
|
difftime = os.difftime,
|
||||||
clock = os.clock
|
clock = os.clock
|
||||||
}
|
}
|
||||||
|
|
||||||
-- Basic functions
|
-- Basic functions
|
||||||
base.tonumber = tonumber
|
base.tonumber = tonumber
|
||||||
base.tostring = tostring
|
base.tostring = tostring
|
||||||
base.type = type
|
base.type = type
|
||||||
base.pairs = pairs
|
base.pairs = pairs
|
||||||
base.ipairs = ipairs
|
base.ipairs = ipairs
|
||||||
base.next = next
|
base.next = next
|
||||||
base.select = select
|
base.select = select
|
||||||
base.unpack = unpack
|
base.unpack = unpack
|
||||||
base.pcall = pcall
|
base.pcall = pcall
|
||||||
base.xpcall = xpcall
|
base.xpcall = xpcall
|
||||||
base.error = error
|
base.error = error
|
||||||
base.assert = assert
|
base.assert = assert
|
||||||
|
|
||||||
-- Package system is shared for performance
|
-- Package system is shared for performance
|
||||||
base.package = {
|
base.package = {
|
||||||
loaded = package.loaded,
|
loaded = package.loaded,
|
||||||
path = package.path,
|
path = package.path,
|
||||||
preload = package.preload
|
preload = package.preload
|
||||||
}
|
}
|
||||||
|
|
||||||
-- Add HTTP module explicitly to the base environment
|
-- Add registered custom modules
|
||||||
base.http = http
|
if __sandbox_modules then
|
||||||
|
for name, mod in pairs(__sandbox_modules) do
|
||||||
|
base[name] = mod
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
-- Add registered custom modules
|
-- Store base environment
|
||||||
if __sandbox_modules then
|
__env_system.base_env = base
|
||||||
for name, mod in pairs(__sandbox_modules) do
|
__env_system.initialized = true
|
||||||
base[name] = mod
|
end
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Store base environment
|
|
||||||
__env_system.base_env = base
|
|
||||||
__env_system.initialized = true
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Global variable for tracking current environment
|
|
||||||
__last_env = nil
|
|
||||||
|
|
||||||
-- Fast environment creation with pre-allocation
|
-- Fast environment creation with pre-allocation
|
||||||
function __get_sandbox_env(ctx)
|
function __get_sandbox_env(ctx)
|
||||||
|
@ -109,8 +101,6 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
|
|
||||||
-- Clear any previous context
|
-- Clear any previous context
|
||||||
env.ctx = ctx or nil
|
env.ctx = ctx or nil
|
||||||
-- Clear any previous response
|
|
||||||
env._response = nil
|
|
||||||
else
|
else
|
||||||
-- Create new environment with metatable inheritance
|
-- Create new environment with metatable inheritance
|
||||||
env = setmetatable({}, {
|
env = setmetatable({}, {
|
||||||
|
@ -128,9 +118,6 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Store reference to current environment
|
|
||||||
__last_env = env
|
|
||||||
|
|
||||||
return env
|
return env
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -140,7 +127,6 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
if __env_system.pool_size < __env_system.max_pool_size then
|
if __env_system.pool_size < __env_system.max_pool_size then
|
||||||
-- Clear context reference to avoid memory leaks
|
-- Clear context reference to avoid memory leaks
|
||||||
env.ctx = nil
|
env.ctx = nil
|
||||||
-- Don't clear response data - we need it for extraction
|
|
||||||
|
|
||||||
-- Add to pool
|
-- Add to pool
|
||||||
table.insert(__env_system.env_pool, env)
|
table.insert(__env_system.env_pool, env)
|
||||||
|
@ -253,14 +239,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
|
||||||
|
|
||||||
// Get result
|
// Get result
|
||||||
result, err := state.ToValue(-1)
|
result, err := state.ToValue(-1)
|
||||||
state.Pop(1) // Pop result
|
state.Pop(1)
|
||||||
|
|
||||||
// Check if HTTP response was set
|
|
||||||
httpResponse, hasHTTPResponse := GetHTTPResponse(state)
|
|
||||||
if hasHTTPResponse {
|
|
||||||
httpResponse.Body = result
|
|
||||||
return httpResponse, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user