diff --git a/http/server.go b/http/server.go index cc9d84c..96ba000 100644 --- a/http/server.go +++ b/http/server.go @@ -186,6 +186,10 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip luaCtx := runner.NewHTTPContext(ctx) defer luaCtx.Release() + if runner.GetGlobalEnvManager() != nil { + luaCtx.Set("env", runner.GetGlobalEnvManager().GetAll()) + } + sessionMap := s.ctxPool.Get().(map[string]any) defer func() { for k := range sessionMap { diff --git a/main.go b/main.go index 5f6e18e..7a356f0 100644 --- a/main.go +++ b/main.go @@ -252,6 +252,10 @@ func (s *Moonshark) Shutdown() error { s.LuaRunner.Close() } + if err := runner.CleanupEnv(); err != nil { + logger.Warning("Environment cleanup failed: %v", err) + } + logger.Info("Shutdown complete") return nil } @@ -302,6 +306,11 @@ func initRunner(s *Moonshark) error { } } + // Initialize environment manager + if err := runner.InitEnv(s.Config.Dirs.Data); err != nil { + logger.Warning("Environment initialization failed: %v", err) + } + sessionManager := sessions.GlobalSessionManager sessionManager.SetCookieOptions( "MoonsharkSID", diff --git a/runner/context.go b/runner/context.go index 5324645..382e682 100644 --- a/runner/context.go +++ b/runner/context.go @@ -30,7 +30,8 @@ var contextPool = sync.Pool{ // NewContext creates a new context, potentially reusing one from the pool func NewContext() *Context { - return contextPool.Get().(*Context) + ctx := contextPool.Get().(*Context) + return ctx } // NewHTTPContext creates a new context from a fasthttp RequestCtx diff --git a/runner/embed.go b/runner/embed.go index 74e4d85..da9a2f0 100644 --- a/runner/embed.go +++ b/runner/embed.go @@ -40,6 +40,9 @@ var timeLuaCode string //go:embed lua/math.lua var mathLuaCode string +//go:embed lua/env.lua +var envLuaCode string + // ModuleInfo holds information about an embeddable Lua module type ModuleInfo struct { Name string // Module name @@ -61,6 +64,7 @@ var ( {Name: "crypto", Code: cryptoLuaCode, DefinesGlobal: true}, {Name: "time", Code: timeLuaCode}, {Name: "math", Code: mathLuaCode}, + {Name: "env", Code: envLuaCode, DefinesGlobal: true}, } ) diff --git a/runner/env.go b/runner/env.go new file mode 100644 index 0000000..03c1432 --- /dev/null +++ b/runner/env.go @@ -0,0 +1,278 @@ +package runner + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "Moonshark/utils/color" + "Moonshark/utils/logger" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// EnvManager handles loading, storing, and saving environment variables +type EnvManager struct { + envPath string // Path to .env file + vars map[string]any // Environment variables in memory + mu sync.RWMutex // Thread-safe access +} + +// Global environment manager instance +var globalEnvManager *EnvManager + +// InitEnv initializes the environment manager with the given data directory +func InitEnv(dataDir string) error { + if dataDir == "" { + return fmt.Errorf("data directory cannot be empty") + } + + // Create data directory if it doesn't exist + if err := os.MkdirAll(dataDir, 0755); err != nil { + return fmt.Errorf("failed to create data directory: %w", err) + } + + envPath := filepath.Join(dataDir, ".env") + + globalEnvManager = &EnvManager{ + envPath: envPath, + vars: make(map[string]any), + } + + // Load existing .env file if it exists + if err := globalEnvManager.load(); err != nil { + logger.Warning("Failed to load .env file: %v", err) + } + + count := len(globalEnvManager.vars) + if count > 0 { + logger.Info("Environment loaded: %s vars from %s", + color.Apply(fmt.Sprintf("%d", count), color.Yellow), + color.Apply(envPath, color.Yellow)) + } else { + logger.Info("Environment initialized: %s", color.Apply(envPath, color.Yellow)) + } + + return nil +} + +// GetGlobalEnvManager returns the global environment manager instance +func GetGlobalEnvManager() *EnvManager { + return globalEnvManager +} + +// load reads the .env file and populates the vars map +func (e *EnvManager) load() error { + file, err := os.Open(e.envPath) + if os.IsNotExist(err) { + // File doesn't exist, start with empty env + return nil + } + if err != nil { + return fmt.Errorf("failed to open .env file: %w", err) + } + defer file.Close() + + e.mu.Lock() + defer e.mu.Unlock() + + scanner := bufio.NewScanner(file) + lineNum := 0 + + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Parse key=value + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + logger.Warning("Invalid .env line %d: %s", lineNum, line) + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Remove quotes if present + if len(value) >= 2 { + if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) || + (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) { + value = value[1 : len(value)-1] + } + } + + e.vars[key] = value + } + + return scanner.Err() +} + +// Save writes the current environment variables to the .env file +func (e *EnvManager) Save() error { + if e == nil { + return nil // No env manager initialized + } + + e.mu.RLock() + defer e.mu.RUnlock() + + file, err := os.Create(e.envPath) + if err != nil { + return fmt.Errorf("failed to create .env file: %w", err) + } + defer file.Close() + + // Sort keys for consistent output + keys := make([]string, 0, len(e.vars)) + for key := range e.vars { + keys = append(keys, key) + } + sort.Strings(keys) + + // Write header comment + fmt.Fprintln(file, "# Environment variables for Moonshark") + fmt.Fprintln(file, "# Generated automatically - you can edit this file") + fmt.Fprintln(file) + + // Write each variable + for _, key := range keys { + value := e.vars[key] + + // Convert value to string + var strValue string + switch v := value.(type) { + case string: + strValue = v + case nil: + continue // Skip nil values + default: + strValue = fmt.Sprintf("%v", v) + } + + // Quote values that contain spaces or special characters + if strings.ContainsAny(strValue, " \t\n\r\"'\\") { + strValue = fmt.Sprintf("\"%s\"", strings.ReplaceAll(strValue, "\"", "\\\"")) + } + + fmt.Fprintf(file, "%s=%s\n", key, strValue) + } + + logger.Debug("Environment saved: %d vars to %s", len(e.vars), e.envPath) + return nil +} + +// Get retrieves an environment variable +func (e *EnvManager) Get(key string) (any, bool) { + if e == nil { + return nil, false + } + + e.mu.RLock() + defer e.mu.RUnlock() + + value, exists := e.vars[key] + return value, exists +} + +// Set stores an environment variable +func (e *EnvManager) Set(key string, value any) { + if e == nil { + return + } + + e.mu.Lock() + defer e.mu.Unlock() + + e.vars[key] = value +} + +// GetAll returns a copy of all environment variables +func (e *EnvManager) GetAll() map[string]any { + if e == nil { + return make(map[string]any) + } + + e.mu.RLock() + defer e.mu.RUnlock() + + result := make(map[string]any, len(e.vars)) + for k, v := range e.vars { + result[k] = v + } + return result +} + +// CleanupEnv saves the environment and cleans up resources +func CleanupEnv() error { + if globalEnvManager != nil { + return globalEnvManager.Save() + } + return nil +} + +// envGet Lua function to get an environment variable +func envGet(state *luajit.State) int { + if !state.IsString(1) { + state.PushNil() + return 1 + } + + key := state.ToString(1) + if value, exists := globalEnvManager.Get(key); exists { + if err := state.PushValue(value); err != nil { + state.PushNil() + } + } else { + state.PushNil() + } + return 1 +} + +// envSet Lua function to set an environment variable +func envSet(state *luajit.State) int { + if !state.IsString(1) || !state.IsString(2) { + state.PushBoolean(false) + return 1 + } + + key := state.ToString(1) + value := state.ToString(2) + + globalEnvManager.Set(key, value) + state.PushBoolean(true) + return 1 +} + +// envGetAll Lua function to get all environment variables +func envGetAll(state *luajit.State) int { + vars := globalEnvManager.GetAll() + + if err := state.PushTable(vars); err != nil { + state.PushNil() + } + + return 1 +} + +// RegisterEnvFunctions registers environment functions with the Lua state +func RegisterEnvFunctions(state *luajit.State) error { + if err := state.RegisterGoFunction("__env_get", envGet); err != nil { + return err + } + if err := state.RegisterGoFunction("__env_set", envSet); err != nil { + return err + } + if err := state.RegisterGoFunction("__env_get_all", envGetAll); err != nil { + return err + } + return nil +} diff --git a/runner/lua/env.lua b/runner/lua/env.lua new file mode 100644 index 0000000..b3b77a5 --- /dev/null +++ b/runner/lua/env.lua @@ -0,0 +1,93 @@ +-- Environment variable module for Moonshark +-- Provides access to persistent environment variables stored in .env file + +-- Get an environment variable with a default value +-- Returns the value if it exists, default_value otherwise +function env_get(key, default_value) + if type(key) ~= "string" then + error("env_get: key must be a string") + end + + -- First check context for environment variables (no Go call needed) + if _env and _env[key] ~= nil then + return _env[key] + end + + return default_value +end + +-- Set an environment variable +-- Returns true on success, false on failure +function env_set(key, value) + if type(key) ~= "string" then + error("env_set: key must be a string") + end + + -- Update context immediately for future reads + if not _env then + _env = {} + end + _env[key] = value + + -- Persist to Go backend + return __env_set(key, value) +end + +-- Get all environment variables as a table +-- Returns a table with all key-value pairs +function env_get_all() + -- Return context table directly if available + if _env then + local copy = {} + for k, v in pairs(_env) do + copy[k] = v + end + return copy + end + + -- Fallback to Go call + return __env_get_all() +end + +-- Check if an environment variable exists +-- Returns true if the variable exists, false otherwise +function env_exists(key) + if type(key) ~= "string" then + error("env_exists: key must be a string") + end + + -- Check context first + if _env then + return _env[key] ~= nil + end + + return false +end + +-- Set multiple environment variables from a table +-- Returns true on success, false if any setting failed +function env_set_many(vars) + if type(vars) ~= "table" then + error("env_set_many: vars must be a table") + end + + if not _env then + _env = {} + end + + local success = true + for key, value in pairs(vars) do + if type(key) == "string" and type(value) == "string" then + -- Update context + _env[key] = value + -- Persist to Go + if not __env_set(key, value) then + success = false + end + else + error("env_set_many: all keys and values must be strings") + end + end + + return success +end diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua index b205a51..d0783d3 100644 --- a/runner/lua/sandbox.lua +++ b/runner/lua/sandbox.lua @@ -22,6 +22,10 @@ function __create_env(ctx) if ctx then env.ctx = ctx + + if ctx._env then + env._env = ctx._env + end end if __setup_require then @@ -457,10 +461,10 @@ _G.render = function(template_str, env) end local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$") - + -- Check if it's a simple variable name for escaped output local is_simple_var = tag_type == "=" and code:match("^[%w_]+$") - + table.insert(chunks, {tag_type, code, pos, is_simple_var}) pos = close_stop + 1 end @@ -660,4 +664,4 @@ end function send_binary(content, mime_type) http_set_content_type(mime_type or "application/octet-stream") return content -end \ No newline at end of file +end diff --git a/runner/moduleLoader.go b/runner/moduleLoader.go index e1ef921..72bccb9 100644 --- a/runner/moduleLoader.go +++ b/runner/moduleLoader.go @@ -1,6 +1,7 @@ package runner import ( + "fmt" "os" "path/filepath" "strings" @@ -71,15 +72,11 @@ func (l *ModuleLoader) SetupRequire(state *luajit.State) error { err := state.DoString(` -- Initialize global module registry __module_paths = {} - - -- Setup fast module loading system __module_bytecode = {} + __ready_modules = {} -- Create module preload table package.preload = package.preload or {} - - -- Setup module state registry - __ready_modules = {} `) if err != nil { @@ -209,7 +206,7 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error { // Cache bytecode l.bytecodeCache[modName] = bytecode - // Register in Lua + // Register in Lua - store path info escapedPath := escapeLuaString(path) escapedName := escapeLuaString(modName) @@ -217,25 +214,26 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error { return nil } - // Load bytecode into Lua state + // Load bytecode and register in package.preload properly if err := state.LoadBytecode(bytecode, path); err != nil { return nil } - // Add to package.preload - luaCode := ` - local modname = "` + escapedName + `" - local chunk = ... - package.preload[modname] = chunk - __ready_modules[modname] = true - ` + // Store the function in package.preload - the function is on the stack + state.GetGlobal("package") + state.GetField(-1, "preload") + state.PushString(modName) + state.PushCopy(-4) // Copy the compiled function + state.SetTable(-3) // preload[modName] = function + state.Pop(2) // Pop package and preload tables - if err := state.DoString(luaCode); err != nil { - state.Pop(1) // Remove chunk from stack + // Mark as ready + if err := state.DoString(`__ready_modules["` + escapedName + `"] = true`); err != nil { + state.Pop(1) // Remove the function from stack return nil } - state.Pop(1) // Remove chunk from stack + state.Pop(1) // Remove the function from stack return nil }) @@ -318,24 +316,28 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) { l.mu.RLock() defer l.mu.RUnlock() - // Clean path for proper comparison - path = filepath.Clean(path) + // Convert to absolute path for consistent comparison + absPath, err := filepath.Abs(path) + if err != nil { + absPath = filepath.Clean(path) + } - // Try direct lookup from cache + // Try direct lookup from cache with absolute path for modName, modPath := range l.pathCache { - if modPath == path { + if modPath == absPath { return modName, true } } - // Try to find by relative path from lib dirs + // Try to construct module name from lib dirs for _, dir := range l.config.LibDirs { absDir, err := filepath.Abs(dir) if err != nil { continue } - relPath, err := filepath.Rel(absDir, path) + // Check if the file is under this lib directory + relPath, err := filepath.Rel(absDir, absPath) if err != nil || strings.HasPrefix(relPath, "..") { continue } @@ -343,13 +345,78 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) { if strings.HasSuffix(relPath, ".lua") { modName := strings.TrimSuffix(relPath, ".lua") modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") + + l.debugLog("Found module %s for path %s", modName, path) return modName, true } } + l.debugLog("No module found for path %s", path) return "", false } +// RefreshModule recompiles and updates a specific module +func (l *ModuleLoader) RefreshModule(state *luajit.State, moduleName string) error { + l.mu.Lock() + defer l.mu.Unlock() + + // Get module path + path, exists := l.pathCache[moduleName] + if !exists { + l.debugLog("Module not found in cache: %s", moduleName) + return fmt.Errorf("module %s not found", moduleName) + } + + l.debugLog("Refreshing module: %s at %s", moduleName, path) + + // Read updated file content + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to read module file: %w", err) + } + + // Recompile to bytecode + bytecode, err := state.CompileBytecode(string(content), path) + if err != nil { + return fmt.Errorf("failed to compile module: %w", err) + } + + // Update bytecode cache + l.bytecodeCache[moduleName] = bytecode + + // Load new bytecode + if err := state.LoadBytecode(bytecode, path); err != nil { + return fmt.Errorf("failed to load bytecode: %w", err) + } + + // Update package.preload with new function (function is on stack) + state.GetGlobal("package") + state.GetField(-1, "preload") + state.PushString(moduleName) + state.PushCopy(-4) // Copy the new compiled function + state.SetTable(-3) // preload[moduleName] = new_function + state.Pop(2) // Pop package and preload tables + state.Pop(1) // Pop the function + + // Clear from package.loaded so it gets reloaded + escapedName := escapeLuaString(moduleName) + if err := state.DoString(`package.loaded["` + escapedName + `"] = nil`); err != nil { + return fmt.Errorf("failed to clear loaded module: %w", err) + } + + l.debugLog("Successfully refreshed module: %s", moduleName) + return nil +} + +// RefreshModuleByPath refreshes a module by its file path +func (l *ModuleLoader) RefreshModuleByPath(state *luajit.State, filePath string) error { + moduleName, exists := l.GetModuleByPath(filePath) + if !exists { + return fmt.Errorf("no module found for path: %s", filePath) + } + return l.RefreshModule(state, moduleName) +} + // escapeLuaString escapes special characters in a string for Lua func escapeLuaString(s string) string { replacer := strings.NewReplacer( diff --git a/runner/runner.go b/runner/runner.go index 8c9267a..02759ee 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -384,8 +384,7 @@ waitForInUse: // NotifyFileChanged alerts the runner about file changes func (r *Runner) NotifyFileChanged(filePath string) bool { - logger.Debug("Runner has been notified of a file change...") - logger.Debug("%s", filePath) + logger.Debug("Runner notified of file change: %s", filePath) module, isModule := r.moduleLoader.GetModuleByPath(filePath) if isModule { @@ -393,7 +392,7 @@ func (r *Runner) NotifyFileChanged(filePath string) bool { return r.RefreshModule(module) } - logger.Debug("File change noted but no state refresh needed: %s", filePath) + logger.Debug("File change noted but no refresh needed: %s", filePath) return true } @@ -414,10 +413,41 @@ func (r *Runner) RefreshModule(moduleName string) bool { continue } - // Invalidate module in Lua - if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil { + // Use the enhanced module loader refresh + if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil { success = false - logger.Debug("Failed to invalidate module %s: %v", moduleName, err) + logger.Debug("Failed to refresh module %s in state %d: %v", moduleName, state.index, err) + } + } + + if success { + logger.Debug("Successfully refreshed module: %s", moduleName) + } + + return success +} + +// RefreshModuleByPath refreshes a module by its file path +func (r *Runner) RefreshModuleByPath(filePath string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.isRunning.Load() { + return false + } + + logger.Debug("Refreshing module by path: %s", filePath) + + success := true + for _, state := range r.states { + if state == nil || state.inUse.Load() { + continue + } + + // Use the enhanced module loader refresh by path + if err := r.moduleLoader.RefreshModuleByPath(state.L, filePath); err != nil { + success = false + logger.Debug("Failed to refresh module at %s in state %d: %v", filePath, state.index, err) } } diff --git a/runner/sandbox.go b/runner/sandbox.go index f8f69e2..6fe6c5f 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -121,6 +121,10 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { return err } + if err := RegisterEnvFunctions(state); err != nil { + return err + } + return nil }