csrf 2 go 1

This commit is contained in:
Sky Johnson 2025-04-03 10:12:14 -05:00
parent eea5ba8c8a
commit 945886abe6
3 changed files with 238 additions and 59 deletions

View File

@ -11,16 +11,22 @@ import (
// CoreModuleRegistry manages the initialization and reloading of core modules
type CoreModuleRegistry struct {
modules map[string]StateInitFunc
mu sync.RWMutex
debug bool
modules map[string]StateInitFunc
initOrder []string // Explicit initialization order
dependencies map[string][]string // Module dependencies
initializedFlag map[string]bool // Track which modules are initialized
mu sync.RWMutex
debug bool
}
// NewCoreModuleRegistry creates a new core module registry
func NewCoreModuleRegistry() *CoreModuleRegistry {
return &CoreModuleRegistry{
modules: make(map[string]StateInitFunc),
debug: false,
modules: make(map[string]StateInitFunc),
initOrder: []string{},
dependencies: make(map[string][]string),
initializedFlag: make(map[string]bool),
debug: false,
}
}
@ -41,9 +47,77 @@ func (r *CoreModuleRegistry) Register(name string, initFunc StateInitFunc) {
r.mu.Lock()
defer r.mu.Unlock()
r.modules[name] = initFunc
// Add to initialization order if not already there
found := false
for _, n := range r.initOrder {
if n == name {
found = true
break
}
}
if !found {
r.initOrder = append(r.initOrder, name)
}
r.debugLog("Registered module: %s", name)
}
// RegisterWithDependencies registers a module with explicit dependencies
func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc StateInitFunc, dependencies []string) {
r.mu.Lock()
defer r.mu.Unlock()
r.modules[name] = initFunc
r.dependencies[name] = dependencies
// Add to initialization order if not already there
found := false
for _, n := range r.initOrder {
if n == name {
found = true
break
}
}
if !found {
r.initOrder = append(r.initOrder, name)
}
r.debugLog("Registered module %s with dependencies: %v", name, dependencies)
}
// SetInitOrder sets explicit initialization order
func (r *CoreModuleRegistry) SetInitOrder(order []string) {
r.mu.Lock()
defer r.mu.Unlock()
// First add all known modules that are in the specified order
for _, name := range order {
if _, exists := r.modules[name]; exists {
r.initOrder = append(r.initOrder, name)
}
}
// Then add any modules not in the specified order
for name := range r.modules {
found := false
for _, n := range r.initOrder {
if n == name {
found = true
break
}
}
if !found {
r.initOrder = append(r.initOrder, name)
}
}
r.debugLog("Set initialization order: %v", r.initOrder)
}
// Initialize initializes all registered modules
func (r *CoreModuleRegistry) Initialize(state *luajit.State) error {
r.mu.RLock()
@ -51,30 +125,62 @@ func (r *CoreModuleRegistry) Initialize(state *luajit.State) error {
r.debugLog("Initializing all modules...")
// Get all module init functions
initFuncs := r.getInitFuncs()
// Clear initialization flags
r.initializedFlag = make(map[string]bool)
// Initialize modules one by one to better track issues
for name, initFunc := range initFuncs {
r.debugLog("Initializing module: %s", name)
if err := initFunc(state); err != nil {
r.debugLog("Failed to initialize module %s: %v", name, err)
return fmt.Errorf("failed to initialize module %s: %w", name, err)
// Initialize modules in order, respecting dependencies
for _, name := range r.initOrder {
if err := r.initializeModule(state, name, []string{}); err != nil {
return err
}
r.debugLog("Module %s initialized successfully", name)
}
r.debugLog("All modules initialized successfully")
return nil
}
// getInitFuncs returns all module init functions
func (r *CoreModuleRegistry) getInitFuncs() map[string]StateInitFunc {
funcs := make(map[string]StateInitFunc, len(r.modules))
for name, initFunc := range r.modules {
funcs[name] = initFunc
// initializeModule initializes a module and its dependencies
func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string, initStack []string) error {
// Check if already initialized
if r.initializedFlag[name] {
return nil
}
return funcs
// Check for circular dependencies
for _, n := range initStack {
if n == name {
return fmt.Errorf("circular dependency detected: %s -> %s",
strings.Join(initStack, " -> "), name)
}
}
// Get init function
initFunc, ok := r.modules[name]
if !ok {
return fmt.Errorf("module not found: %s", name)
}
// Initialize dependencies first
deps := r.dependencies[name]
for _, dep := range deps {
newStack := append(initStack, name)
if err := r.initializeModule(state, dep, newStack); err != nil {
return err
}
}
// Initialize this module
r.debugLog("Initializing module: %s", name)
if err := initFunc(state); err != nil {
r.debugLog("Failed to initialize module %s: %v", name, err)
return fmt.Errorf("failed to initialize module %s: %w", name, err)
}
// Mark as initialized
r.initializedFlag[name] = true
r.debugLog("Module %s initialized successfully", name)
return nil
}
// InitializeModule initializes a specific module
@ -82,14 +188,10 @@ func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string)
r.mu.RLock()
defer r.mu.RUnlock()
initFunc, ok := r.modules[name]
if !ok {
r.debugLog("Module not found: %s", name)
return nil // Module not found, no error
}
// Clear initialization flag for this module
r.initializedFlag[name] = false
r.debugLog("Reinitializing module: %s", name)
return initFunc(state)
return r.initializeModule(state, name, []string{})
}
// ModuleNames returns a list of all registered module names
@ -130,9 +232,27 @@ var GlobalRegistry = NewCoreModuleRegistry()
// Initialize global registry with core modules
func init() {
GlobalRegistry.EnableDebug() // Enable debugging by default
// Register modules
GlobalRegistry.Register("go", GoModuleInitFunc())
// Register HTTP module (no dependencies)
GlobalRegistry.Register("http", HTTPModuleInitFunc())
GlobalRegistry.Register("cookie", CookieModuleInitFunc())
GlobalRegistry.Register("csrf", CSRFModuleInitFunc())
// Register cookie module (depends on http)
GlobalRegistry.RegisterWithDependencies("cookie", CookieModuleInitFunc(), []string{"http"})
// Register CSRF module (depends on go)
GlobalRegistry.RegisterWithDependencies("csrf", CSRFModuleInitFunc(), []string{"go"})
// Set explicit initialization order
GlobalRegistry.SetInitOrder([]string{
"go", // First: core utilities
"http", // Second: HTTP functionality
"cookie", // Third: Cookie functionality (uses HTTP)
"csrf", // Fourth: CSRF protection (uses go and possibly session)
})
logger.Debug("[CoreModuleRegistry] Core modules registered in init()")
}
@ -141,3 +261,8 @@ func init() {
func RegisterCoreModule(name string, initFunc StateInitFunc) {
GlobalRegistry.Register(name, initFunc)
}
// RegisterCoreModuleWithDependencies registers a module with dependencies
func RegisterCoreModuleWithDependencies(name string, initFunc StateInitFunc, dependencies []string) {
GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies)
}

View File

@ -13,7 +13,7 @@ const LuaCSRFModule = `
local csrf = {
-- Session key where the token is stored
TOKEN_KEY = "_csrf_token",
-- Default form field name
DEFAULT_FIELD = "csrf",
@ -21,83 +21,75 @@ local csrf = {
generate = function(length)
-- Default length is 32 characters
length = length or 32
if length < 16 then
-- Enforce minimum security
length = 16
end
-- Check if we have a session module
if not session then
error("CSRF protection requires the session module", 2)
end
-- Generate a secure random token using os.time and math.random
local token = ""
local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
-- Seed the random generator with current time
math.randomseed(os.time())
-- Generate random string
for i = 1, length do
local idx = math.random(1, #chars)
token = token .. chars:sub(idx, idx)
end
-- Use Go's secure token generation
local token = go.generate_token(length)
-- Store in session
session.set(csrf.TOKEN_KEY, token)
return token
end,
-- Get the current token or generate a new one
token = function()
-- Get from session if exists
local token = session.get(csrf.TOKEN_KEY)
-- Generate if needed
if not token then
token = csrf.generate()
end
return token
end,
-- Generate a hidden form field with the CSRF token
field = function(field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local token = csrf.token()
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
end,
-- Verify a given token against the session token
verify = function(token, field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local env = getfenv(2)
local form = nil
if env.ctx and env.ctx.form then
form = env.ctx.form
else
return false
end
token = token or form[field_name]
if not token then
return false
end
local session_token = session.get(csrf.TOKEN_KEY)
if not session_token then
return false
end
-- Constant-time comparison to prevent timing attacks
-- This is safe since Lua strings are immutable
if #token ~= #session_token then
return false
end
local result = true
for i = 1, #token do
if token:sub(i, i) ~= session_token:sub(i, i) then
@ -105,7 +97,7 @@ local csrf = {
-- Don't break early - continue to prevent timing attacks
end
end
return result
end
}

62
core/runner/Go.go Normal file
View File

@ -0,0 +1,62 @@
package runner
import (
"crypto/rand"
"encoding/base64"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"git.sharkk.net/Sky/Moonshark/core/logger"
)
// GenerateToken creates a cryptographically secure random token
func GenerateToken(s *luajit.State) int {
// Get the length from the Lua arguments (default to 32)
length := 32
if s.GetTop() >= 1 && s.IsNumber(1) {
length = int(s.ToNumber(1))
}
// Enforce minimum length for security
if length < 16 {
length = 16
}
// Generate secure random bytes
tokenBytes := make([]byte, length)
if _, err := rand.Read(tokenBytes); err != nil {
s.PushString("")
logger.Error("Failed to generate secure token: %v", err)
return 1 // Return empty string on error
}
// Encode as base64
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
// Trim to requested length (base64 might be longer)
if len(token) > length {
token = token[:length]
}
// Push the token to the Lua stack
s.PushString(token)
return 1 // One return value
}
// GoModuleFunctions returns all functions for the go module
func GoModuleFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"generate_token": GenerateToken,
}
}
// GoModuleInitFunc returns an initializer for the go module
func GoModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
return RegisterModule(state, "go", GoModuleFunctions())
}
}
// Initialize the core module during startup
func init() {
RegisterCoreModule("go", GoModuleInitFunc())
}