Moonshark/runner/moduleLoader.go

293 lines
6.5 KiB
Go

package runner
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"Moonshark/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
type ModuleConfig struct {
ScriptDir string
LibDirs []string
}
type ModuleLoader struct {
config *ModuleConfig
pathCache map[string]string // For reverse lookups (path -> module name)
debug bool
mu sync.RWMutex
}
func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
if config == nil {
config = &ModuleConfig{}
}
return &ModuleLoader{
config: config,
pathCache: make(map[string]string),
}
}
func (l *ModuleLoader) EnableDebug() {
l.debug = true
}
func (l *ModuleLoader) SetScriptDir(dir string) {
l.mu.Lock()
defer l.mu.Unlock()
l.config.ScriptDir = dir
}
func (l *ModuleLoader) debugLog(format string, args ...any) {
if l.debug {
logger.Debugf("ModuleLoader "+format, args...)
}
}
func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
// Set package.path
paths := l.getSearchPaths()
pathStr := strings.Join(paths, ";")
return state.DoString(`package.path = "` + escapeLuaString(pathStr) + `"`)
}
func (l *ModuleLoader) getSearchPaths() []string {
var paths []string
seen := make(map[string]bool)
// Script directory first
if l.config.ScriptDir != "" {
if absPath, err := filepath.Abs(l.config.ScriptDir); err == nil && !seen[absPath] {
paths = append(paths, filepath.Join(absPath, "?.lua"))
seen[absPath] = true
}
}
// Library directories
for _, dir := range l.config.LibDirs {
if dir == "" {
continue
}
if absPath, err := filepath.Abs(dir); err == nil && !seen[absPath] {
paths = append(paths, filepath.Join(absPath, "?.lua"))
seen[absPath] = true
}
}
return paths
}
func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
l.mu.Lock()
defer l.mu.Unlock()
// Reset caches
l.pathCache = make(map[string]string)
// Clear non-core modules
err := state.DoString(`
local core = {string=1, table=1, math=1, os=1, package=1, io=1, coroutine=1, debug=1, _G=1}
for name in pairs(package.loaded) do
if not core[name] then package.loaded[name] = nil end
end
package.preload = {}
`)
if err != nil {
return err
}
// Scan and preload modules
for _, dir := range l.config.LibDirs {
if err := l.scanDirectory(state, dir); err != nil {
return err
}
}
// Install simplified require
return state.DoString(`
function __setup_require(env)
env.require = function(modname)
if package.loaded[modname] then
return package.loaded[modname]
end
local loader = package.preload[modname]
if loader then
setfenv(loader, env)
local result = loader() or true
package.loaded[modname] = result
return result
end
-- Standard path search
for path in package.path:gmatch("[^;]+") do
local file = path:gsub("?", modname:gsub("%.", "/"))
local chunk = loadfile(file)
if chunk then
setfenv(chunk, env)
local result = chunk() or true
package.loaded[modname] = result
return result
end
end
error("module '" .. modname .. "' not found", 2)
end
return env
end
`)
}
func (l *ModuleLoader) scanDirectory(state *luajit.State, dir string) error {
if dir == "" {
return nil
}
absDir, err := filepath.Abs(dir)
if err != nil {
return nil
}
l.debugLog("Scanning directory: %s", absDir)
return filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
return nil
}
relPath, err := filepath.Rel(absDir, path)
if err != nil || strings.HasPrefix(relPath, "..") {
return nil
}
// Convert to module name
modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
l.debugLog("Found module: %s at %s", modName, path)
l.pathCache[modName] = path
// Load and compile module
content, err := os.ReadFile(path)
if err != nil {
l.debugLog("Failed to read %s: %v", path, err)
return nil
}
if err := state.LoadString(string(content)); err != nil {
l.debugLog("Failed to compile %s: %v", path, err)
return nil
}
// Store in package.preload
state.GetGlobal("package")
state.GetField(-1, "preload")
state.PushString(modName)
state.PushCopy(-4) // Copy compiled function
state.SetTable(-3)
state.Pop(2) // Pop package and preload
state.Pop(1) // Pop function
return nil
})
}
func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
l.mu.RLock()
defer l.mu.RUnlock()
absPath, err := filepath.Abs(path)
if err != nil {
absPath = filepath.Clean(path)
}
// Direct lookup
for modName, modPath := range l.pathCache {
if modPath == absPath {
return modName, true
}
}
// Construct from lib dirs
for _, dir := range l.config.LibDirs {
absDir, err := filepath.Abs(dir)
if err != nil {
continue
}
relPath, err := filepath.Rel(absDir, absPath)
if err != nil || strings.HasPrefix(relPath, "..") || !strings.HasSuffix(relPath, ".lua") {
continue
}
modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
return modName, true
}
return "", false
}
func (l *ModuleLoader) RefreshModule(state *luajit.State, moduleName string) error {
l.mu.Lock()
defer l.mu.Unlock()
path, exists := l.pathCache[moduleName]
if !exists {
return fmt.Errorf("module %s not found", moduleName)
}
l.debugLog("Refreshing module: %s", moduleName)
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read module: %w", err)
}
// Compile new version
if err := state.LoadString(string(content)); err != nil {
return fmt.Errorf("failed to compile module: %w", err)
}
// Update package.preload
state.GetGlobal("package")
state.GetField(-1, "preload")
state.PushString(moduleName)
state.PushCopy(-4) // Copy function
state.SetTable(-3)
state.Pop(2) // Pop package and preload
state.Pop(1) // Pop function
// Clear from loaded
state.DoString(`package.loaded["` + escapeLuaString(moduleName) + `"] = nil`)
l.debugLog("Successfully refreshed: %s", moduleName)
return nil
}
func (l *ModuleLoader) RefreshModuleByPath(state *luajit.State, filePath string) error {
moduleName, exists := l.GetModuleByPath(filePath)
if !exists {
return fmt.Errorf("no module found for path: %s", filePath)
}
return l.RefreshModule(state, moduleName)
}
func escapeLuaString(s string) string {
return strings.NewReplacer(
`\`, `\\`,
`"`, `\"`,
"\n", `\n`,
"\r", `\r`,
"\t", `\t`,
).Replace(s)
}