282 lines
6.8 KiB
Go
282 lines
6.8 KiB
Go
package runner_test
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
"git.sharkk.net/Sky/Moonshark/core/runner"
|
|
)
|
|
|
|
func TestRequireFunctionality(t *testing.T) {
|
|
// Create temporary directories for test
|
|
tempDir, err := os.MkdirTemp("", "luarunner-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp directory: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
// Create script directory and lib directory
|
|
scriptDir := filepath.Join(tempDir, "scripts")
|
|
libDir := filepath.Join(tempDir, "libs")
|
|
|
|
if err := os.Mkdir(scriptDir, 0755); err != nil {
|
|
t.Fatalf("Failed to create script directory: %v", err)
|
|
}
|
|
if err := os.Mkdir(libDir, 0755); err != nil {
|
|
t.Fatalf("Failed to create lib directory: %v", err)
|
|
}
|
|
|
|
// Create a module in the lib directory
|
|
libModule := `
|
|
local lib = {}
|
|
|
|
function lib.add(a, b)
|
|
return a + b
|
|
end
|
|
|
|
function lib.mul(a, b)
|
|
return a * b
|
|
end
|
|
|
|
return lib
|
|
`
|
|
if err := os.WriteFile(filepath.Join(libDir, "mathlib.lua"), []byte(libModule), 0644); err != nil {
|
|
t.Fatalf("Failed to write lib module: %v", err)
|
|
}
|
|
|
|
// Create a helper module in the script directory
|
|
helperModule := `
|
|
local helper = {}
|
|
|
|
function helper.square(x)
|
|
return x * x
|
|
end
|
|
|
|
function helper.cube(x)
|
|
return x * x * x
|
|
end
|
|
|
|
return helper
|
|
`
|
|
if err := os.WriteFile(filepath.Join(scriptDir, "helper.lua"), []byte(helperModule), 0644); err != nil {
|
|
t.Fatalf("Failed to write helper module: %v", err)
|
|
}
|
|
|
|
// Create main script that requires both modules
|
|
mainScript := `
|
|
-- Require from the same directory
|
|
local helper = require("helper")
|
|
|
|
-- Require from the lib directory
|
|
local mathlib = require("mathlib")
|
|
|
|
-- Use both modules
|
|
local result = {
|
|
add = mathlib.add(10, 5),
|
|
mul = mathlib.mul(10, 5),
|
|
square = helper.square(5),
|
|
cube = helper.cube(3)
|
|
}
|
|
|
|
return result
|
|
`
|
|
mainScriptPath := filepath.Join(scriptDir, "main.lua")
|
|
if err := os.WriteFile(mainScriptPath, []byte(mainScript), 0644); err != nil {
|
|
t.Fatalf("Failed to write main script: %v", err)
|
|
}
|
|
|
|
// Create LuaRunner
|
|
luaRunner, err := runner.NewRunner(
|
|
runner.WithScriptDir(scriptDir),
|
|
runner.WithLibDirs(libDir),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create LuaRunner: %v", err)
|
|
}
|
|
defer luaRunner.Close()
|
|
|
|
// Compile the main script
|
|
state := luajit.New()
|
|
if state == nil {
|
|
t.Fatal("Failed to create Lua state")
|
|
}
|
|
defer state.Close()
|
|
|
|
bytecode, err := state.CompileBytecode(mainScript, "main.lua")
|
|
if err != nil {
|
|
t.Fatalf("Failed to compile script: %v", err)
|
|
}
|
|
|
|
// Run the script
|
|
result, err := luaRunner.Run(bytecode, nil, mainScriptPath)
|
|
if err != nil {
|
|
t.Fatalf("Failed to run script: %v", err)
|
|
}
|
|
|
|
// Check result
|
|
resultMap, ok := result.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("Expected map result, got %T", result)
|
|
}
|
|
|
|
// Validate results
|
|
expectedResults := map[string]float64{
|
|
"add": 15, // 10 + 5
|
|
"mul": 50, // 10 * 5
|
|
"square": 25, // 5^2
|
|
"cube": 27, // 3^3
|
|
}
|
|
|
|
for key, expected := range expectedResults {
|
|
if val, ok := resultMap[key]; !ok {
|
|
t.Errorf("Missing result key: %s", key)
|
|
} else if val != expected {
|
|
t.Errorf("For %s: expected %.1f, got %v", key, expected, val)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRequireSecurityBoundaries(t *testing.T) {
|
|
// Create temporary directories for test
|
|
tempDir, err := os.MkdirTemp("", "luarunner-security-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp directory: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
// Create script directory and lib directory
|
|
scriptDir := filepath.Join(tempDir, "scripts")
|
|
libDir := filepath.Join(tempDir, "libs")
|
|
secretDir := filepath.Join(tempDir, "secret")
|
|
|
|
err = os.MkdirAll(scriptDir, 0755)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create script directory: %v", err)
|
|
}
|
|
err = os.MkdirAll(libDir, 0755)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create lib directory: %v", err)
|
|
}
|
|
err = os.MkdirAll(secretDir, 0755)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create secret directory: %v", err)
|
|
}
|
|
|
|
// Create a "secret" module that should not be accessible
|
|
secretModule := `
|
|
local secret = "TOP SECRET"
|
|
return secret
|
|
`
|
|
err = os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644)
|
|
if err != nil {
|
|
t.Fatalf("Failed to write secret module: %v", err)
|
|
}
|
|
|
|
// Create a normal module in lib
|
|
normalModule := `return "normal module"`
|
|
err = os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644)
|
|
if err != nil {
|
|
t.Fatalf("Failed to write normal module: %v", err)
|
|
}
|
|
|
|
// Create a compile-and-run function that takes care of both compilation and execution
|
|
compileAndRun := func(scriptText, scriptName, scriptPath string) (interface{}, error) {
|
|
// Compile
|
|
state := luajit.New()
|
|
if state == nil {
|
|
return nil, nil
|
|
}
|
|
defer state.Close()
|
|
|
|
bytecode, err := state.CompileBytecode(scriptText, scriptName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create and configure a new runner each time
|
|
r, err := runner.NewRunner(
|
|
runner.WithScriptDir(scriptDir),
|
|
runner.WithLibDirs(libDir),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer r.Close()
|
|
|
|
// Run
|
|
return r.Run(bytecode, nil, scriptPath)
|
|
}
|
|
|
|
// Test that normal require works
|
|
normalScript := `
|
|
local normal = require("normal")
|
|
return normal
|
|
`
|
|
normalPath := filepath.Join(scriptDir, "normal_test.lua")
|
|
err = os.WriteFile(normalPath, []byte(normalScript), 0644)
|
|
if err != nil {
|
|
t.Fatalf("Failed to write normal script: %v", err)
|
|
}
|
|
|
|
result, err := compileAndRun(normalScript, "normal_test.lua", normalPath)
|
|
if err != nil {
|
|
t.Fatalf("Failed to run normal script: %v", err)
|
|
}
|
|
|
|
if result != "normal module" {
|
|
t.Errorf("Expected 'normal module', got %v", result)
|
|
}
|
|
|
|
// Test path traversal attempts
|
|
pathTraversalTests := []struct {
|
|
name string
|
|
script string
|
|
}{
|
|
{
|
|
name: "Direct path traversal",
|
|
script: `
|
|
-- Try path traversal
|
|
local secret = require("../secret/secret")
|
|
return secret ~= nil
|
|
`,
|
|
},
|
|
{
|
|
name: "Double dot traversal",
|
|
script: `
|
|
local secret = require("..secret.secret")
|
|
return secret ~= nil
|
|
`,
|
|
},
|
|
{
|
|
name: "Absolute path traversal",
|
|
script: `
|
|
local secret = require("` + filepath.Join(secretDir, "secret") + `")
|
|
return secret ~= nil
|
|
`,
|
|
},
|
|
}
|
|
|
|
for _, tt := range pathTraversalTests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
|
|
err := os.WriteFile(scriptPath, []byte(tt.script), 0644)
|
|
if err != nil {
|
|
t.Fatalf("Failed to write test script: %v", err)
|
|
}
|
|
|
|
result, err := compileAndRun(tt.script, tt.name+".lua", scriptPath)
|
|
// If there's an error, that's expected and good
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// If no error, then the script should have returned false (couldn't get the module)
|
|
if result == true {
|
|
t.Errorf("Security breach! Script was able to access restricted module")
|
|
}
|
|
})
|
|
}
|
|
}
|