213 lines
4.9 KiB
Go
213 lines
4.9 KiB
Go
package runner
|
|
|
|
import (
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// Common errors
|
|
var (
|
|
ErrModuleNotFound = errors.New("module not found")
|
|
ErrPathTraversal = errors.New("path traversal not allowed")
|
|
)
|
|
|
|
// RequireConfig holds configuration for Lua's require function
|
|
type RequireConfig struct {
|
|
ScriptDir string // Base directory for script being executed
|
|
LibDirs []string // Additional library directories
|
|
}
|
|
|
|
// RequireCache is a thread-safe cache for loaded Lua modules
|
|
type RequireCache struct {
|
|
modules sync.Map // Maps full file paths to compiled bytecode
|
|
}
|
|
|
|
// NewRequireCache creates a new, empty require cache
|
|
func NewRequireCache() *RequireCache {
|
|
return &RequireCache{
|
|
modules: sync.Map{},
|
|
}
|
|
}
|
|
|
|
// SetupRequire configures the Lua state with a secure require function
|
|
func SetupRequire(state *luajit.State, cache *RequireCache, config RequireConfig) error {
|
|
// Register the loader function
|
|
err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int {
|
|
// Get module name
|
|
modName := s.ToString(1)
|
|
if modName == "" {
|
|
s.PushString("module name required")
|
|
return -1 // Return error
|
|
}
|
|
|
|
// Try to load the module
|
|
bytecode, err := findAndCompileModule(s, cache, config, modName)
|
|
if err != nil {
|
|
if err == ErrModuleNotFound {
|
|
s.PushString("module '" + modName + "' not found")
|
|
} else {
|
|
s.PushString("error loading module: " + err.Error())
|
|
}
|
|
return -1 // Return error
|
|
}
|
|
|
|
// Load the bytecode
|
|
if err := s.LoadBytecode(bytecode, modName); err != nil {
|
|
s.PushString("error loading bytecode: " + err.Error())
|
|
return -1 // Return error
|
|
}
|
|
|
|
// Return the loaded function
|
|
return 1
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set up the secure require implementation
|
|
setupScript := `
|
|
-- Create a secure require function for sandboxed environments
|
|
function __setup_secure_require(env)
|
|
-- Replace env.require with our secure version
|
|
env.require = function(modname)
|
|
-- Check if already loaded in package.loaded
|
|
if package.loaded[modname] then
|
|
return package.loaded[modname]
|
|
end
|
|
|
|
-- Try to load the module using our Go loader
|
|
local loader = __go_load_module
|
|
|
|
-- Load the module
|
|
local f, err = loader(modname)
|
|
if not f then
|
|
error(err or "failed to load module: " .. modname)
|
|
end
|
|
|
|
-- Set the environment for the module
|
|
setfenv(f, env)
|
|
|
|
-- Execute the module
|
|
local result = f()
|
|
|
|
-- If module didn't return a value, use true
|
|
if result == nil then
|
|
result = true
|
|
end
|
|
|
|
-- Cache the result
|
|
package.loaded[modname] = result
|
|
|
|
return result
|
|
end
|
|
|
|
return env
|
|
end
|
|
`
|
|
|
|
return state.DoString(setupScript)
|
|
}
|
|
|
|
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
|
|
func findAndCompileModule(
|
|
state *luajit.State,
|
|
cache *RequireCache,
|
|
config RequireConfig,
|
|
modName string,
|
|
) ([]byte, error) {
|
|
// Convert module name to relative path
|
|
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
|
|
|
|
// List of paths to check
|
|
paths := []string{}
|
|
|
|
// 1. Check adjacent to script directory first
|
|
if config.ScriptDir != "" {
|
|
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
|
|
}
|
|
|
|
// 2. Check in lib directories
|
|
for _, libDir := range config.LibDirs {
|
|
if libDir != "" {
|
|
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
|
|
}
|
|
}
|
|
|
|
// Try each path
|
|
for _, path := range paths {
|
|
// Clean the path to handle .. and such (security)
|
|
cleanPath := filepath.Clean(path)
|
|
|
|
// Check for path traversal (extra safety)
|
|
if !isSubPath(config.ScriptDir, cleanPath) {
|
|
isValidLib := false
|
|
for _, libDir := range config.LibDirs {
|
|
if isSubPath(libDir, cleanPath) {
|
|
isValidLib = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !isValidLib {
|
|
continue // Skip paths outside allowed directories
|
|
}
|
|
}
|
|
|
|
// Check if already in cache
|
|
if bytecode, ok := cache.modules.Load(cleanPath); ok {
|
|
return bytecode.([]byte), nil
|
|
}
|
|
|
|
// Check if file exists
|
|
_, err := os.Stat(cleanPath)
|
|
if os.IsNotExist(err) {
|
|
continue
|
|
}
|
|
|
|
// Read and compile the file
|
|
content, err := os.ReadFile(cleanPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Compile to bytecode
|
|
bytecode, err := state.CompileBytecode(string(content), cleanPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Store in cache
|
|
cache.modules.Store(cleanPath, bytecode)
|
|
|
|
return bytecode, nil
|
|
}
|
|
|
|
return nil, ErrModuleNotFound
|
|
}
|
|
|
|
// isSubPath checks if path is contained within base directory
|
|
func isSubPath(baseDir, path string) bool {
|
|
if baseDir == "" {
|
|
return false
|
|
}
|
|
|
|
// Clean and normalize paths
|
|
baseDir = filepath.Clean(baseDir)
|
|
path = filepath.Clean(path)
|
|
|
|
// Get relative path
|
|
rel, err := filepath.Rel(baseDir, path)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if path goes outside baseDir
|
|
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
|
|
}
|