Compare commits

..

No commits in common. "0259d1a135aac788980925b07056a57fd08f023a" and "7bc5194b10c90f9b0889cbf50426398bac0e29fc" have entirely different histories.

7 changed files with 1295 additions and 116 deletions

351
core/config/config_test.go Normal file
View File

@ -0,0 +1,351 @@
package config
import (
"os"
"reflect"
"testing"
)
// TestLoad verifies we can successfully load configuration values from a Lua file.
func TestLoad(t *testing.T) {
// Create a temporary config file
content := `
-- Basic configuration values
host = "localhost"
port = 8080
debug = true
pi = 3.14159
`
configFile := createTempLuaFile(t, content)
defer os.Remove(configFile)
// Load the config
cfg, err := Load(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Verify values were loaded correctly
if host := cfg.GetString("host", ""); host != "localhost" {
t.Errorf("Expected host to be 'localhost', got '%s'", host)
}
if port := cfg.GetInt("port", 0); port != 8080 {
t.Errorf("Expected port to be 8080, got %d", port)
}
if debug := cfg.GetBool("debug", false); !debug {
t.Errorf("Expected debug to be true")
}
if pi := cfg.GetFloat("pi", 0); pi != 3.14159 {
t.Errorf("Expected pi to be 3.14159, got %f", pi)
}
}
// TestLoadErrors ensures the package properly handles loading errors.
func TestLoadErrors(t *testing.T) {
// Test with non-existent file
_, err := Load("nonexistent.lua")
if err == nil {
t.Error("Expected error when loading non-existent file, got nil")
}
// Test with invalid Lua
content := `
-- This is invalid Lua
host = "localhost
port = 8080)
`
configFile := createTempLuaFile(t, content)
defer os.Remove(configFile)
_, err = Load(configFile)
if err == nil {
t.Error("Expected error when loading invalid Lua, got nil")
}
}
// TestLocalVsGlobal verifies only global variables are exported, not locals.
func TestLocalVsGlobal(t *testing.T) {
// Create a temporary config file with both local and global variables
content := `
-- Local variables should not be exported
local local_var = "hidden"
-- Global variables should be exported
global_var = "visible"
-- A function that uses both
function test_func()
return local_var .. " " .. global_var
end
`
configFile := createTempLuaFile(t, content)
defer os.Remove(configFile)
// Load the config
cfg, err := Load(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Check that global_var exists
if globalVar := cfg.GetString("global_var", ""); globalVar != "visible" {
t.Errorf("Expected global_var to be 'visible', got '%s'", globalVar)
}
// Check that local_var does not exist
if localVar := cfg.GetString("local_var", "default"); localVar != "default" {
t.Errorf("Expected local_var to use default, got '%s'", localVar)
}
// Check that functions are not exported
if val := cfg.Get("test_func"); val != nil {
t.Errorf("Expected function to not be exported, got %v", val)
}
}
// TestArrayHandling verifies correct handling of Lua arrays.
func TestArrayHandling(t *testing.T) {
// Create a temporary config file with arrays
content := `
-- Numeric array
numbers = {10, 20, 30, 40, 50}
-- String array
strings = {"apple", "banana", "cherry"}
-- Mixed array
mixed = {1, "two", true, 4.5}
`
configFile := createTempLuaFile(t, content)
defer os.Remove(configFile)
// Load the config
cfg, err := Load(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Test GetIntArray
intArray := cfg.GetIntArray("numbers")
expectedInts := []int{10, 20, 30, 40, 50}
if !reflect.DeepEqual(intArray, expectedInts) {
t.Errorf("Expected int array %v, got %v", expectedInts, intArray)
}
// Test GetStringArray
strArray := cfg.GetStringArray("strings")
expectedStrs := []string{"apple", "banana", "cherry"}
if !reflect.DeepEqual(strArray, expectedStrs) {
t.Errorf("Expected string array %v, got %v", expectedStrs, strArray)
}
// Test GetArray with mixed types
mixedArray := cfg.GetArray("mixed")
if len(mixedArray) != 4 {
t.Errorf("Expected mixed array length 4, got %d", len(mixedArray))
// Skip further tests if array is empty to avoid panic
return
}
// Check types - carefully to avoid panics
if len(mixedArray) > 0 {
if num, ok := mixedArray[0].(float64); !ok || num != 1 {
t.Errorf("Expected first element to be 1, got %v", mixedArray[0])
}
}
if len(mixedArray) > 1 {
if str, ok := mixedArray[1].(string); !ok || str != "two" {
t.Errorf("Expected second element to be 'two', got %v", mixedArray[1])
}
}
}
// TestComplexTable tests handling of complex nested tables.
func TestComplexTable(t *testing.T) {
// Create a temporary config file with complex tables
content := `
-- Nested table structure
server = {
host = "localhost",
port = 8080,
settings = {
timeout = 30,
retries = 3
}
}
-- Table with mixed array and map elements
mixed_table = {
list = {1, 2, 3},
mapping = {
a = "apple",
b = "banana"
}
}
`
configFile := createTempLuaFile(t, content)
defer os.Remove(configFile)
// Load the config
cfg, err := Load(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Test getting nested values
serverMap := cfg.GetMap("server")
if serverMap == nil {
t.Fatal("Expected server map to exist")
}
// Check first level values
if host, ok := serverMap["host"].(string); !ok || host != "localhost" {
t.Errorf("Expected server.host to be 'localhost', got %v", serverMap["host"])
}
if port, ok := serverMap["port"].(float64); !ok || port != 8080 {
t.Errorf("Expected server.port to be 8080, got %v", serverMap["port"])
}
// Check nested settings
settings, ok := serverMap["settings"].(map[string]any)
if !ok {
t.Fatal("Expected server.settings to be a map")
}
if timeout, ok := settings["timeout"].(float64); !ok || timeout != 30 {
t.Errorf("Expected server.settings.timeout to be 30, got %v", settings["timeout"])
}
}
// TestDefaultValues verifies default values work correctly when keys don't exist.
func TestDefaultValues(t *testing.T) {
// Create a temporary config file
content := `
-- Just one value
existing = "value"
`
configFile := createTempLuaFile(t, content)
defer os.Remove(configFile)
// Load the config
cfg, err := Load(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Test defaults for non-existent keys
if val := cfg.GetString("nonexistent", "default"); val != "default" {
t.Errorf("Expected default string, got '%s'", val)
}
if val := cfg.GetInt("nonexistent", 42); val != 42 {
t.Errorf("Expected default int 42, got %d", val)
}
if val := cfg.GetFloat("nonexistent", 3.14); val != 3.14 {
t.Errorf("Expected default float 3.14, got %f", val)
}
if val := cfg.GetBool("nonexistent", true); !val {
t.Errorf("Expected default bool true, got false")
}
}
// TestModifyConfig tests the ability to modify configuration values.
func TestModifyConfig(t *testing.T) {
// Create a config manually
cfg := New()
// Set some values
cfg.Set("host", "localhost")
cfg.Set("port", 8080)
// Verify the values were set
if host := cfg.GetString("host", ""); host != "localhost" {
t.Errorf("Expected host to be 'localhost', got '%s'", host)
}
if port := cfg.GetInt("port", 0); port != 8080 {
t.Errorf("Expected port to be 8080, got %d", port)
}
// Modify a value
cfg.Set("host", "127.0.0.1")
// Verify the change
if host := cfg.GetString("host", ""); host != "127.0.0.1" {
t.Errorf("Expected modified host to be '127.0.0.1', got '%s'", host)
}
}
// TestTypeConversion tests the type conversion in getter methods.
func TestTypeConversion(t *testing.T) {
// Create a temporary config file with values that need conversion
content := `
-- Numbers that can be integers
int_as_float = 42.0
-- Floats that should remain floats
float_val = 3.14159
`
configFile := createTempLuaFile(t, content)
defer os.Remove(configFile)
// Load the config
cfg, err := Load(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Test GetInt with a float value
if val := cfg.GetInt("int_as_float", 0); val != 42 {
t.Errorf("Expected int 42, got %d", val)
}
// Test GetFloat with an int value
if val := cfg.GetFloat("int_as_float", 0); val != 42.0 {
t.Errorf("Expected float 42.0, got %f", val)
}
// Test incorrect type handling
cfg.Set("string_val", "not a number")
if val := cfg.GetInt("string_val", 99); val != 99 {
t.Errorf("Expected default int 99 for string value, got %d", val)
}
if val := cfg.GetFloat("string_val", 99.9); val != 99.9 {
t.Errorf("Expected default float 99.9 for string value, got %f", val)
}
if val := cfg.GetBool("float_val", false); val != false {
t.Errorf("Expected default false for non-bool value, got true")
}
}
// Helper function to create a temporary Lua file with content
func createTempLuaFile(t *testing.T, content string) string {
t.Helper()
tempFile, err := os.CreateTemp("", "config-test-*.lua")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
if _, err := tempFile.WriteString(content); err != nil {
os.Remove(tempFile.Name())
t.Fatalf("Failed to write to temp file: %v", err)
}
if err := tempFile.Close(); err != nil {
os.Remove(tempFile.Name())
t.Fatalf("Failed to close temp file: %v", err)
}
return tempFile.Name()
}

View File

@ -170,32 +170,15 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
return
}
// Check for HTTPResponse type
if httpResp, ok := result.(*runner.HTTPResponse); ok {
// Set response headers
for name, value := range httpResp.Headers {
w.Header().Set(name, value)
}
// Set status code
w.WriteHeader(httpResp.Status)
// Process the body based on its type
if httpResp.Body == nil {
return
}
result = httpResp.Body // Set result to body for processing below
}
switch res := result.(type) {
case string:
// String result - plain text
setContentTypeIfMissing(w, contentTypePlain)
// String result
w.Header().Set("Content-Type", contentTypePlain)
w.Write([]byte(res))
default:
// All other types - convert to JSON
setContentTypeIfMissing(w, contentTypeJSON)
case map[string]any, []any:
// Table or array result - convert to JSON
w.Header().Set("Content-Type", contentTypeJSON)
data, err := json.Marshal(res)
if err != nil {
log.Error("Failed to marshal response: %v", err)
@ -203,11 +186,16 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
return
}
w.Write(data)
default:
// Other result types - convert to JSON
w.Header().Set("Content-Type", contentTypeJSON)
data, err := json.Marshal(result)
if err != nil {
log.Error("Failed to marshal response: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Write(data)
}
}
func setContentTypeIfMissing(w http.ResponseWriter, contentType string) {
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", contentType)
}
}

214
core/logger/logger_test.go Normal file
View File

@ -0,0 +1,214 @@
package logger
import (
"bytes"
"strings"
"sync"
"testing"
"time"
)
func TestLoggerLevels(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelInfo, false)
logger.SetOutput(&buf)
// Debug should be below threshold
logger.Debug("This should not appear")
if buf.Len() > 0 {
t.Error("Debug message appeared when it should be filtered")
}
// Info and above should appear
logger.Info("Info message")
if !strings.Contains(buf.String(), "INFO") {
t.Errorf("Info message not logged, got: %q", buf.String())
}
buf.Reset()
logger.Warning("Warning message")
if !strings.Contains(buf.String(), "WARN") {
t.Errorf("Warning message not logged, got: %q", buf.String())
}
buf.Reset()
logger.Error("Error message")
if !strings.Contains(buf.String(), "ERROR") {
t.Errorf("Error message not logged, got: %q", buf.String())
}
buf.Reset()
// Test format strings
logger.Info("Count: %d", 42)
if !strings.Contains(buf.String(), "Count: 42") {
t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
}
buf.Reset()
// Test changing level
logger.SetLevel(LevelError)
logger.Info("This should not appear")
logger.Warning("This should not appear")
if buf.Len() > 0 {
t.Error("Messages below threshold appeared")
}
logger.Error("Error should appear")
if !strings.Contains(buf.String(), "ERROR") {
t.Errorf("Error message not logged after level change, got: %q", buf.String())
}
}
func TestLoggerRateLimit(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelDebug, false)
logger.SetOutput(&buf)
// Override max logs per second to something small for testing
logger.maxLogsPerSec = 5
logger.limitDuration = 1 * time.Second
// Send debug messages (should get limited)
for i := 0; i < 20; i++ {
logger.Debug("Debug message %d", i)
}
// Error messages should always go through
logger.Error("Error message should appear")
content := buf.String()
// We should see some debug messages, then a warning about rate limiting,
// and finally the error message
if !strings.Contains(content, "Debug message 0") {
t.Error("First debug message should appear")
}
if !strings.Contains(content, "Rate limiting logger") {
t.Error("Rate limiting message should appear")
}
if !strings.Contains(content, "ERROR") {
t.Error("Error message should always appear despite rate limiting")
}
}
func TestLoggerConcurrency(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelDebug, false)
logger.SetOutput(&buf)
// Increase log threshold for this test
logger.maxLogsPerSec = 1000
// Log a bunch of messages concurrently
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
logger.Info("Concurrent message %d", n)
}(i)
}
wg.Wait()
// Check logs were processed
content := buf.String()
if !strings.Contains(content, "Concurrent message") {
t.Error("Concurrent messages should appear")
}
}
func TestLoggerColors(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelInfo, true)
logger.SetOutput(&buf)
// Test with color
logger.Info("Colored message")
content := buf.String()
t.Logf("Colored output: %q", content) // Print actual output for diagnosis
if !strings.Contains(content, "\033[") {
t.Errorf("Color codes not present when enabled, got: %q", content)
}
buf.Reset()
logger.DisableColors()
logger.Info("Non-colored message")
content = buf.String()
if strings.Contains(content, "\033[") {
t.Errorf("Color codes present when disabled, got: %q", content)
}
}
func TestDefaultLogger(t *testing.T) {
var buf bytes.Buffer
SetOutput(&buf)
Info("Test default logger")
content := buf.String()
if !strings.Contains(content, "INFO") {
t.Errorf("Default logger not working, got: %q", content)
}
}
func BenchmarkLogger(b *testing.B) {
var buf bytes.Buffer
logger := New(LevelInfo, false)
logger.SetOutput(&buf)
// Set very high threshold to avoid rate limiting during benchmark
logger.maxLogsPerSec = int64(b.N + 1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Info("Benchmark message %d", i)
}
}
func BenchmarkLoggerWithRateLimit(b *testing.B) {
var buf bytes.Buffer
logger := New(LevelDebug, false)
logger.SetOutput(&buf)
// Set threshold to allow about 10% of messages through
logger.maxLogsPerSec = int64(b.N / 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Debug("Benchmark message %d", i)
}
}
func BenchmarkLoggerParallel(b *testing.B) {
var buf bytes.Buffer
logger := New(LevelDebug, false)
logger.SetOutput(&buf)
// Set very high threshold to avoid rate limiting during benchmark
logger.maxLogsPerSec = int64(b.N + 1)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
logger.Debug("Parallel benchmark message %d", i)
i++
}
})
}
func BenchmarkProductionLevels(b *testing.B) {
var buf bytes.Buffer
logger := New(LevelWarning, false) // Only log warnings and above
logger.SetOutput(&buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// This should be filtered out before any processing
logger.Debug("Debug message that won't be logged %d", i)
}
}

View File

@ -103,25 +103,18 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
return nil, ErrInitFailed
}
// Initialize HTTP module BEFORE sandbox setup
httpInit := HTTPModuleInitFunc()
if err := httpInit(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up sandbox AFTER HTTP module is initialized
if err := runner.sandbox.Setup(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Preload all modules into package.loaded
if err := runner.moduleLoader.PreloadAllModules(state); err != nil {
state.Close()
return nil, errors.New("failed to preload modules")
}
// Set up sandbox
if err := runner.sandbox.Setup(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Run init function if provided
if runner.initFunc != nil {
if err := runner.initFunc(state); err != nil {

View File

@ -0,0 +1,373 @@
package runner
import (
"context"
"testing"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Helper function to create bytecode for testing
func createTestBytecode(t *testing.T, code string) []byte {
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
bytecode, err := state.CompileBytecode(code, "test")
if err != nil {
t.Fatalf("Failed to compile test bytecode: %v", err)
}
return bytecode
}
func TestRunnerBasic(t *testing.T) {
runner, err := NewRunner()
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
defer runner.Close()
bytecode := createTestBytecode(t, "return 42")
result, err := runner.Run(bytecode, nil, "")
if err != nil {
t.Fatalf("Failed to run script: %v", err)
}
num, ok := result.(float64)
if !ok {
t.Fatalf("Expected float64 result, got %T", result)
}
if num != 42 {
t.Errorf("Expected 42, got %f", num)
}
}
func TestRunnerWithContext(t *testing.T) {
runner, err := NewRunner()
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
defer runner.Close()
bytecode := createTestBytecode(t, `
return {
num = ctx.number,
str = ctx.text,
flag = ctx.enabled,
list = {ctx.table[1], ctx.table[2], ctx.table[3]},
}
`)
execCtx := NewContext()
execCtx.Set("number", 42.5)
execCtx.Set("text", "hello")
execCtx.Set("enabled", true)
execCtx.Set("table", []float64{10, 20, 30})
result, err := runner.Run(bytecode, execCtx, "")
if err != nil {
t.Fatalf("Failed to run job: %v", err)
}
// Result should be a map
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Check values
if resultMap["num"] != 42.5 {
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
}
if resultMap["str"] != "hello" {
t.Errorf("Expected str=hello, got %v", resultMap["str"])
}
if resultMap["flag"] != true {
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
}
arr, ok := resultMap["list"].([]float64)
if !ok {
t.Fatalf("Expected []float64, got %T", resultMap["list"])
}
expected := []float64{10, 20, 30}
for i, v := range expected {
if arr[i] != v {
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
}
}
}
func TestRunnerWithTimeout(t *testing.T) {
runner, err := NewRunner()
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
defer runner.Close()
// Create bytecode that sleeps
bytecode := createTestBytecode(t, `
-- Sleep for 500ms
local start = os.time()
while os.difftime(os.time(), start) < 0.5 do end
return "done"
`)
// Test with timeout that should succeed
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
result, err := runner.RunWithContext(ctx, bytecode, nil, "")
if err != nil {
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
}
if result != "done" {
t.Errorf("Expected 'done', got %v", result)
}
// Test with timeout that should fail
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = runner.RunWithContext(ctx, bytecode, nil, "")
if err == nil {
t.Errorf("Expected timeout error, got nil")
}
}
func TestSandboxIsolation(t *testing.T) {
runner, err := NewRunner()
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
defer runner.Close()
// Create a script that tries to modify a global variable
bytecode1 := createTestBytecode(t, `
-- Set a "global" variable
my_global = "test value"
return true
`)
_, err = runner.Run(bytecode1, nil, "")
if err != nil {
t.Fatalf("Failed to execute first script: %v", err)
}
// Now try to access that variable from another script
bytecode2 := createTestBytecode(t, `
-- Try to access the previously set global
return my_global ~= nil
`)
result, err := runner.Run(bytecode2, nil, "")
if err != nil {
t.Fatalf("Failed to execute second script: %v", err)
}
// The variable should not be accessible (sandbox isolation)
if result.(bool) {
t.Errorf("Expected sandbox isolation, but global variable was accessible")
}
}
func TestRunnerWithInit(t *testing.T) {
// Define an init function that registers a simple "math" module
mathInit := func(state *luajit.State) error {
// Register the "add" function
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a + b)
return 1 // Return one result
})
if err != nil {
return err
}
// Register a whole module
mathFuncs := map[string]luajit.GoFunction{
"multiply": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a * b)
return 1
},
"subtract": func(s *luajit.State) int {
a := s.ToNumber(1)
b := s.ToNumber(2)
s.PushNumber(a - b)
return 1
},
}
return RegisterModule(state, "math2", mathFuncs)
}
// Create a runner with our init function
runner, err := NewRunner(WithInitFunc(mathInit))
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
defer runner.Close()
// Test the add function
bytecode1 := createTestBytecode(t, "return add(5, 7)")
result1, err := runner.Run(bytecode1, nil, "")
if err != nil {
t.Fatalf("Failed to call add function: %v", err)
}
num1, ok := result1.(float64)
if !ok || num1 != 12 {
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
}
// Test the math2 module
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
result2, err := runner.Run(bytecode2, nil, "")
if err != nil {
t.Fatalf("Failed to call math2.multiply: %v", err)
}
num2, ok := result2.(float64)
if !ok || num2 != 48 {
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
}
}
func TestConcurrentExecution(t *testing.T) {
const jobs = 20
runner, err := NewRunner(WithBufferSize(20))
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
defer runner.Close()
// Create bytecode that returns its input
bytecode := createTestBytecode(t, "return ctx.n")
// Run multiple jobs concurrently
results := make(chan int, jobs)
for i := 0; i < jobs; i++ {
i := i // Capture loop variable
go func() {
execCtx := NewContext()
execCtx.Set("n", float64(i))
result, err := runner.Run(bytecode, execCtx, "")
if err != nil {
t.Errorf("Job %d failed: %v", i, err)
results <- -1
return
}
num, ok := result.(float64)
if !ok {
t.Errorf("Job %d: expected float64, got %T", i, result)
results <- -1
return
}
results <- int(num)
}()
}
// Collect results
seen := make(map[int]bool)
for i := 0; i < jobs; i++ {
result := <-results
if result != -1 {
seen[result] = true
}
}
// Verify all jobs were processed
if len(seen) != jobs {
t.Errorf("Expected %d unique results, got %d", jobs, len(seen))
}
}
func TestRunnerClose(t *testing.T) {
runner, err := NewRunner()
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
// Submit a job to verify runner works
bytecode := createTestBytecode(t, "return 42")
_, err = runner.Run(bytecode, nil, "")
if err != nil {
t.Fatalf("Failed to run job: %v", err)
}
// Close
if err := runner.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}
// Run after close should fail
_, err = runner.Run(bytecode, nil, "")
if err != ErrRunnerClosed {
t.Errorf("Expected ErrRunnerClosed, got %v", err)
}
// Second close should return error
if err := runner.Close(); err != ErrRunnerClosed {
t.Errorf("Expected ErrRunnerClosed on second close, got %v", err)
}
}
func TestErrorHandling(t *testing.T) {
runner, err := NewRunner()
if err != nil {
t.Fatalf("Failed to create runner: %v", err)
}
defer runner.Close()
// Test invalid bytecode
_, err = runner.Run([]byte("not valid bytecode"), nil, "")
if err == nil {
t.Errorf("Expected error for invalid bytecode, got nil")
}
// Test Lua runtime error
bytecode := createTestBytecode(t, `
error("intentional error")
return true
`)
_, err = runner.Run(bytecode, nil, "")
if err == nil {
t.Errorf("Expected error from Lua error() call, got nil")
}
// Test with nil context
bytecode = createTestBytecode(t, "return ctx == nil")
result, err := runner.Run(bytecode, nil, "")
if err != nil {
t.Errorf("Unexpected error with nil context: %v", err)
}
if result.(bool) != true {
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
}
// Test invalid context value
execCtx := NewContext()
execCtx.Set("param", complex128(1+2i)) // Unsupported type
bytecode = createTestBytecode(t, "return ctx.param")
_, err = runner.Run(bytecode, execCtx, "")
if err == nil {
t.Errorf("Expected error for unsupported context value type, got nil")
}
}

281
core/runner/require_test.go Normal file
View File

@ -0,0 +1,281 @@
package runner_test
import (
"os"
"path/filepath"
"testing"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/runner"
)
func TestRequireFunctionality(t *testing.T) {
// Create temporary directories for test
tempDir, err := os.MkdirTemp("", "luarunner-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create script directory and lib directory
scriptDir := filepath.Join(tempDir, "scripts")
libDir := filepath.Join(tempDir, "libs")
if err := os.Mkdir(scriptDir, 0755); err != nil {
t.Fatalf("Failed to create script directory: %v", err)
}
if err := os.Mkdir(libDir, 0755); err != nil {
t.Fatalf("Failed to create lib directory: %v", err)
}
// Create a module in the lib directory
libModule := `
local lib = {}
function lib.add(a, b)
return a + b
end
function lib.mul(a, b)
return a * b
end
return lib
`
if err := os.WriteFile(filepath.Join(libDir, "mathlib.lua"), []byte(libModule), 0644); err != nil {
t.Fatalf("Failed to write lib module: %v", err)
}
// Create a helper module in the script directory
helperModule := `
local helper = {}
function helper.square(x)
return x * x
end
function helper.cube(x)
return x * x * x
end
return helper
`
if err := os.WriteFile(filepath.Join(scriptDir, "helper.lua"), []byte(helperModule), 0644); err != nil {
t.Fatalf("Failed to write helper module: %v", err)
}
// Create main script that requires both modules
mainScript := `
-- Require from the same directory
local helper = require("helper")
-- Require from the lib directory
local mathlib = require("mathlib")
-- Use both modules
local result = {
add = mathlib.add(10, 5),
mul = mathlib.mul(10, 5),
square = helper.square(5),
cube = helper.cube(3)
}
return result
`
mainScriptPath := filepath.Join(scriptDir, "main.lua")
if err := os.WriteFile(mainScriptPath, []byte(mainScript), 0644); err != nil {
t.Fatalf("Failed to write main script: %v", err)
}
// Create LuaRunner
luaRunner, err := runner.NewRunner(
runner.WithScriptDir(scriptDir),
runner.WithLibDirs(libDir),
)
if err != nil {
t.Fatalf("Failed to create LuaRunner: %v", err)
}
defer luaRunner.Close()
// Compile the main script
state := luajit.New()
if state == nil {
t.Fatal("Failed to create Lua state")
}
defer state.Close()
bytecode, err := state.CompileBytecode(mainScript, "main.lua")
if err != nil {
t.Fatalf("Failed to compile script: %v", err)
}
// Run the script
result, err := luaRunner.Run(bytecode, nil, mainScriptPath)
if err != nil {
t.Fatalf("Failed to run script: %v", err)
}
// Check result
resultMap, ok := result.(map[string]any)
if !ok {
t.Fatalf("Expected map result, got %T", result)
}
// Validate results
expectedResults := map[string]float64{
"add": 15, // 10 + 5
"mul": 50, // 10 * 5
"square": 25, // 5^2
"cube": 27, // 3^3
}
for key, expected := range expectedResults {
if val, ok := resultMap[key]; !ok {
t.Errorf("Missing result key: %s", key)
} else if val != expected {
t.Errorf("For %s: expected %.1f, got %v", key, expected, val)
}
}
}
func TestRequireSecurityBoundaries(t *testing.T) {
// Create temporary directories for test
tempDir, err := os.MkdirTemp("", "luarunner-security-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create script directory and lib directory
scriptDir := filepath.Join(tempDir, "scripts")
libDir := filepath.Join(tempDir, "libs")
secretDir := filepath.Join(tempDir, "secret")
err = os.MkdirAll(scriptDir, 0755)
if err != nil {
t.Fatalf("Failed to create script directory: %v", err)
}
err = os.MkdirAll(libDir, 0755)
if err != nil {
t.Fatalf("Failed to create lib directory: %v", err)
}
err = os.MkdirAll(secretDir, 0755)
if err != nil {
t.Fatalf("Failed to create secret directory: %v", err)
}
// Create a "secret" module that should not be accessible
secretModule := `
local secret = "TOP SECRET"
return secret
`
err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644)
if err != nil {
t.Fatalf("Failed to write secret module: %v", err)
}
// Create a normal module in lib
normalModule := `return "normal module"`
err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644)
if err != nil {
t.Fatalf("Failed to write normal module: %v", err)
}
// Create a compile-and-run function that takes care of both compilation and execution
compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) {
// Compile
state := luajit.New()
if state == nil {
return nil, nil
}
defer state.Close()
bytecode, err := state.CompileBytecode(scriptText, scriptName)
if err != nil {
return nil, err
}
// Create and configure a new runner each time
r, err := runner.NewRunner(
runner.WithScriptDir(scriptDir),
runner.WithLibDirs(libDir),
)
if err != nil {
return nil, err
}
defer r.Close()
// Run
return r.Run(bytecode, nil, scriptPath)
}
// Test that normal require works
normalScript := `
local normal = require("normal")
return normal
`
normalPath := filepath.Join(scriptDir, "normal_test.lua")
err = os.WriteFile(normalPath, []byte(normalScript), 0644)
if err != nil {
t.Fatalf("Failed to write normal script: %v", err)
}
result, err := compileAndRun(normalScript, "normal_test.lua", normalPath)
if err != nil {
t.Fatalf("Failed to run normal script: %v", err)
}
if result != "normal module" {
t.Errorf("Expected 'normal module', got %v", result)
}
// Test path traversal attempts
pathTraversalTests := []struct {
name string
script string
}{
{
name: "Direct path traversal",
script: `
-- Try path traversal
local secret = require("../secret/secret")
return secret ~= nil
`,
},
{
name: "Double dot traversal",
script: `
local secret = require("..secret.secret")
return secret ~= nil
`,
},
{
name: "Absolute path traversal",
script: `
local secret = require("` + filepath.Join(secretDir, "secret") + `")
return secret ~= nil
`,
},
}
for _, tt := range pathTraversalTests {
t.Run(tt.name, func(t *testing.T) {
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
err := os.WriteFile(scriptPath, []byte(tt.script), 0644)
if err != nil {
t.Fatalf("Failed to write test script: %v", err)
}
result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath)
// If there's an error, that's expected and good
if err != nil {
return
}
// If no error, then the script should have returned false (couldn't get the module)
if result == true {
t.Errorf("Security breach! Script was able to access restricted module")
}
})
}
}

View File

@ -12,12 +12,10 @@ type Sandbox struct {
// NewSandbox creates a new sandbox
func NewSandbox() *Sandbox {
s := &Sandbox{
return &Sandbox{
modules: make(map[string]any),
initialized: false,
}
return s
}
// AddModule adds a module to the sandbox environment
@ -34,69 +32,63 @@ func (s *Sandbox) Setup(state *luajit.State) error {
// Create high-performance persistent environment
return state.DoString(`
-- Global shared environment (created once)
__env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Global shared environment (created once)
__env_system = __env_system or {
base_env = nil, -- Template environment
initialized = false, -- Initialization flag
env_pool = {}, -- Pre-allocated environment pool
pool_size = 0, -- Current pool size
max_pool_size = 8 -- Maximum pool size
}
-- Initialize base environment once
if not __env_system.initialized then
-- Create base environment with all standard libraries
local base = {}
-- Initialize base environment once
if not __env_system.initialized then
-- Create base environment with all standard libraries
local base = {}
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Safe standard libraries
base.string = string
base.table = table
base.math = math
base.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
-- Basic functions
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.unpack = unpack
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
-- Basic functions
base.tonumber = tonumber
base.tostring = tostring
base.type = type
base.pairs = pairs
base.ipairs = ipairs
base.next = next
base.select = select
base.unpack = unpack
base.pcall = pcall
base.xpcall = xpcall
base.error = error
base.assert = assert
-- Package system is shared for performance
base.package = {
loaded = package.loaded,
path = package.path,
preload = package.preload
}
-- Package system is shared for performance
base.package = {
loaded = package.loaded,
path = package.path,
preload = package.preload
}
-- Add HTTP module explicitly to the base environment
base.http = http
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Add registered custom modules
if __sandbox_modules then
for name, mod in pairs(__sandbox_modules) do
base[name] = mod
end
end
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Global variable for tracking current environment
__last_env = nil
-- Store base environment
__env_system.base_env = base
__env_system.initialized = true
end
-- Fast environment creation with pre-allocation
function __get_sandbox_env(ctx)
@ -109,8 +101,6 @@ func (s *Sandbox) Setup(state *luajit.State) error {
-- Clear any previous context
env.ctx = ctx or nil
-- Clear any previous response
env._response = nil
else
-- Create new environment with metatable inheritance
env = setmetatable({}, {
@ -128,9 +118,6 @@ func (s *Sandbox) Setup(state *luajit.State) error {
end
end
-- Store reference to current environment
__last_env = env
return env
end
@ -140,7 +127,6 @@ func (s *Sandbox) Setup(state *luajit.State) error {
if __env_system.pool_size < __env_system.max_pool_size then
-- Clear context reference to avoid memory leaks
env.ctx = nil
-- Don't clear response data - we need it for extraction
-- Add to pool
table.insert(__env_system.env_pool, env)
@ -253,14 +239,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
// Get result
result, err := state.ToValue(-1)
state.Pop(1) // Pop result
// Check if HTTP response was set
httpResponse, hasHTTPResponse := GetHTTPResponse(state)
if hasHTTPResponse {
httpResponse.Body = result
return httpResponse, err
}
state.Pop(1)
return result, err
}