remove test files
This commit is contained in:
parent
7bc5194b10
commit
780533bd76
|
@ -1,351 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLoad verifies we can successfully load configuration values from a Lua file.
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
content := `
|
||||
-- Basic configuration values
|
||||
host = "localhost"
|
||||
port = 8080
|
||||
debug = true
|
||||
pi = 3.14159
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify values were loaded correctly
|
||||
if host := cfg.GetString("host", ""); host != "localhost" {
|
||||
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
||||
}
|
||||
|
||||
if port := cfg.GetInt("port", 0); port != 8080 {
|
||||
t.Errorf("Expected port to be 8080, got %d", port)
|
||||
}
|
||||
|
||||
if debug := cfg.GetBool("debug", false); !debug {
|
||||
t.Errorf("Expected debug to be true")
|
||||
}
|
||||
|
||||
if pi := cfg.GetFloat("pi", 0); pi != 3.14159 {
|
||||
t.Errorf("Expected pi to be 3.14159, got %f", pi)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadErrors ensures the package properly handles loading errors.
|
||||
func TestLoadErrors(t *testing.T) {
|
||||
// Test with non-existent file
|
||||
_, err := Load("nonexistent.lua")
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading non-existent file, got nil")
|
||||
}
|
||||
|
||||
// Test with invalid Lua
|
||||
content := `
|
||||
-- This is invalid Lua
|
||||
host = "localhost
|
||||
port = 8080)
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
_, err = Load(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading invalid Lua, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalVsGlobal verifies only global variables are exported, not locals.
|
||||
func TestLocalVsGlobal(t *testing.T) {
|
||||
// Create a temporary config file with both local and global variables
|
||||
content := `
|
||||
-- Local variables should not be exported
|
||||
local local_var = "hidden"
|
||||
|
||||
-- Global variables should be exported
|
||||
global_var = "visible"
|
||||
|
||||
-- A function that uses both
|
||||
function test_func()
|
||||
return local_var .. " " .. global_var
|
||||
end
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Check that global_var exists
|
||||
if globalVar := cfg.GetString("global_var", ""); globalVar != "visible" {
|
||||
t.Errorf("Expected global_var to be 'visible', got '%s'", globalVar)
|
||||
}
|
||||
|
||||
// Check that local_var does not exist
|
||||
if localVar := cfg.GetString("local_var", "default"); localVar != "default" {
|
||||
t.Errorf("Expected local_var to use default, got '%s'", localVar)
|
||||
}
|
||||
|
||||
// Check that functions are not exported
|
||||
if val := cfg.Get("test_func"); val != nil {
|
||||
t.Errorf("Expected function to not be exported, got %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// TestArrayHandling verifies correct handling of Lua arrays.
|
||||
func TestArrayHandling(t *testing.T) {
|
||||
// Create a temporary config file with arrays
|
||||
content := `
|
||||
-- Numeric array
|
||||
numbers = {10, 20, 30, 40, 50}
|
||||
|
||||
-- String array
|
||||
strings = {"apple", "banana", "cherry"}
|
||||
|
||||
-- Mixed array
|
||||
mixed = {1, "two", true, 4.5}
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test GetIntArray
|
||||
intArray := cfg.GetIntArray("numbers")
|
||||
expectedInts := []int{10, 20, 30, 40, 50}
|
||||
if !reflect.DeepEqual(intArray, expectedInts) {
|
||||
t.Errorf("Expected int array %v, got %v", expectedInts, intArray)
|
||||
}
|
||||
|
||||
// Test GetStringArray
|
||||
strArray := cfg.GetStringArray("strings")
|
||||
expectedStrs := []string{"apple", "banana", "cherry"}
|
||||
if !reflect.DeepEqual(strArray, expectedStrs) {
|
||||
t.Errorf("Expected string array %v, got %v", expectedStrs, strArray)
|
||||
}
|
||||
|
||||
// Test GetArray with mixed types
|
||||
mixedArray := cfg.GetArray("mixed")
|
||||
if len(mixedArray) != 4 {
|
||||
t.Errorf("Expected mixed array length 4, got %d", len(mixedArray))
|
||||
// Skip further tests if array is empty to avoid panic
|
||||
return
|
||||
}
|
||||
|
||||
// Check types - carefully to avoid panics
|
||||
if len(mixedArray) > 0 {
|
||||
if num, ok := mixedArray[0].(float64); !ok || num != 1 {
|
||||
t.Errorf("Expected first element to be 1, got %v", mixedArray[0])
|
||||
}
|
||||
}
|
||||
|
||||
if len(mixedArray) > 1 {
|
||||
if str, ok := mixedArray[1].(string); !ok || str != "two" {
|
||||
t.Errorf("Expected second element to be 'two', got %v", mixedArray[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestComplexTable tests handling of complex nested tables.
|
||||
func TestComplexTable(t *testing.T) {
|
||||
// Create a temporary config file with complex tables
|
||||
content := `
|
||||
-- Nested table structure
|
||||
server = {
|
||||
host = "localhost",
|
||||
port = 8080,
|
||||
settings = {
|
||||
timeout = 30,
|
||||
retries = 3
|
||||
}
|
||||
}
|
||||
|
||||
-- Table with mixed array and map elements
|
||||
mixed_table = {
|
||||
list = {1, 2, 3},
|
||||
mapping = {
|
||||
a = "apple",
|
||||
b = "banana"
|
||||
}
|
||||
}
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test getting nested values
|
||||
serverMap := cfg.GetMap("server")
|
||||
if serverMap == nil {
|
||||
t.Fatal("Expected server map to exist")
|
||||
}
|
||||
|
||||
// Check first level values
|
||||
if host, ok := serverMap["host"].(string); !ok || host != "localhost" {
|
||||
t.Errorf("Expected server.host to be 'localhost', got %v", serverMap["host"])
|
||||
}
|
||||
|
||||
if port, ok := serverMap["port"].(float64); !ok || port != 8080 {
|
||||
t.Errorf("Expected server.port to be 8080, got %v", serverMap["port"])
|
||||
}
|
||||
|
||||
// Check nested settings
|
||||
settings, ok := serverMap["settings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected server.settings to be a map")
|
||||
}
|
||||
|
||||
if timeout, ok := settings["timeout"].(float64); !ok || timeout != 30 {
|
||||
t.Errorf("Expected server.settings.timeout to be 30, got %v", settings["timeout"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultValues verifies default values work correctly when keys don't exist.
|
||||
func TestDefaultValues(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
content := `
|
||||
-- Just one value
|
||||
existing = "value"
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test defaults for non-existent keys
|
||||
if val := cfg.GetString("nonexistent", "default"); val != "default" {
|
||||
t.Errorf("Expected default string, got '%s'", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetInt("nonexistent", 42); val != 42 {
|
||||
t.Errorf("Expected default int 42, got %d", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetFloat("nonexistent", 3.14); val != 3.14 {
|
||||
t.Errorf("Expected default float 3.14, got %f", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetBool("nonexistent", true); !val {
|
||||
t.Errorf("Expected default bool true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestModifyConfig tests the ability to modify configuration values.
|
||||
func TestModifyConfig(t *testing.T) {
|
||||
// Create a config manually
|
||||
cfg := New()
|
||||
|
||||
// Set some values
|
||||
cfg.Set("host", "localhost")
|
||||
cfg.Set("port", 8080)
|
||||
|
||||
// Verify the values were set
|
||||
if host := cfg.GetString("host", ""); host != "localhost" {
|
||||
t.Errorf("Expected host to be 'localhost', got '%s'", host)
|
||||
}
|
||||
|
||||
if port := cfg.GetInt("port", 0); port != 8080 {
|
||||
t.Errorf("Expected port to be 8080, got %d", port)
|
||||
}
|
||||
|
||||
// Modify a value
|
||||
cfg.Set("host", "127.0.0.1")
|
||||
|
||||
// Verify the change
|
||||
if host := cfg.GetString("host", ""); host != "127.0.0.1" {
|
||||
t.Errorf("Expected modified host to be '127.0.0.1', got '%s'", host)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTypeConversion tests the type conversion in getter methods.
|
||||
func TestTypeConversion(t *testing.T) {
|
||||
// Create a temporary config file with values that need conversion
|
||||
content := `
|
||||
-- Numbers that can be integers
|
||||
int_as_float = 42.0
|
||||
|
||||
-- Floats that should remain floats
|
||||
float_val = 3.14159
|
||||
`
|
||||
configFile := createTempLuaFile(t, content)
|
||||
defer os.Remove(configFile)
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test GetInt with a float value
|
||||
if val := cfg.GetInt("int_as_float", 0); val != 42 {
|
||||
t.Errorf("Expected int 42, got %d", val)
|
||||
}
|
||||
|
||||
// Test GetFloat with an int value
|
||||
if val := cfg.GetFloat("int_as_float", 0); val != 42.0 {
|
||||
t.Errorf("Expected float 42.0, got %f", val)
|
||||
}
|
||||
|
||||
// Test incorrect type handling
|
||||
cfg.Set("string_val", "not a number")
|
||||
|
||||
if val := cfg.GetInt("string_val", 99); val != 99 {
|
||||
t.Errorf("Expected default int 99 for string value, got %d", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetFloat("string_val", 99.9); val != 99.9 {
|
||||
t.Errorf("Expected default float 99.9 for string value, got %f", val)
|
||||
}
|
||||
|
||||
if val := cfg.GetBool("float_val", false); val != false {
|
||||
t.Errorf("Expected default false for non-bool value, got true")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a temporary Lua file with content
|
||||
func createTempLuaFile(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
|
||||
tempFile, err := os.CreateTemp("", "config-test-*.lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
if _, err := tempFile.WriteString(content); err != nil {
|
||||
os.Remove(tempFile.Name())
|
||||
t.Fatalf("Failed to write to temp file: %v", err)
|
||||
}
|
||||
|
||||
if err := tempFile.Close(); err != nil {
|
||||
os.Remove(tempFile.Name())
|
||||
t.Fatalf("Failed to close temp file: %v", err)
|
||||
}
|
||||
|
||||
return tempFile.Name()
|
||||
}
|
|
@ -1,214 +0,0 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoggerLevels(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelInfo, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Debug should be below threshold
|
||||
logger.Debug("This should not appear")
|
||||
if buf.Len() > 0 {
|
||||
t.Error("Debug message appeared when it should be filtered")
|
||||
}
|
||||
|
||||
// Info and above should appear
|
||||
logger.Info("Info message")
|
||||
if !strings.Contains(buf.String(), "INFO") {
|
||||
t.Errorf("Info message not logged, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
logger.Warning("Warning message")
|
||||
if !strings.Contains(buf.String(), "WARN") {
|
||||
t.Errorf("Warning message not logged, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
logger.Error("Error message")
|
||||
if !strings.Contains(buf.String(), "ERROR") {
|
||||
t.Errorf("Error message not logged, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test format strings
|
||||
logger.Info("Count: %d", 42)
|
||||
if !strings.Contains(buf.String(), "Count: 42") {
|
||||
t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test changing level
|
||||
logger.SetLevel(LevelError)
|
||||
logger.Info("This should not appear")
|
||||
logger.Warning("This should not appear")
|
||||
if buf.Len() > 0 {
|
||||
t.Error("Messages below threshold appeared")
|
||||
}
|
||||
|
||||
logger.Error("Error should appear")
|
||||
if !strings.Contains(buf.String(), "ERROR") {
|
||||
t.Errorf("Error message not logged after level change, got: %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerRateLimit(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelDebug, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Override max logs per second to something small for testing
|
||||
logger.maxLogsPerSec = 5
|
||||
logger.limitDuration = 1 * time.Second
|
||||
|
||||
// Send debug messages (should get limited)
|
||||
for i := 0; i < 20; i++ {
|
||||
logger.Debug("Debug message %d", i)
|
||||
}
|
||||
|
||||
// Error messages should always go through
|
||||
logger.Error("Error message should appear")
|
||||
|
||||
content := buf.String()
|
||||
|
||||
// We should see some debug messages, then a warning about rate limiting,
|
||||
// and finally the error message
|
||||
if !strings.Contains(content, "Debug message 0") {
|
||||
t.Error("First debug message should appear")
|
||||
}
|
||||
|
||||
if !strings.Contains(content, "Rate limiting logger") {
|
||||
t.Error("Rate limiting message should appear")
|
||||
}
|
||||
|
||||
if !strings.Contains(content, "ERROR") {
|
||||
t.Error("Error message should always appear despite rate limiting")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerConcurrency(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelDebug, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Increase log threshold for this test
|
||||
logger.maxLogsPerSec = 1000
|
||||
|
||||
// Log a bunch of messages concurrently
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
logger.Info("Concurrent message %d", n)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Check logs were processed
|
||||
content := buf.String()
|
||||
if !strings.Contains(content, "Concurrent message") {
|
||||
t.Error("Concurrent messages should appear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerColors(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelInfo, true)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Test with color
|
||||
logger.Info("Colored message")
|
||||
|
||||
content := buf.String()
|
||||
t.Logf("Colored output: %q", content) // Print actual output for diagnosis
|
||||
if !strings.Contains(content, "\033[") {
|
||||
t.Errorf("Color codes not present when enabled, got: %q", content)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
logger.DisableColors()
|
||||
logger.Info("Non-colored message")
|
||||
|
||||
content = buf.String()
|
||||
if strings.Contains(content, "\033[") {
|
||||
t.Errorf("Color codes present when disabled, got: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
SetOutput(&buf)
|
||||
|
||||
Info("Test default logger")
|
||||
|
||||
content := buf.String()
|
||||
if !strings.Contains(content, "INFO") {
|
||||
t.Errorf("Default logger not working, got: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelInfo, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Set very high threshold to avoid rate limiting during benchmark
|
||||
logger.maxLogsPerSec = int64(b.N + 1)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Info("Benchmark message %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoggerWithRateLimit(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelDebug, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Set threshold to allow about 10% of messages through
|
||||
logger.maxLogsPerSec = int64(b.N / 10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Debug("Benchmark message %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoggerParallel(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelDebug, false)
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
// Set very high threshold to avoid rate limiting during benchmark
|
||||
logger.maxLogsPerSec = int64(b.N + 1)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
logger.Debug("Parallel benchmark message %d", i)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkProductionLevels(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
logger := New(LevelWarning, false) // Only log warnings and above
|
||||
logger.SetOutput(&buf)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// This should be filtered out before any processing
|
||||
logger.Debug("Debug message that won't be logged %d", i)
|
||||
}
|
||||
}
|
|
@ -1,373 +0,0 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Helper function to create bytecode for testing
|
||||
func createTestBytecode(t *testing.T, code string) []byte {
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(code, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile test bytecode: %v", err)
|
||||
}
|
||||
|
||||
return bytecode
|
||||
}
|
||||
|
||||
func TestRunnerBasic(t *testing.T) {
|
||||
runner, err := NewRunner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
defer runner.Close()
|
||||
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
|
||||
result, err := runner.Run(bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run script: %v", err)
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected float64 result, got %T", result)
|
||||
}
|
||||
|
||||
if num != 42 {
|
||||
t.Errorf("Expected 42, got %f", num)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerWithContext(t *testing.T) {
|
||||
runner, err := NewRunner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
defer runner.Close()
|
||||
|
||||
bytecode := createTestBytecode(t, `
|
||||
return {
|
||||
num = ctx.number,
|
||||
str = ctx.text,
|
||||
flag = ctx.enabled,
|
||||
list = {ctx.table[1], ctx.table[2], ctx.table[3]},
|
||||
}
|
||||
`)
|
||||
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("number", 42.5)
|
||||
execCtx.Set("text", "hello")
|
||||
execCtx.Set("enabled", true)
|
||||
execCtx.Set("table", []float64{10, 20, 30})
|
||||
|
||||
result, err := runner.Run(bytecode, execCtx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run job: %v", err)
|
||||
}
|
||||
|
||||
// Result should be a map
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Check values
|
||||
if resultMap["num"] != 42.5 {
|
||||
t.Errorf("Expected num=42.5, got %v", resultMap["num"])
|
||||
}
|
||||
if resultMap["str"] != "hello" {
|
||||
t.Errorf("Expected str=hello, got %v", resultMap["str"])
|
||||
}
|
||||
if resultMap["flag"] != true {
|
||||
t.Errorf("Expected flag=true, got %v", resultMap["flag"])
|
||||
}
|
||||
|
||||
arr, ok := resultMap["list"].([]float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []float64, got %T", resultMap["list"])
|
||||
}
|
||||
|
||||
expected := []float64{10, 20, 30}
|
||||
for i, v := range expected {
|
||||
if arr[i] != v {
|
||||
t.Errorf("Expected list[%d]=%f, got %f", i, v, arr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerWithTimeout(t *testing.T) {
|
||||
runner, err := NewRunner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
defer runner.Close()
|
||||
|
||||
// Create bytecode that sleeps
|
||||
bytecode := createTestBytecode(t, `
|
||||
-- Sleep for 500ms
|
||||
local start = os.time()
|
||||
while os.difftime(os.time(), start) < 0.5 do end
|
||||
return "done"
|
||||
`)
|
||||
|
||||
// Test with timeout that should succeed
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := runner.RunWithContext(ctx, bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
|
||||
}
|
||||
if result != "done" {
|
||||
t.Errorf("Expected 'done', got %v", result)
|
||||
}
|
||||
|
||||
// Test with timeout that should fail
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = runner.RunWithContext(ctx, bytecode, nil, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected timeout error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxIsolation(t *testing.T) {
|
||||
runner, err := NewRunner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
defer runner.Close()
|
||||
|
||||
// Create a script that tries to modify a global variable
|
||||
bytecode1 := createTestBytecode(t, `
|
||||
-- Set a "global" variable
|
||||
my_global = "test value"
|
||||
return true
|
||||
`)
|
||||
|
||||
_, err = runner.Run(bytecode1, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute first script: %v", err)
|
||||
}
|
||||
|
||||
// Now try to access that variable from another script
|
||||
bytecode2 := createTestBytecode(t, `
|
||||
-- Try to access the previously set global
|
||||
return my_global ~= nil
|
||||
`)
|
||||
|
||||
result, err := runner.Run(bytecode2, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute second script: %v", err)
|
||||
}
|
||||
|
||||
// The variable should not be accessible (sandbox isolation)
|
||||
if result.(bool) {
|
||||
t.Errorf("Expected sandbox isolation, but global variable was accessible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerWithInit(t *testing.T) {
|
||||
// Define an init function that registers a simple "math" module
|
||||
mathInit := func(state *luajit.State) error {
|
||||
// Register the "add" function
|
||||
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a + b)
|
||||
return 1 // Return one result
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register a whole module
|
||||
mathFuncs := map[string]luajit.GoFunction{
|
||||
"multiply": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a * b)
|
||||
return 1
|
||||
},
|
||||
"subtract": func(s *luajit.State) int {
|
||||
a := s.ToNumber(1)
|
||||
b := s.ToNumber(2)
|
||||
s.PushNumber(a - b)
|
||||
return 1
|
||||
},
|
||||
}
|
||||
|
||||
return RegisterModule(state, "math2", mathFuncs)
|
||||
}
|
||||
|
||||
// Create a runner with our init function
|
||||
runner, err := NewRunner(WithInitFunc(mathInit))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
defer runner.Close()
|
||||
|
||||
// Test the add function
|
||||
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
||||
result1, err := runner.Run(bytecode1, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call add function: %v", err)
|
||||
}
|
||||
|
||||
num1, ok := result1.(float64)
|
||||
if !ok || num1 != 12 {
|
||||
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
|
||||
}
|
||||
|
||||
// Test the math2 module
|
||||
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
||||
result2, err := runner.Run(bytecode2, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call math2.multiply: %v", err)
|
||||
}
|
||||
|
||||
num2, ok := result2.(float64)
|
||||
if !ok || num2 != 48 {
|
||||
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentExecution(t *testing.T) {
|
||||
const jobs = 20
|
||||
|
||||
runner, err := NewRunner(WithBufferSize(20))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
defer runner.Close()
|
||||
|
||||
// Create bytecode that returns its input
|
||||
bytecode := createTestBytecode(t, "return ctx.n")
|
||||
|
||||
// Run multiple jobs concurrently
|
||||
results := make(chan int, jobs)
|
||||
for i := 0; i < jobs; i++ {
|
||||
i := i // Capture loop variable
|
||||
go func() {
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("n", float64(i))
|
||||
|
||||
result, err := runner.Run(bytecode, execCtx, "")
|
||||
if err != nil {
|
||||
t.Errorf("Job %d failed: %v", i, err)
|
||||
results <- -1
|
||||
return
|
||||
}
|
||||
|
||||
num, ok := result.(float64)
|
||||
if !ok {
|
||||
t.Errorf("Job %d: expected float64, got %T", i, result)
|
||||
results <- -1
|
||||
return
|
||||
}
|
||||
|
||||
results <- int(num)
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
seen := make(map[int]bool)
|
||||
for i := 0; i < jobs; i++ {
|
||||
result := <-results
|
||||
if result != -1 {
|
||||
seen[result] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all jobs were processed
|
||||
if len(seen) != jobs {
|
||||
t.Errorf("Expected %d unique results, got %d", jobs, len(seen))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerClose(t *testing.T) {
|
||||
runner, err := NewRunner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
|
||||
// Submit a job to verify runner works
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
_, err = runner.Run(bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run job: %v", err)
|
||||
}
|
||||
|
||||
// Close
|
||||
if err := runner.Close(); err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// Run after close should fail
|
||||
_, err = runner.Run(bytecode, nil, "")
|
||||
if err != ErrRunnerClosed {
|
||||
t.Errorf("Expected ErrRunnerClosed, got %v", err)
|
||||
}
|
||||
|
||||
// Second close should return error
|
||||
if err := runner.Close(); err != ErrRunnerClosed {
|
||||
t.Errorf("Expected ErrRunnerClosed on second close, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
runner, err := NewRunner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create runner: %v", err)
|
||||
}
|
||||
defer runner.Close()
|
||||
|
||||
// Test invalid bytecode
|
||||
_, err = runner.Run([]byte("not valid bytecode"), nil, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid bytecode, got nil")
|
||||
}
|
||||
|
||||
// Test Lua runtime error
|
||||
bytecode := createTestBytecode(t, `
|
||||
error("intentional error")
|
||||
return true
|
||||
`)
|
||||
|
||||
_, err = runner.Run(bytecode, nil, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from Lua error() call, got nil")
|
||||
}
|
||||
|
||||
// Test with nil context
|
||||
bytecode = createTestBytecode(t, "return ctx == nil")
|
||||
result, err := runner.Run(bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error with nil context: %v", err)
|
||||
}
|
||||
if result.(bool) != true {
|
||||
t.Errorf("Expected ctx to be nil in Lua, but it wasn't")
|
||||
}
|
||||
|
||||
// Test invalid context value
|
||||
execCtx := NewContext()
|
||||
execCtx.Set("param", complex128(1+2i)) // Unsupported type
|
||||
|
||||
bytecode = createTestBytecode(t, "return ctx.param")
|
||||
_, err = runner.Run(bytecode, execCtx, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for unsupported context value type, got nil")
|
||||
}
|
||||
}
|
|
@ -1,281 +0,0 @@
|
|||
package runner_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/runner"
|
||||
)
|
||||
|
||||
func TestRequireFunctionality(t *testing.T) {
|
||||
// Create temporary directories for test
|
||||
tempDir, err := os.MkdirTemp("", "luarunner-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create script directory and lib directory
|
||||
scriptDir := filepath.Join(tempDir, "scripts")
|
||||
libDir := filepath.Join(tempDir, "libs")
|
||||
|
||||
if err := os.Mkdir(scriptDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create script directory: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(libDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create lib directory: %v", err)
|
||||
}
|
||||
|
||||
// Create a module in the lib directory
|
||||
libModule := `
|
||||
local lib = {}
|
||||
|
||||
function lib.add(a, b)
|
||||
return a + b
|
||||
end
|
||||
|
||||
function lib.mul(a, b)
|
||||
return a * b
|
||||
end
|
||||
|
||||
return lib
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(libDir, "mathlib.lua"), []byte(libModule), 0644); err != nil {
|
||||
t.Fatalf("Failed to write lib module: %v", err)
|
||||
}
|
||||
|
||||
// Create a helper module in the script directory
|
||||
helperModule := `
|
||||
local helper = {}
|
||||
|
||||
function helper.square(x)
|
||||
return x * x
|
||||
end
|
||||
|
||||
function helper.cube(x)
|
||||
return x * x * x
|
||||
end
|
||||
|
||||
return helper
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(scriptDir, "helper.lua"), []byte(helperModule), 0644); err != nil {
|
||||
t.Fatalf("Failed to write helper module: %v", err)
|
||||
}
|
||||
|
||||
// Create main script that requires both modules
|
||||
mainScript := `
|
||||
-- Require from the same directory
|
||||
local helper = require("helper")
|
||||
|
||||
-- Require from the lib directory
|
||||
local mathlib = require("mathlib")
|
||||
|
||||
-- Use both modules
|
||||
local result = {
|
||||
add = mathlib.add(10, 5),
|
||||
mul = mathlib.mul(10, 5),
|
||||
square = helper.square(5),
|
||||
cube = helper.cube(3)
|
||||
}
|
||||
|
||||
return result
|
||||
`
|
||||
mainScriptPath := filepath.Join(scriptDir, "main.lua")
|
||||
if err := os.WriteFile(mainScriptPath, []byte(mainScript), 0644); err != nil {
|
||||
t.Fatalf("Failed to write main script: %v", err)
|
||||
}
|
||||
|
||||
// Create LuaRunner
|
||||
luaRunner, err := runner.NewRunner(
|
||||
runner.WithScriptDir(scriptDir),
|
||||
runner.WithLibDirs(libDir),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create LuaRunner: %v", err)
|
||||
}
|
||||
defer luaRunner.Close()
|
||||
|
||||
// Compile the main script
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(mainScript, "main.lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile script: %v", err)
|
||||
}
|
||||
|
||||
// Run the script
|
||||
result, err := luaRunner.Run(bytecode, nil, mainScriptPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run script: %v", err)
|
||||
}
|
||||
|
||||
// Check result
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Validate results
|
||||
expectedResults := map[string]float64{
|
||||
"add": 15, // 10 + 5
|
||||
"mul": 50, // 10 * 5
|
||||
"square": 25, // 5^2
|
||||
"cube": 27, // 3^3
|
||||
}
|
||||
|
||||
for key, expected := range expectedResults {
|
||||
if val, ok := resultMap[key]; !ok {
|
||||
t.Errorf("Missing result key: %s", key)
|
||||
} else if val != expected {
|
||||
t.Errorf("For %s: expected %.1f, got %v", key, expected, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSecurityBoundaries(t *testing.T) {
|
||||
// Create temporary directories for test
|
||||
tempDir, err := os.MkdirTemp("", "luarunner-security-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create script directory and lib directory
|
||||
scriptDir := filepath.Join(tempDir, "scripts")
|
||||
libDir := filepath.Join(tempDir, "libs")
|
||||
secretDir := filepath.Join(tempDir, "secret")
|
||||
|
||||
err = os.MkdirAll(scriptDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create script directory: %v", err)
|
||||
}
|
||||
err = os.MkdirAll(libDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create lib directory: %v", err)
|
||||
}
|
||||
err = os.MkdirAll(secretDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create secret directory: %v", err)
|
||||
}
|
||||
|
||||
// Create a "secret" module that should not be accessible
|
||||
secretModule := `
|
||||
local secret = "TOP SECRET"
|
||||
return secret
|
||||
`
|
||||
err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write secret module: %v", err)
|
||||
}
|
||||
|
||||
// Create a normal module in lib
|
||||
normalModule := `return "normal module"`
|
||||
err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write normal module: %v", err)
|
||||
}
|
||||
|
||||
// Create a compile-and-run function that takes care of both compilation and execution
|
||||
compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) {
|
||||
// Compile
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(scriptText, scriptName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create and configure a new runner each time
|
||||
r, err := runner.NewRunner(
|
||||
runner.WithScriptDir(scriptDir),
|
||||
runner.WithLibDirs(libDir),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// Run
|
||||
return r.Run(bytecode, nil, scriptPath)
|
||||
}
|
||||
|
||||
// Test that normal require works
|
||||
normalScript := `
|
||||
local normal = require("normal")
|
||||
return normal
|
||||
`
|
||||
normalPath := filepath.Join(scriptDir, "normal_test.lua")
|
||||
err = os.WriteFile(normalPath, []byte(normalScript), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write normal script: %v", err)
|
||||
}
|
||||
|
||||
result, err := compileAndRun(normalScript, "normal_test.lua", normalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run normal script: %v", err)
|
||||
}
|
||||
|
||||
if result != "normal module" {
|
||||
t.Errorf("Expected 'normal module', got %v", result)
|
||||
}
|
||||
|
||||
// Test path traversal attempts
|
||||
pathTraversalTests := []struct {
|
||||
name string
|
||||
script string
|
||||
}{
|
||||
{
|
||||
name: "Direct path traversal",
|
||||
script: `
|
||||
-- Try path traversal
|
||||
local secret = require("../secret/secret")
|
||||
return secret ~= nil
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Double dot traversal",
|
||||
script: `
|
||||
local secret = require("..secret.secret")
|
||||
return secret ~= nil
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Absolute path traversal",
|
||||
script: `
|
||||
local secret = require("` + filepath.Join(secretDir, "secret") + `")
|
||||
return secret ~= nil
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range pathTraversalTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
|
||||
err := os.WriteFile(scriptPath, []byte(tt.script), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test script: %v", err)
|
||||
}
|
||||
|
||||
result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath)
|
||||
// If there's an error, that's expected and good
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// If no error, then the script should have returned false (couldn't get the module)
|
||||
if result == true {
|
||||
t.Errorf("Security breach! Script was able to access restricted module")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user