Moonshark/core/runner/require.go
2025-03-19 16:50:39 -05:00

213 lines
4.9 KiB
Go

package runner
import (
"errors"
"os"
"path/filepath"
"strings"
"sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrModuleNotFound = errors.New("module not found")
ErrPathTraversal = errors.New("path traversal not allowed")
)
// RequireConfig holds configuration for Lua's require function
type RequireConfig struct {
ScriptDir string // Base directory for script being executed
LibDirs []string // Additional library directories
}
// RequireCache is a thread-safe cache for loaded Lua modules
type RequireCache struct {
modules sync.Map // Maps full file paths to compiled bytecode
}
// NewRequireCache creates a new, empty require cache
func NewRequireCache() *RequireCache {
return &RequireCache{
modules: sync.Map{},
}
}
// SetupRequire configures the Lua state with a secure require function
func SetupRequire(state *luajit.State, cache *RequireCache, config RequireConfig) error {
// Register the loader function
err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int {
// Get module name
modName := s.ToString(1)
if modName == "" {
s.PushString("module name required")
return -1 // Return error
}
// Try to load the module
bytecode, err := findAndCompileModule(s, cache, config, modName)
if err != nil {
if err == ErrModuleNotFound {
s.PushString("module '" + modName + "' not found")
} else {
s.PushString("error loading module: " + err.Error())
}
return -1 // Return error
}
// Load the bytecode
if err := s.LoadBytecode(bytecode, modName); err != nil {
s.PushString("error loading bytecode: " + err.Error())
return -1 // Return error
}
// Return the loaded function
return 1
})
if err != nil {
return err
}
// Set up the secure require implementation
setupScript := `
-- Create a secure require function for sandboxed environments
function __setup_secure_require(env)
-- Replace env.require with our secure version
env.require = function(modname)
-- Check if already loaded in package.loaded
if package.loaded[modname] then
return package.loaded[modname]
end
-- Try to load the module using our Go loader
local loader = __go_load_module
-- Load the module
local f, err = loader(modname)
if not f then
error(err or "failed to load module: " .. modname)
end
-- Set the environment for the module
setfenv(f, env)
-- Execute the module
local result = f()
-- If module didn't return a value, use true
if result == nil then
result = true
end
-- Cache the result
package.loaded[modname] = result
return result
end
return env
end
`
return state.DoString(setupScript)
}
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
func findAndCompileModule(
state *luajit.State,
cache *RequireCache,
config RequireConfig,
modName string,
) ([]byte, error) {
// Convert module name to relative path
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
// List of paths to check
paths := []string{}
// 1. Check adjacent to script directory first
if config.ScriptDir != "" {
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
}
// 2. Check in lib directories
for _, libDir := range config.LibDirs {
if libDir != "" {
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
}
}
// Try each path
for _, path := range paths {
// Clean the path to handle .. and such (security)
cleanPath := filepath.Clean(path)
// Check for path traversal (extra safety)
if !isSubPath(config.ScriptDir, cleanPath) {
isValidLib := false
for _, libDir := range config.LibDirs {
if isSubPath(libDir, cleanPath) {
isValidLib = true
break
}
}
if !isValidLib {
continue // Skip paths outside allowed directories
}
}
// Check if already in cache
if bytecode, ok := cache.modules.Load(cleanPath); ok {
return bytecode.([]byte), nil
}
// Check if file exists
_, err := os.Stat(cleanPath)
if os.IsNotExist(err) {
continue
}
// Read and compile the file
content, err := os.ReadFile(cleanPath)
if err != nil {
return nil, err
}
// Compile to bytecode
bytecode, err := state.CompileBytecode(string(content), cleanPath)
if err != nil {
return nil, err
}
// Store in cache
cache.modules.Store(cleanPath, bytecode)
return bytecode, nil
}
return nil, ErrModuleNotFound
}
// isSubPath checks if path is contained within base directory
func isSubPath(baseDir, path string) bool {
if baseDir == "" {
return false
}
// Clean and normalize paths
baseDir = filepath.Clean(baseDir)
path = filepath.Clean(path)
// Get relative path
rel, err := filepath.Rel(baseDir, path)
if err != nil {
return false
}
// Check if path goes outside baseDir
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
}