From 945886abe639e321aa9325f462c1441f131a496e Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 3 Apr 2025 10:12:14 -0500 Subject: [PATCH] csrf 2 go 1 --- core/runner/CoreModules.go | 183 +++++++++++++++++++++++++++++++------ core/runner/Csrf.go | 52 +++++------ core/runner/Go.go | 62 +++++++++++++ 3 files changed, 238 insertions(+), 59 deletions(-) create mode 100644 core/runner/Go.go diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go index fc98e33..52ebb29 100644 --- a/core/runner/CoreModules.go +++ b/core/runner/CoreModules.go @@ -11,16 +11,22 @@ import ( // CoreModuleRegistry manages the initialization and reloading of core modules type CoreModuleRegistry struct { - modules map[string]StateInitFunc - mu sync.RWMutex - debug bool + modules map[string]StateInitFunc + initOrder []string // Explicit initialization order + dependencies map[string][]string // Module dependencies + initializedFlag map[string]bool // Track which modules are initialized + mu sync.RWMutex + debug bool } // NewCoreModuleRegistry creates a new core module registry func NewCoreModuleRegistry() *CoreModuleRegistry { return &CoreModuleRegistry{ - modules: make(map[string]StateInitFunc), - debug: false, + modules: make(map[string]StateInitFunc), + initOrder: []string{}, + dependencies: make(map[string][]string), + initializedFlag: make(map[string]bool), + debug: false, } } @@ -41,9 +47,77 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) { r.mu.Lock() defer r.mu.Unlock() r.modules[name] = initFunc + + // Add to initialization order if not already there + found := false + for _, n := range r.initOrder { + if n == name { + found = true + break + } + } + + if !found { + r.initOrder = append(r.initOrder, name) + } + r.debugLog("Registered module: %s", name) } +// RegisterWithDependencies registers a module with explicit dependencies +func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc StateInitFunc, dependencies []string) { + r.mu.Lock() + defer r.mu.Unlock() + + r.modules[name] = initFunc + r.dependencies[name] = dependencies + + // Add to initialization order if not already there + found := false + for _, n := range r.initOrder { + if n == name { + found = true + break + } + } + + if !found { + r.initOrder = append(r.initOrder, name) + } + + r.debugLog("Registered module %s with dependencies: %v", name, dependencies) +} + +// SetInitOrder sets explicit initialization order +func (r *CoreModuleRegistry) SetInitOrder(order []string) { + r.mu.Lock() + defer r.mu.Unlock() + + // First add all known modules that are in the specified order + for _, name := range order { + if _, exists := r.modules[name]; exists { + r.initOrder = append(r.initOrder, name) + } + } + + // Then add any modules not in the specified order + for name := range r.modules { + found := false + for _, n := range r.initOrder { + if n == name { + found = true + break + } + } + + if !found { + r.initOrder = append(r.initOrder, name) + } + } + + r.debugLog("Set initialization order: %v", r.initOrder) +} + // Initialize initializes all registered modules func (r *CoreModuleRegistry) Initialize(state *luajit.State) error { r.mu.RLock() @@ -51,30 +125,62 @@ func (r *CoreModuleRegistry) Initialize(state *luajit.State) error { r.debugLog("Initializing all modules...") - // Get all module init functions - initFuncs := r.getInitFuncs() + // Clear initialization flags + r.initializedFlag = make(map[string]bool) - // Initialize modules one by one to better track issues - for name, initFunc := range initFuncs { - r.debugLog("Initializing module: %s", name) - if err := initFunc(state); err != nil { - r.debugLog("Failed to initialize module %s: %v", name, err) - return fmt.Errorf("failed to initialize module %s: %w", name, err) + // Initialize modules in order, respecting dependencies + for _, name := range r.initOrder { + if err := r.initializeModule(state, name, []string{}); err != nil { + return err } - r.debugLog("Module %s initialized successfully", name) } r.debugLog("All modules initialized successfully") return nil } -// getInitFuncs returns all module init functions -func (r *CoreModuleRegistry) getInitFuncs() map[string]StateInitFunc { - funcs := make(map[string]StateInitFunc, len(r.modules)) - for name, initFunc := range r.modules { - funcs[name] = initFunc +// initializeModule initializes a module and its dependencies +func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string, initStack []string) error { + // Check if already initialized + if r.initializedFlag[name] { + return nil } - return funcs + + // Check for circular dependencies + for _, n := range initStack { + if n == name { + return fmt.Errorf("circular dependency detected: %s -> %s", + strings.Join(initStack, " -> "), name) + } + } + + // Get init function + initFunc, ok := r.modules[name] + if !ok { + return fmt.Errorf("module not found: %s", name) + } + + // Initialize dependencies first + deps := r.dependencies[name] + for _, dep := range deps { + newStack := append(initStack, name) + if err := r.initializeModule(state, dep, newStack); err != nil { + return err + } + } + + // Initialize this module + r.debugLog("Initializing module: %s", name) + if err := initFunc(state); err != nil { + r.debugLog("Failed to initialize module %s: %v", name, err) + return fmt.Errorf("failed to initialize module %s: %w", name, err) + } + + // Mark as initialized + r.initializedFlag[name] = true + r.debugLog("Module %s initialized successfully", name) + + return nil } // InitializeModule initializes a specific module @@ -82,14 +188,10 @@ func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string) r.mu.RLock() defer r.mu.RUnlock() - initFunc, ok := r.modules[name] - if !ok { - r.debugLog("Module not found: %s", name) - return nil // Module not found, no error - } + // Clear initialization flag for this module + r.initializedFlag[name] = false - r.debugLog("Reinitializing module: %s", name) - return initFunc(state) + return r.initializeModule(state, name, []string{}) } // ModuleNames returns a list of all registered module names @@ -130,9 +232,27 @@ var GlobalRegistry = NewCoreModuleRegistry() // Initialize global registry with core modules func init() { GlobalRegistry.EnableDebug() // Enable debugging by default + + // Register modules + GlobalRegistry.Register("go", GoModuleInitFunc()) + + // Register HTTP module (no dependencies) GlobalRegistry.Register("http", HTTPModuleInitFunc()) - GlobalRegistry.Register("cookie", CookieModuleInitFunc()) - GlobalRegistry.Register("csrf", CSRFModuleInitFunc()) + + // Register cookie module (depends on http) + GlobalRegistry.RegisterWithDependencies("cookie", CookieModuleInitFunc(), []string{"http"}) + + // Register CSRF module (depends on go) + GlobalRegistry.RegisterWithDependencies("csrf", CSRFModuleInitFunc(), []string{"go"}) + + // Set explicit initialization order + GlobalRegistry.SetInitOrder([]string{ + "go", // First: core utilities + "http", // Second: HTTP functionality + "cookie", // Third: Cookie functionality (uses HTTP) + "csrf", // Fourth: CSRF protection (uses go and possibly session) + }) + logger.Debug("[CoreModuleRegistry] Core modules registered in init()") } @@ -141,3 +261,8 @@ func init() { func RegisterCoreModule(name string, initFunc StateInitFunc) { GlobalRegistry.Register(name, initFunc) } + +// RegisterCoreModuleWithDependencies registers a module with dependencies +func RegisterCoreModuleWithDependencies(name string, initFunc StateInitFunc, dependencies []string) { + GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies) +} diff --git a/core/runner/Csrf.go b/core/runner/Csrf.go index 8c0bd04..f0812e9 100644 --- a/core/runner/Csrf.go +++ b/core/runner/Csrf.go @@ -13,7 +13,7 @@ const LuaCSRFModule = ` local csrf = { -- Session key where the token is stored TOKEN_KEY = "_csrf_token", - + -- Default form field name DEFAULT_FIELD = "csrf", @@ -21,83 +21,75 @@ local csrf = { generate = function(length) -- Default length is 32 characters length = length or 32 - + if length < 16 then -- Enforce minimum security length = 16 end - + -- Check if we have a session module if not session then error("CSRF protection requires the session module", 2) end - - -- Generate a secure random token using os.time and math.random - local token = "" - local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - - -- Seed the random generator with current time - math.randomseed(os.time()) - - -- Generate random string - for i = 1, length do - local idx = math.random(1, #chars) - token = token .. chars:sub(idx, idx) - end - + + -- Use Go's secure token generation + local token = go.generate_token(length) + -- Store in session session.set(csrf.TOKEN_KEY, token) - + return token end, - + -- Get the current token or generate a new one token = function() -- Get from session if exists local token = session.get(csrf.TOKEN_KEY) - + -- Generate if needed if not token then token = csrf.generate() end - + return token end, - + -- Generate a hidden form field with the CSRF token field = function(field_name) field_name = field_name or csrf.DEFAULT_FIELD local token = csrf.token() return string.format('', field_name, token) end, - + -- Verify a given token against the session token verify = function(token, field_name) field_name = field_name or csrf.DEFAULT_FIELD - + local env = getfenv(2) - + local form = nil if env.ctx and env.ctx.form then form = env.ctx.form else return false end - + token = token or form[field_name] if not token then return false end - + local session_token = session.get(csrf.TOKEN_KEY) if not session_token then return false end - + + -- Constant-time comparison to prevent timing attacks + -- This is safe since Lua strings are immutable if #token ~= #session_token then return false end - + local result = true for i = 1, #token do if token:sub(i, i) ~= session_token:sub(i, i) then @@ -105,7 +97,7 @@ local csrf = { -- Don't break early - continue to prevent timing attacks end end - + return result end } diff --git a/core/runner/Go.go b/core/runner/Go.go new file mode 100644 index 0000000..ccd2a64 --- /dev/null +++ b/core/runner/Go.go @@ -0,0 +1,62 @@ +package runner + +import ( + "crypto/rand" + "encoding/base64" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "git.sharkk.net/Sky/Moonshark/core/logger" +) + +// GenerateToken creates a cryptographically secure random token +func GenerateToken(s *luajit.State) int { + // Get the length from the Lua arguments (default to 32) + length := 32 + if s.GetTop() >= 1 && s.IsNumber(1) { + length = int(s.ToNumber(1)) + } + + // Enforce minimum length for security + if length < 16 { + length = 16 + } + + // Generate secure random bytes + tokenBytes := make([]byte, length) + if _, err := rand.Read(tokenBytes); err != nil { + s.PushString("") + logger.Error("Failed to generate secure token: %v", err) + return 1 // Return empty string on error + } + + // Encode as base64 + token := base64.RawURLEncoding.EncodeToString(tokenBytes) + + // Trim to requested length (base64 might be longer) + if len(token) > length { + token = token[:length] + } + + // Push the token to the Lua stack + s.PushString(token) + return 1 // One return value +} + +// GoModuleFunctions returns all functions for the go module +func GoModuleFunctions() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "generate_token": GenerateToken, + } +} + +// GoModuleInitFunc returns an initializer for the go module +func GoModuleInitFunc() StateInitFunc { + return func(state *luajit.State) error { + return RegisterModule(state, "go", GoModuleFunctions()) + } +} + +// Initialize the core module during startup +func init() { + RegisterCoreModule("go", GoModuleInitFunc()) +}