Compare commits
2 Commits
7bc5194b10
...
0259d1a135
Author | SHA1 | Date | |
---|---|---|---|
0259d1a135 | |||
780533bd76 |
|
@ -1,351 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestLoad verifies we can successfully load configuration values from a Lua file.
|
|
||||||
func TestLoad(t *testing.T) {
|
|
||||||
// Create a temporary config file
|
|
||||||
content := `
|
|
||||||
-- Basic configuration values
|
|
||||||
host = "localhost"
|
|
||||||
port = 8080
|
|
||||||
debug = true
|
|
||||||
pi = 3.14159
|
|
||||||
`
|
|
||||||
configFile := createTempLuaFile(t, content)
|
|
||||||
defer os.Remove(configFile)
|
|
||||||
|
|
||||||
// Load the config
|
|
||||||
cfg, err := Load(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify values were loaded correctly
|
|
||||||
if host := cfg.GetString("host", ""); host != "localhost" {
|
|
||||||
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
|
||||||
}
|
|
||||||
|
|
||||||
if port := cfg.GetInt("port", 0); port != 8080 {
|
|
||||||
t.Errorf("Expected port to be 8080, got %d", port)
|
|
||||||
}
|
|
||||||
|
|
||||||
if debug := cfg.GetBool("debug", false); !debug {
|
|
||||||
t.Errorf("Expected debug to be true")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pi := cfg.GetFloat("pi", 0); pi != 3.14159 {
|
|
||||||
t.Errorf("Expected pi to be 3.14159, got %f", pi)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestLoadErrors ensures the package properly handles loading errors.
|
|
||||||
func TestLoadErrors(t *testing.T) {
|
|
||||||
// Test with non-existent file
|
|
||||||
_, err := Load("nonexistent.lua")
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error when loading non-existent file, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with invalid Lua
|
|
||||||
content := `
|
|
||||||
-- This is invalid Lua
|
|
||||||
host = "localhost
|
|
||||||
port = 8080)
|
|
||||||
`
|
|
||||||
configFile := createTempLuaFile(t, content)
|
|
||||||
defer os.Remove(configFile)
|
|
||||||
|
|
||||||
_, err = Load(configFile)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error when loading invalid Lua, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestLocalVsGlobal verifies only global variables are exported, not locals.
|
|
||||||
func TestLocalVsGlobal(t *testing.T) {
|
|
||||||
// Create a temporary config file with both local and global variables
|
|
||||||
content := `
|
|
||||||
-- Local variables should not be exported
|
|
||||||
local local_var = "hidden"
|
|
||||||
|
|
||||||
-- Global variables should be exported
|
|
||||||
global_var = "visible"
|
|
||||||
|
|
||||||
-- A function that uses both
|
|
||||||
function test_func()
|
|
||||||
return local_var .. " " .. global_var
|
|
||||||
end
|
|
||||||
`
|
|
||||||
configFile := createTempLuaFile(t, content)
|
|
||||||
defer os.Remove(configFile)
|
|
||||||
|
|
||||||
// Load the config
|
|
||||||
cfg, err := Load(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that global_var exists
|
|
||||||
if globalVar := cfg.GetString("global_var", ""); globalVar != "visible" {
|
|
||||||
t.Errorf("Expected global_var to be 'visible', got '%s'", globalVar)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that local_var does not exist
|
|
||||||
if localVar := cfg.GetString("local_var", "default"); localVar != "default" {
|
|
||||||
t.Errorf("Expected local_var to use default, got '%s'", localVar)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that functions are not exported
|
|
||||||
if val := cfg.Get("test_func"); val != nil {
|
|
||||||
t.Errorf("Expected function to not be exported, got %v", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestArrayHandling verifies correct handling of Lua arrays.
|
|
||||||
func TestArrayHandling(t *testing.T) {
|
|
||||||
// Create a temporary config file with arrays
|
|
||||||
content := `
|
|
||||||
-- Numeric array
|
|
||||||
numbers = {10, 20, 30, 40, 50}
|
|
||||||
|
|
||||||
-- String array
|
|
||||||
strings = {"apple", "banana", "cherry"}
|
|
||||||
|
|
||||||
-- Mixed array
|
|
||||||
mixed = {1, "two", true, 4.5}
|
|
||||||
`
|
|
||||||
configFile := createTempLuaFile(t, content)
|
|
||||||
defer os.Remove(configFile)
|
|
||||||
|
|
||||||
// Load the config
|
|
||||||
cfg, err := Load(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test GetIntArray
|
|
||||||
intArray := cfg.GetIntArray("numbers")
|
|
||||||
expectedInts := []int{10, 20, 30, 40, 50}
|
|
||||||
if !reflect.DeepEqual(intArray, expectedInts) {
|
|
||||||
t.Errorf("Expected int array %v, got %v", expectedInts, intArray)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test GetStringArray
|
|
||||||
strArray := cfg.GetStringArray("strings")
|
|
||||||
expectedStrs := []string{"apple", "banana", "cherry"}
|
|
||||||
if !reflect.DeepEqual(strArray, expectedStrs) {
|
|
||||||
t.Errorf("Expected string array %v, got %v", expectedStrs, strArray)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test GetArray with mixed types
|
|
||||||
mixedArray := cfg.GetArray("mixed")
|
|
||||||
if len(mixedArray) != 4 {
|
|
||||||
t.Errorf("Expected mixed array length 4, got %d", len(mixedArray))
|
|
||||||
// Skip further tests if array is empty to avoid panic
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check types - carefully to avoid panics
|
|
||||||
if len(mixedArray) > 0 {
|
|
||||||
if num, ok := mixedArray[0].(float64); !ok || num != 1 {
|
|
||||||
t.Errorf("Expected first element to be 1, got %v", mixedArray[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(mixedArray) > 1 {
|
|
||||||
if str, ok := mixedArray[1].(string); !ok || str != "two" {
|
|
||||||
t.Errorf("Expected second element to be 'two', got %v", mixedArray[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestComplexTable tests handling of complex nested tables.
|
|
||||||
func TestComplexTable(t *testing.T) {
|
|
||||||
// Create a temporary config file with complex tables
|
|
||||||
content := `
|
|
||||||
-- Nested table structure
|
|
||||||
server = {
|
|
||||||
host = "localhost",
|
|
||||||
port = 8080,
|
|
||||||
settings = {
|
|
||||||
timeout = 30,
|
|
||||||
retries = 3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Table with mixed array and map elements
|
|
||||||
mixed_table = {
|
|
||||||
list = {1, 2, 3},
|
|
||||||
mapping = {
|
|
||||||
a = "apple",
|
|
||||||
b = "banana"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`
|
|
||||||
configFile := createTempLuaFile(t, content)
|
|
||||||
defer os.Remove(configFile)
|
|
||||||
|
|
||||||
// Load the config
|
|
||||||
cfg, err := Load(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test getting nested values
|
|
||||||
serverMap := cfg.GetMap("server")
|
|
||||||
if serverMap == nil {
|
|
||||||
t.Fatal("Expected server map to exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check first level values
|
|
||||||
if host, ok := serverMap["host"].(string); !ok || host != "localhost" {
|
|
||||||
t.Errorf("Expected server.host to be 'localhost', got %v", serverMap["host"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if port, ok := serverMap["port"].(float64); !ok || port != 8080 {
|
|
||||||
t.Errorf("Expected server.port to be 8080, got %v", serverMap["port"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check nested settings
|
|
||||||
settings, ok := serverMap["settings"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("Expected server.settings to be a map")
|
|
||||||
}
|
|
||||||
|
|
||||||
if timeout, ok := settings["timeout"].(float64); !ok || timeout != 30 {
|
|
||||||
t.Errorf("Expected server.settings.timeout to be 30, got %v", settings["timeout"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDefaultValues verifies default values work correctly when keys don't exist.
|
|
||||||
func TestDefaultValues(t *testing.T) {
|
|
||||||
// Create a temporary config file
|
|
||||||
content := `
|
|
||||||
-- Just one value
|
|
||||||
existing = "value"
|
|
||||||
`
|
|
||||||
configFile := createTempLuaFile(t, content)
|
|
||||||
defer os.Remove(configFile)
|
|
||||||
|
|
||||||
// Load the config
|
|
||||||
cfg, err := Load(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test defaults for non-existent keys
|
|
||||||
if val := cfg.GetString("nonexistent", "default"); val != "default" {
|
|
||||||
t.Errorf("Expected default string, got '%s'", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
if val := cfg.GetInt("nonexistent", 42); val != 42 {
|
|
||||||
t.Errorf("Expected default int 42, got %d", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
if val := cfg.GetFloat("nonexistent", 3.14); val != 3.14 {
|
|
||||||
t.Errorf("Expected default float 3.14, got %f", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
if val := cfg.GetBool("nonexistent", true); !val {
|
|
||||||
t.Errorf("Expected default bool true, got false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestModifyConfig tests the ability to modify configuration values.
|
|
||||||
func TestModifyConfig(t *testing.T) {
|
|
||||||
// Create a config manually
|
|
||||||
cfg := New()
|
|
||||||
|
|
||||||
// Set some values
|
|
||||||
cfg.Set("host", "localhost")
|
|
||||||
cfg.Set("port", 8080)
|
|
||||||
|
|
||||||
// Verify the values were set
|
|
||||||
if host := cfg.GetString("host", ""); host != "localhost" {
|
|
||||||
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
|
||||||
}
|
|
||||||
|
|
||||||
if port := cfg.GetInt("port", 0); port != 8080 {
|
|
||||||
t.Errorf("Expected port to be 8080, got %d", port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modify a value
|
|
||||||
cfg.Set("host", "127.0.0.1")
|
|
||||||
|
|
||||||
// Verify the change
|
|
||||||
if host := cfg.GetString("host", ""); host != "127.0.0.1" {
|
|
||||||
t.Errorf("Expected modified host to be '127.0.0.1', got '%s'", host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestTypeConversion tests the type conversion in getter methods.
|
|
||||||
func TestTypeConversion(t *testing.T) {
|
|
||||||
// Create a temporary config file with values that need conversion
|
|
||||||
content := `
|
|
||||||
-- Numbers that can be integers
|
|
||||||
int_as_float = 42.0
|
|
||||||
|
|
||||||
-- Floats that should remain floats
|
|
||||||
float_val = 3.14159
|
|
||||||
`
|
|
||||||
configFile := createTempLuaFile(t, content)
|
|
||||||
defer os.Remove(configFile)
|
|
||||||
|
|
||||||
// Load the config
|
|
||||||
cfg, err := Load(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test GetInt with a float value
|
|
||||||
if val := cfg.GetInt("int_as_float", 0); val != 42 {
|
|
||||||
t.Errorf("Expected int 42, got %d", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test GetFloat with an int value
|
|
||||||
if val := cfg.GetFloat("int_as_float", 0); val != 42.0 {
|
|
||||||
t.Errorf("Expected float 42.0, got %f", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test incorrect type handling
|
|
||||||
cfg.Set("string_val", "not a number")
|
|
||||||
|
|
||||||
if val := cfg.GetInt("string_val", 99); val != 99 {
|
|
||||||
t.Errorf("Expected default int 99 for string value, got %d", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
if val := cfg.GetFloat("string_val", 99.9); val != 99.9 {
|
|
||||||
t.Errorf("Expected default float 99.9 for string value, got %f", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
if val := cfg.GetBool("float_val", false); val != false {
|
|
||||||
t.Errorf("Expected default false for non-bool value, got true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to create a temporary Lua file with content
|
|
||||||
func createTempLuaFile(t *testing.T, content string) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
tempFile, err := os.CreateTemp("", "config-test-*.lua")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temp file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tempFile.WriteString(content); err != nil {
|
|
||||||
os.Remove(tempFile.Name())
|
|
||||||
t.Fatalf("Failed to write to temp file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tempFile.Close(); err != nil {
|
|
||||||
os.Remove(tempFile.Name())
|
|
||||||
t.Fatalf("Failed to close temp file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tempFile.Name()
|
|
||||||
}
|
|
|
@ -170,15 +170,32 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for HTTPResponse type
|
||||||
|
if httpResp, ok := result.(*runner.HTTPResponse); ok {
|
||||||
|
// Set response headers
|
||||||
|
for name, value := range httpResp.Headers {
|
||||||
|
w.Header().Set(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set status code
|
||||||
|
w.WriteHeader(httpResp.Status)
|
||||||
|
|
||||||
|
// Process the body based on its type
|
||||||
|
if httpResp.Body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result = httpResp.Body // Set result to body for processing below
|
||||||
|
}
|
||||||
|
|
||||||
switch res := result.(type) {
|
switch res := result.(type) {
|
||||||
case string:
|
case string:
|
||||||
// String result
|
// String result - plain text
|
||||||
w.Header().Set("Content-Type", contentTypePlain)
|
setContentTypeIfMissing(w, contentTypePlain)
|
||||||
w.Write([]byte(res))
|
w.Write([]byte(res))
|
||||||
|
default:
|
||||||
case map[string]any, []any:
|
// All other types - convert to JSON
|
||||||
// Table or array result - convert to JSON
|
setContentTypeIfMissing(w, contentTypeJSON)
|
||||||
w.Header().Set("Content-Type", contentTypeJSON)
|
|
||||||
data, err := json.Marshal(res)
|
data, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Failed to marshal response: %v", err)
|
log.Error("Failed to marshal response: %v", err)
|
||||||
|
@ -186,16 +203,11 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Write(data)
|
w.Write(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
func setContentTypeIfMissing(w http.ResponseWriter, contentType string) {
|
||||||
// Other result types - convert to JSON
|
if w.Header().Get("Content-Type") == "" {
|
||||||
w.Header().Set("Content-Type", contentTypeJSON)
|
w.Header().Set("Content-Type", contentType)
|
||||||
data, err := json.Marshal(result)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Failed to marshal response: %v", err)
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Write(data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,214 +0,0 @@
|
||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoggerLevels(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelInfo, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Debug should be below threshold
|
|
||||||
logger.Debug("This should not appear")
|
|
||||||
if buf.Len() > 0 {
|
|
||||||
t.Error("Debug message appeared when it should be filtered")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info and above should appear
|
|
||||||
logger.Info("Info message")
|
|
||||||
if !strings.Contains(buf.String(), "INFO") {
|
|
||||||
t.Errorf("Info message not logged, got: %q", buf.String())
|
|
||||||
}
|
|
||||||
buf.Reset()
|
|
||||||
|
|
||||||
logger.Warning("Warning message")
|
|
||||||
if !strings.Contains(buf.String(), "WARN") {
|
|
||||||
t.Errorf("Warning message not logged, got: %q", buf.String())
|
|
||||||
}
|
|
||||||
buf.Reset()
|
|
||||||
|
|
||||||
logger.Error("Error message")
|
|
||||||
if !strings.Contains(buf.String(), "ERROR") {
|
|
||||||
t.Errorf("Error message not logged, got: %q", buf.String())
|
|
||||||
}
|
|
||||||
buf.Reset()
|
|
||||||
|
|
||||||
// Test format strings
|
|
||||||
logger.Info("Count: %d", 42)
|
|
||||||
if !strings.Contains(buf.String(), "Count: 42") {
|
|
||||||
t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
|
|
||||||
}
|
|
||||||
buf.Reset()
|
|
||||||
|
|
||||||
// Test changing level
|
|
||||||
logger.SetLevel(LevelError)
|
|
||||||
logger.Info("This should not appear")
|
|
||||||
logger.Warning("This should not appear")
|
|
||||||
if buf.Len() > 0 {
|
|
||||||
t.Error("Messages below threshold appeared")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Error("Error should appear")
|
|
||||||
if !strings.Contains(buf.String(), "ERROR") {
|
|
||||||
t.Errorf("Error message not logged after level change, got: %q", buf.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerRateLimit(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelDebug, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Override max logs per second to something small for testing
|
|
||||||
logger.maxLogsPerSec = 5
|
|
||||||
logger.limitDuration = 1 * time.Second
|
|
||||||
|
|
||||||
// Send debug messages (should get limited)
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
logger.Debug("Debug message %d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error messages should always go through
|
|
||||||
logger.Error("Error message should appear")
|
|
||||||
|
|
||||||
content := buf.String()
|
|
||||||
|
|
||||||
// We should see some debug messages, then a warning about rate limiting,
|
|
||||||
// and finally the error message
|
|
||||||
if !strings.Contains(content, "Debug message 0") {
|
|
||||||
t.Error("First debug message should appear")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(content, "Rate limiting logger") {
|
|
||||||
t.Error("Rate limiting message should appear")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(content, "ERROR") {
|
|
||||||
t.Error("Error message should always appear despite rate limiting")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerConcurrency(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelDebug, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Increase log threshold for this test
|
|
||||||
logger.maxLogsPerSec = 1000
|
|
||||||
|
|
||||||
// Log a bunch of messages concurrently
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(n int) {
|
|
||||||
defer wg.Done()
|
|
||||||
logger.Info("Concurrent message %d", n)
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// Check logs were processed
|
|
||||||
content := buf.String()
|
|
||||||
if !strings.Contains(content, "Concurrent message") {
|
|
||||||
t.Error("Concurrent messages should appear")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerColors(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelInfo, true)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Test with color
|
|
||||||
logger.Info("Colored message")
|
|
||||||
|
|
||||||
content := buf.String()
|
|
||||||
t.Logf("Colored output: %q", content) // Print actual output for diagnosis
|
|
||||||
if !strings.Contains(content, "\033[") {
|
|
||||||
t.Errorf("Color codes not present when enabled, got: %q", content)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf.Reset()
|
|
||||||
logger.DisableColors()
|
|
||||||
logger.Info("Non-colored message")
|
|
||||||
|
|
||||||
content = buf.String()
|
|
||||||
if strings.Contains(content, "\033[") {
|
|
||||||
t.Errorf("Color codes present when disabled, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultLogger(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
SetOutput(&buf)
|
|
||||||
|
|
||||||
Info("Test default logger")
|
|
||||||
|
|
||||||
content := buf.String()
|
|
||||||
if !strings.Contains(content, "INFO") {
|
|
||||||
t.Errorf("Default logger not working, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLogger(b *testing.B) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelInfo, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Set very high threshold to avoid rate limiting during benchmark
|
|
||||||
logger.maxLogsPerSec = int64(b.N + 1)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
logger.Info("Benchmark message %d", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLoggerWithRateLimit(b *testing.B) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelDebug, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Set threshold to allow about 10% of messages through
|
|
||||||
logger.maxLogsPerSec = int64(b.N / 10)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
logger.Debug("Benchmark message %d", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLoggerParallel(b *testing.B) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelDebug, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Set very high threshold to avoid rate limiting during benchmark
|
|
||||||
logger.maxLogsPerSec = int64(b.N + 1)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
|
||||||
i := 0
|
|
||||||
for pb.Next() {
|
|
||||||
logger.Debug("Parallel benchmark message %d", i)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkProductionLevels(b *testing.B) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelWarning, false) // Only log warnings and above
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// This should be filtered out before any processing
|
|
||||||
logger.Debug("Debug message that won't be logged %d", i)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -103,18 +103,25 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
return nil, ErrInitFailed
|
return nil, ErrInitFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize HTTP module BEFORE sandbox setup
|
||||||
|
httpInit := HTTPModuleInitFunc()
|
||||||
|
if err := httpInit(state); err != nil {
|
||||||
|
state.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up sandbox AFTER HTTP module is initialized
|
||||||
|
if err := runner.sandbox.Setup(state); err != nil {
|
||||||
|
state.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
|
||||||
// Preload all modules into package.loaded
|
// Preload all modules into package.loaded
|
||||||
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
|
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
|
||||||
state.Close()
|
state.Close()
|
||||||
return nil, errors.New("failed to preload modules")
|
return nil, errors.New("failed to preload modules")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up sandbox
|
|
||||||
if err := runner.sandbox.Setup(state); err != nil {
|
|
||||||
state.Close()
|
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run init function if provided
|
// Run init function if provided
|
||||||
if runner.initFunc != nil {
|
if runner.initFunc != nil {
|
||||||
if err := runner.initFunc(state); err != nil {
|
if err := runner.initFunc(state); err != nil {
|
||||||
|
|
|
@ -1,373 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Helper function to create bytecode for testing
|
|
||||||
func createTestBytecode(t *testing.T, code string) []byte {
|
|
||||||
state := luajit.New()
|
|
||||||
if state == nil {
|
|
||||||
t.Fatal("Failed to create Lua state")
|
|
||||||
}
|
|
||||||
defer state.Close()
|
|
||||||
|
|
||||||
bytecode, err := state.CompileBytecode(code, "test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to compile test bytecode: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytecode
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunnerBasic(t *testing.T) {
|
|
||||||
runner, err := NewRunner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
defer runner.Close()
|
|
||||||
|
|
||||||
bytecode := createTestBytecode(t, "return 42")
|
|
||||||
|
|
||||||
result, err := runner.Run(bytecode, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
num, ok := result.(float64)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Expected float64 result, got %T", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
if num != 42 {
|
|
||||||
t.Errorf("Expected 42, got %f", num)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunnerWithContext(t *testing.T) {
|
|
||||||
runner, err := NewRunner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
defer runner.Close()
|
|
||||||
|
|
||||||
bytecode := createTestBytecode(t, `
|
|
||||||
return {
|
|
||||||
num = ctx.number,
|
|
||||||
str = ctx.text,
|
|
||||||
flag = ctx.enabled,
|
|
||||||
list = {ctx.table[1], ctx.table[2], ctx.table[3]},
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
|
|
||||||
execCtx := NewContext()
|
|
||||||
execCtx.Set("number", 42.5)
|
|
||||||
execCtx.Set("text", "hello")
|
|
||||||
execCtx.Set("enabled", true)
|
|
||||||
execCtx.Set("table", []float64{10, 20, 30})
|
|
||||||
|
|
||||||
result, err := runner.Run(bytecode, execCtx, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run job: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Result should be a map
|
|
||||||
resultMap, ok := result.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Expected map result, got %T", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check values
|
|
||||||
if resultMap["num"] != 42.5 {
|
|
||||||
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
|
|
||||||
}
|
|
||||||
if resultMap["str"] != "hello" {
|
|
||||||
t.Errorf("Expected str=hello, got %v", resultMap["str"])
|
|
||||||
}
|
|
||||||
if resultMap["flag"] != true {
|
|
||||||
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
|
|
||||||
}
|
|
||||||
|
|
||||||
arr, ok := resultMap["list"].([]float64)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Expected []float64, got %T", resultMap["list"])
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := []float64{10, 20, 30}
|
|
||||||
for i, v := range expected {
|
|
||||||
if arr[i] != v {
|
|
||||||
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunnerWithTimeout(t *testing.T) {
|
|
||||||
runner, err := NewRunner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
defer runner.Close()
|
|
||||||
|
|
||||||
// Create bytecode that sleeps
|
|
||||||
bytecode := createTestBytecode(t, `
|
|
||||||
-- Sleep for 500ms
|
|
||||||
local start = os.time()
|
|
||||||
while os.difftime(os.time(), start) < 0.5 do end
|
|
||||||
return "done"
|
|
||||||
`)
|
|
||||||
|
|
||||||
// Test with timeout that should succeed
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
result, err := runner.RunWithContext(ctx, bytecode, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
|
|
||||||
}
|
|
||||||
if result != "done" {
|
|
||||||
t.Errorf("Expected 'done', got %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with timeout that should fail
|
|
||||||
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err = runner.RunWithContext(ctx, bytecode, nil, "")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected timeout error, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSandboxIsolation(t *testing.T) {
|
|
||||||
runner, err := NewRunner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
defer runner.Close()
|
|
||||||
|
|
||||||
// Create a script that tries to modify a global variable
|
|
||||||
bytecode1 := createTestBytecode(t, `
|
|
||||||
-- Set a "global" variable
|
|
||||||
my_global = "test value"
|
|
||||||
return true
|
|
||||||
`)
|
|
||||||
|
|
||||||
_, err = runner.Run(bytecode1, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to execute first script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now try to access that variable from another script
|
|
||||||
bytecode2 := createTestBytecode(t, `
|
|
||||||
-- Try to access the previously set global
|
|
||||||
return my_global ~= nil
|
|
||||||
`)
|
|
||||||
|
|
||||||
result, err := runner.Run(bytecode2, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to execute second script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The variable should not be accessible (sandbox isolation)
|
|
||||||
if result.(bool) {
|
|
||||||
t.Errorf("Expected sandbox isolation, but global variable was accessible")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunnerWithInit(t *testing.T) {
|
|
||||||
// Define an init function that registers a simple "math" module
|
|
||||||
mathInit := func(state *luajit.State) error {
|
|
||||||
// Register the "add" function
|
|
||||||
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
|
|
||||||
a := s.ToNumber(1)
|
|
||||||
b := s.ToNumber(2)
|
|
||||||
s.PushNumber(a + b)
|
|
||||||
return 1 // Return one result
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register a whole module
|
|
||||||
mathFuncs := map[string]luajit.GoFunction{
|
|
||||||
"multiply": func(s *luajit.State) int {
|
|
||||||
a := s.ToNumber(1)
|
|
||||||
b := s.ToNumber(2)
|
|
||||||
s.PushNumber(a * b)
|
|
||||||
return 1
|
|
||||||
},
|
|
||||||
"subtract": func(s *luajit.State) int {
|
|
||||||
a := s.ToNumber(1)
|
|
||||||
b := s.ToNumber(2)
|
|
||||||
s.PushNumber(a - b)
|
|
||||||
return 1
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return RegisterModule(state, "math2", mathFuncs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a runner with our init function
|
|
||||||
runner, err := NewRunner(WithInitFunc(mathInit))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
defer runner.Close()
|
|
||||||
|
|
||||||
// Test the add function
|
|
||||||
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
|
||||||
result1, err := runner.Run(bytecode1, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to call add function: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
num1, ok := result1.(float64)
|
|
||||||
if !ok || num1 != 12 {
|
|
||||||
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the math2 module
|
|
||||||
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
|
||||||
result2, err := runner.Run(bytecode2, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to call math2.multiply: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
num2, ok := result2.(float64)
|
|
||||||
if !ok || num2 != 48 {
|
|
||||||
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrentExecution(t *testing.T) {
|
|
||||||
const jobs = 20
|
|
||||||
|
|
||||||
runner, err := NewRunner(WithBufferSize(20))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
defer runner.Close()
|
|
||||||
|
|
||||||
// Create bytecode that returns its input
|
|
||||||
bytecode := createTestBytecode(t, "return ctx.n")
|
|
||||||
|
|
||||||
// Run multiple jobs concurrently
|
|
||||||
results := make(chan int, jobs)
|
|
||||||
for i := 0; i < jobs; i++ {
|
|
||||||
i := i // Capture loop variable
|
|
||||||
go func() {
|
|
||||||
execCtx := NewContext()
|
|
||||||
execCtx.Set("n", float64(i))
|
|
||||||
|
|
||||||
result, err := runner.Run(bytecode, execCtx, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Job %d failed: %v", i, err)
|
|
||||||
results <- -1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
num, ok := result.(float64)
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("Job %d: expected float64, got %T", i, result)
|
|
||||||
results <- -1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
results <- int(num)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect results
|
|
||||||
seen := make(map[int]bool)
|
|
||||||
for i := 0; i < jobs; i++ {
|
|
||||||
result := <-results
|
|
||||||
if result != -1 {
|
|
||||||
seen[result] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify all jobs were processed
|
|
||||||
if len(seen) != jobs {
|
|
||||||
t.Errorf("Expected %d unique results, got %d", jobs, len(seen))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunnerClose(t *testing.T) {
|
|
||||||
runner, err := NewRunner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Submit a job to verify runner works
|
|
||||||
bytecode := createTestBytecode(t, "return 42")
|
|
||||||
_, err = runner.Run(bytecode, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run job: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close
|
|
||||||
if err := runner.Close(); err != nil {
|
|
||||||
t.Errorf("Close failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run after close should fail
|
|
||||||
_, err = runner.Run(bytecode, nil, "")
|
|
||||||
if err != ErrRunnerClosed {
|
|
||||||
t.Errorf("Expected ErrRunnerClosed, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second close should return error
|
|
||||||
if err := runner.Close(); err != ErrRunnerClosed {
|
|
||||||
t.Errorf("Expected ErrRunnerClosed on second close, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorHandling(t *testing.T) {
|
|
||||||
runner, err := NewRunner()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create runner: %v", err)
|
|
||||||
}
|
|
||||||
defer runner.Close()
|
|
||||||
|
|
||||||
// Test invalid bytecode
|
|
||||||
_, err = runner.Run([]byte("not valid bytecode"), nil, "")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error for invalid bytecode, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test Lua runtime error
|
|
||||||
bytecode := createTestBytecode(t, `
|
|
||||||
error("intentional error")
|
|
||||||
return true
|
|
||||||
`)
|
|
||||||
|
|
||||||
_, err = runner.Run(bytecode, nil, "")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error from Lua error() call, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with nil context
|
|
||||||
bytecode = createTestBytecode(t, "return ctx == nil")
|
|
||||||
result, err := runner.Run(bytecode, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Unexpected error with nil context: %v", err)
|
|
||||||
}
|
|
||||||
if result.(bool) != true {
|
|
||||||
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test invalid context value
|
|
||||||
execCtx := NewContext()
|
|
||||||
execCtx.Set("param", complex128(1+2i)) // Unsupported type
|
|
||||||
|
|
||||||
bytecode = createTestBytecode(t, "return ctx.param")
|
|
||||||
_, err = runner.Run(bytecode, execCtx, "")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error for unsupported context value type, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,281 +0,0 @@
|
||||||
package runner_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
"git.sharkk.net/Sky/Moonshark/core/runner"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRequireFunctionality(t *testing.T) {
|
|
||||||
// Create temporary directories for test
|
|
||||||
tempDir, err := os.MkdirTemp("", "luarunner-test-*")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temp directory: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempDir)
|
|
||||||
|
|
||||||
// Create script directory and lib directory
|
|
||||||
scriptDir := filepath.Join(tempDir, "scripts")
|
|
||||||
libDir := filepath.Join(tempDir, "libs")
|
|
||||||
|
|
||||||
if err := os.Mkdir(scriptDir, 0755); err != nil {
|
|
||||||
t.Fatalf("Failed to create script directory: %v", err)
|
|
||||||
}
|
|
||||||
if err := os.Mkdir(libDir, 0755); err != nil {
|
|
||||||
t.Fatalf("Failed to create lib directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a module in the lib directory
|
|
||||||
libModule := `
|
|
||||||
local lib = {}
|
|
||||||
|
|
||||||
function lib.add(a, b)
|
|
||||||
return a + b
|
|
||||||
end
|
|
||||||
|
|
||||||
function lib.mul(a, b)
|
|
||||||
return a * b
|
|
||||||
end
|
|
||||||
|
|
||||||
return lib
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(filepath.Join(libDir, "mathlib.lua"), []byte(libModule), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write lib module: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a helper module in the script directory
|
|
||||||
helperModule := `
|
|
||||||
local helper = {}
|
|
||||||
|
|
||||||
function helper.square(x)
|
|
||||||
return x * x
|
|
||||||
end
|
|
||||||
|
|
||||||
function helper.cube(x)
|
|
||||||
return x * x * x
|
|
||||||
end
|
|
||||||
|
|
||||||
return helper
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(filepath.Join(scriptDir, "helper.lua"), []byte(helperModule), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write helper module: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create main script that requires both modules
|
|
||||||
mainScript := `
|
|
||||||
-- Require from the same directory
|
|
||||||
local helper = require("helper")
|
|
||||||
|
|
||||||
-- Require from the lib directory
|
|
||||||
local mathlib = require("mathlib")
|
|
||||||
|
|
||||||
-- Use both modules
|
|
||||||
local result = {
|
|
||||||
add = mathlib.add(10, 5),
|
|
||||||
mul = mathlib.mul(10, 5),
|
|
||||||
square = helper.square(5),
|
|
||||||
cube = helper.cube(3)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
`
|
|
||||||
mainScriptPath := filepath.Join(scriptDir, "main.lua")
|
|
||||||
if err := os.WriteFile(mainScriptPath, []byte(mainScript), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write main script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create LuaRunner
|
|
||||||
luaRunner, err := runner.NewRunner(
|
|
||||||
runner.WithScriptDir(scriptDir),
|
|
||||||
runner.WithLibDirs(libDir),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create LuaRunner: %v", err)
|
|
||||||
}
|
|
||||||
defer luaRunner.Close()
|
|
||||||
|
|
||||||
// Compile the main script
|
|
||||||
state := luajit.New()
|
|
||||||
if state == nil {
|
|
||||||
t.Fatal("Failed to create Lua state")
|
|
||||||
}
|
|
||||||
defer state.Close()
|
|
||||||
|
|
||||||
bytecode, err := state.CompileBytecode(mainScript, "main.lua")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to compile script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the script
|
|
||||||
result, err := luaRunner.Run(bytecode, nil, mainScriptPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check result
|
|
||||||
resultMap, ok := result.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Expected map result, got %T", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate results
|
|
||||||
expectedResults := map[string]float64{
|
|
||||||
"add": 15, // 10 + 5
|
|
||||||
"mul": 50, // 10 * 5
|
|
||||||
"square": 25, // 5^2
|
|
||||||
"cube": 27, // 3^3
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, expected := range expectedResults {
|
|
||||||
if val, ok := resultMap[key]; !ok {
|
|
||||||
t.Errorf("Missing result key: %s", key)
|
|
||||||
} else if val != expected {
|
|
||||||
t.Errorf("For %s: expected %.1f, got %v", key, expected, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequireSecurityBoundaries(t *testing.T) {
|
|
||||||
// Create temporary directories for test
|
|
||||||
tempDir, err := os.MkdirTemp("", "luarunner-security-test-*")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temp directory: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempDir)
|
|
||||||
|
|
||||||
// Create script directory and lib directory
|
|
||||||
scriptDir := filepath.Join(tempDir, "scripts")
|
|
||||||
libDir := filepath.Join(tempDir, "libs")
|
|
||||||
secretDir := filepath.Join(tempDir, "secret")
|
|
||||||
|
|
||||||
err = os.MkdirAll(scriptDir, 0755)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create script directory: %v", err)
|
|
||||||
}
|
|
||||||
err = os.MkdirAll(libDir, 0755)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create lib directory: %v", err)
|
|
||||||
}
|
|
||||||
err = os.MkdirAll(secretDir, 0755)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create secret directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a "secret" module that should not be accessible
|
|
||||||
secretModule := `
|
|
||||||
local secret = "TOP SECRET"
|
|
||||||
return secret
|
|
||||||
`
|
|
||||||
err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to write secret module: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a normal module in lib
|
|
||||||
normalModule := `return "normal module"`
|
|
||||||
err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to write normal module: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a compile-and-run function that takes care of both compilation and execution
|
|
||||||
compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) {
|
|
||||||
// Compile
|
|
||||||
state := luajit.New()
|
|
||||||
if state == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
defer state.Close()
|
|
||||||
|
|
||||||
bytecode, err := state.CompileBytecode(scriptText, scriptName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create and configure a new runner each time
|
|
||||||
r, err := runner.NewRunner(
|
|
||||||
runner.WithScriptDir(scriptDir),
|
|
||||||
runner.WithLibDirs(libDir),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer r.Close()
|
|
||||||
|
|
||||||
// Run
|
|
||||||
return r.Run(bytecode, nil, scriptPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that normal require works
|
|
||||||
normalScript := `
|
|
||||||
local normal = require("normal")
|
|
||||||
return normal
|
|
||||||
`
|
|
||||||
normalPath := filepath.Join(scriptDir, "normal_test.lua")
|
|
||||||
err = os.WriteFile(normalPath, []byte(normalScript), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to write normal script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := compileAndRun(normalScript, "normal_test.lua", normalPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to run normal script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != "normal module" {
|
|
||||||
t.Errorf("Expected 'normal module', got %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test path traversal attempts
|
|
||||||
pathTraversalTests := []struct {
|
|
||||||
name string
|
|
||||||
script string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Direct path traversal",
|
|
||||||
script: `
|
|
||||||
-- Try path traversal
|
|
||||||
local secret = require("../secret/secret")
|
|
||||||
return secret ~= nil
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Double dot traversal",
|
|
||||||
script: `
|
|
||||||
local secret = require("..secret.secret")
|
|
||||||
return secret ~= nil
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Absolute path traversal",
|
|
||||||
script: `
|
|
||||||
local secret = require("` + filepath.Join(secretDir, "secret") + `")
|
|
||||||
return secret ~= nil
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range pathTraversalTests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
|
|
||||||
err := os.WriteFile(scriptPath, []byte(tt.script), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to write test script: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath)
|
|
||||||
// If there's an error, that's expected and good
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no error, then the script should have returned false (couldn't get the module)
|
|
||||||
if result == true {
|
|
||||||
t.Errorf("Security breach! Script was able to access restricted module")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -12,10 +12,12 @@ type Sandbox struct {
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox
|
// NewSandbox creates a new sandbox
|
||||||
func NewSandbox() *Sandbox {
|
func NewSandbox() *Sandbox {
|
||||||
return &Sandbox{
|
s := &Sandbox{
|
||||||
modules: make(map[string]any),
|
modules: make(map[string]any),
|
||||||
initialized: false,
|
initialized: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
// AddModule adds a module to the sandbox environment
|
||||||
|
@ -78,6 +80,9 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
preload = package.preload
|
preload = package.preload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
-- Add HTTP module explicitly to the base environment
|
||||||
|
base.http = http
|
||||||
|
|
||||||
-- Add registered custom modules
|
-- Add registered custom modules
|
||||||
if __sandbox_modules then
|
if __sandbox_modules then
|
||||||
for name, mod in pairs(__sandbox_modules) do
|
for name, mod in pairs(__sandbox_modules) do
|
||||||
|
@ -90,6 +95,9 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
__env_system.initialized = true
|
__env_system.initialized = true
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Global variable for tracking current environment
|
||||||
|
__last_env = nil
|
||||||
|
|
||||||
-- Fast environment creation with pre-allocation
|
-- Fast environment creation with pre-allocation
|
||||||
function __get_sandbox_env(ctx)
|
function __get_sandbox_env(ctx)
|
||||||
local env
|
local env
|
||||||
|
@ -101,6 +109,8 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
|
|
||||||
-- Clear any previous context
|
-- Clear any previous context
|
||||||
env.ctx = ctx or nil
|
env.ctx = ctx or nil
|
||||||
|
-- Clear any previous response
|
||||||
|
env._response = nil
|
||||||
else
|
else
|
||||||
-- Create new environment with metatable inheritance
|
-- Create new environment with metatable inheritance
|
||||||
env = setmetatable({}, {
|
env = setmetatable({}, {
|
||||||
|
@ -118,6 +128,9 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Store reference to current environment
|
||||||
|
__last_env = env
|
||||||
|
|
||||||
return env
|
return env
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -127,6 +140,7 @@ func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
if __env_system.pool_size < __env_system.max_pool_size then
|
if __env_system.pool_size < __env_system.max_pool_size then
|
||||||
-- Clear context reference to avoid memory leaks
|
-- Clear context reference to avoid memory leaks
|
||||||
env.ctx = nil
|
env.ctx = nil
|
||||||
|
-- Don't clear response data - we need it for extraction
|
||||||
|
|
||||||
-- Add to pool
|
-- Add to pool
|
||||||
table.insert(__env_system.env_pool, env)
|
table.insert(__env_system.env_pool, env)
|
||||||
|
@ -239,7 +253,14 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
|
||||||
|
|
||||||
// Get result
|
// Get result
|
||||||
result, err := state.ToValue(-1)
|
result, err := state.ToValue(-1)
|
||||||
state.Pop(1)
|
state.Pop(1) // Pop result
|
||||||
|
|
||||||
|
// Check if HTTP response was set
|
||||||
|
httpResponse, hasHTTPResponse := GetHTTPResponse(state)
|
||||||
|
if hasHTTPResponse {
|
||||||
|
httpResponse.Body = result
|
||||||
|
return httpResponse, err
|
||||||
|
}
|
||||||
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user