Compare commits

..

2 Commits

Author SHA1 Message Date
3e26f348b4 sandbox split 2025-03-20 14:46:11 -05:00
87acbb402a optimizations 1 2025-03-20 14:12:03 -05:00
6 changed files with 520 additions and 269 deletions

View File

@ -28,29 +28,15 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
staticRouter: staticRouter, staticRouter: staticRouter,
luaRunner: runner, luaRunner: runner,
logger: log, logger: log,
httpServer: &http.Server{ httpServer: &http.Server{},
// Connection timeouts
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
ReadHeaderTimeout: 10 * time.Second,
// Improved connection handling
MaxHeaderBytes: 1 << 16, // 64KB
},
} }
server.httpServer.Handler = server server.httpServer.Handler = server
// Set TCP keep-alive settings for the underlying TCP connections // Set TCP keep-alive for connections
server.httpServer.ConnState = func(conn net.Conn, state http.ConnState) { server.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
if state == http.StateNew { if state == http.StateNew {
if tcpConn, ok := conn.(*net.TCPConn); ok { if tcpConn, ok := conn.(*net.TCPConn); ok {
// Enable TCP keep-alive
tcpConn.SetKeepAlive(true) tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
// Set TCP_NODELAY (disable Nagle's algorithm)
tcpConn.SetNoDelay(true)
} }
} }
} }
@ -124,7 +110,17 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
ctx.Set("method", r.Method) ctx.Set("method", r.Method)
ctx.Set("path", r.URL.Path) ctx.Set("path", r.URL.Path)
ctx.Set("host", r.Host) ctx.Set("host", r.Host)
ctx.Set("headers", makeHeaderMap(r.Header))
// Inline the header conversion (previously makeHeaderMap)
headerMap := make(map[string]any, len(r.Header))
for name, values := range r.Header {
if len(values) == 1 {
headerMap[name] = values[0]
} else {
headerMap[name] = values
}
}
ctx.Set("headers", headerMap)
// Add URL parameters // Add URL parameters
if params.Count > 0 { if params.Count > 0 {
@ -135,12 +131,11 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
ctx.Set("params", paramMap) ctx.Set("params", paramMap)
} }
// Add query parameters // Query parameters will be parsed lazily via metatable in Lua
if queryParams := QueryToLua(r); queryParams != nil { // Instead of parsing for every request, we'll pass the raw URL
ctx.Set("query", queryParams) ctx.Set("rawQuery", r.URL.RawQuery)
}
// Add form data // Add form data for POST/PUT/PATCH only when needed
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch { if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
if formData, err := ParseForm(r); err == nil && len(formData) > 0 { if formData, err := ParseForm(r); err == nil && len(formData) > 0 {
ctx.Set("form", formData) ctx.Set("form", formData)
@ -158,18 +153,11 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
writeResponse(w, result, s.logger) writeResponse(w, result, s.logger)
} }
// makeHeaderMap converts HTTP headers to a map // Content types for responses
func makeHeaderMap(header http.Header) map[string]any { const (
result := make(map[string]any, len(header)) contentTypeJSON = "application/json"
for name, values := range header { contentTypePlain = "text/plain"
if len(values) == 1 { )
result[name] = values[0]
} else {
result[name] = values
}
}
return result
}
// writeResponse writes the Lua result to the HTTP response // writeResponse writes the Lua result to the HTTP response
func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) { func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
@ -181,12 +169,12 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
switch res := result.(type) { switch res := result.(type) {
case string: case string:
// String result // String result
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", contentTypePlain)
w.Write([]byte(res)) w.Write([]byte(res))
case map[string]any: case map[string]any, []any:
// Table result - convert to JSON // Table or array result - convert to JSON
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", contentTypeJSON)
data, err := json.Marshal(res) data, err := json.Marshal(res)
if err != nil { if err != nil {
log.Error("Failed to marshal response: %v", err) log.Error("Failed to marshal response: %v", err)
@ -197,7 +185,7 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
default: default:
// Other result types - convert to JSON // Other result types - convert to JSON
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", contentTypeJSON)
data, err := json.Marshal(result) data, err := json.Marshal(result)
if err != nil { if err != nil {
log.Error("Failed to marshal response: %v", err) log.Error("Failed to marshal response: %v", err)

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"os" "os"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@ -47,6 +48,12 @@ var levelProps = map[int]struct {
// Time format for log messages // Time format for log messages
const timeFormat = "15:04:05" const timeFormat = "15:04:05"
// Default rate limiting settings
const (
defaultMaxLogs = 1000 // Max logs per second before rate limiting
defaultRateLimitTime = 10 * time.Second // How long to pause during rate limiting
)
// Logger handles logging operations // Logger handles logging operations
type Logger struct { type Logger struct {
writer io.Writer writer io.Writer
@ -54,16 +61,39 @@ type Logger struct {
useColors bool useColors bool
timeFormat string timeFormat string
mu sync.Mutex // Mutex for thread-safe writing mu sync.Mutex // Mutex for thread-safe writing
// Simple rate limiting
logCount atomic.Int64 // Number of logs in current window
logCountStart atomic.Int64 // Start time of current counting window
rateLimited atomic.Bool // Whether we're currently rate limited
rateLimitUntil atomic.Int64 // Timestamp when rate limiting ends
maxLogsPerSec int64 // Maximum logs per second before limiting
limitDuration time.Duration // How long to pause logging when rate limited
} }
// New creates a new logger // New creates a new logger
func New(minLevel int, useColors bool) *Logger { func New(minLevel int, useColors bool) *Logger {
return &Logger{ logger := &Logger{
writer: os.Stdout, writer: os.Stdout,
level: minLevel, level: minLevel,
useColors: useColors, useColors: useColors,
timeFormat: timeFormat, timeFormat: timeFormat,
maxLogsPerSec: defaultMaxLogs,
limitDuration: defaultRateLimitTime,
} }
// Initialize counters
logger.resetCounters()
return logger
}
// resetCounters resets the rate limiting counters
func (l *Logger) resetCounters() {
l.logCount.Store(0)
l.logCountStart.Store(time.Now().Unix())
l.rateLimited.Store(false)
l.rateLimitUntil.Store(0)
} }
// SetOutput changes the output destination // SetOutput changes the output destination
@ -119,29 +149,85 @@ func (l *Logger) writeMessage(level int, message string, rawMode bool) {
} }
} }
// Asynchronously write the log message // Synchronously write the log message
go func(w io.Writer, data string) {
l.mu.Lock() l.mu.Lock()
_, _ = fmt.Fprint(w, data) _, _ = fmt.Fprint(l.writer, logLine)
l.mu.Unlock()
}(l.writer, logLine)
// For fatal errors, ensure we sync immediately in the current goroutine // For fatal errors, ensure we sync immediately
if level == LevelFatal { if level == LevelFatal {
l.mu.Lock()
if f, ok := l.writer.(*os.File); ok { if f, ok := l.writer.(*os.File); ok {
_ = f.Sync() _ = f.Sync()
} }
l.mu.Unlock()
} }
l.mu.Unlock()
}
// checkRateLimit checks if we should rate limit logging
// Returns true if the message should be logged, false if it should be dropped
func (l *Logger) checkRateLimit(level int) bool {
// High priority messages are never rate limited
if level >= LevelWarning {
return true
}
// Check if we're currently in a rate-limited period
if l.rateLimited.Load() {
now := time.Now().Unix()
limitUntil := l.rateLimitUntil.Load()
if now >= limitUntil {
// Rate limiting period is over
l.rateLimited.Store(false)
l.resetCounters()
} else {
// Still in rate limiting period, drop the message
return false
}
}
// If not rate limited, check if we should start rate limiting
count := l.logCount.Add(1)
// Check if we need to reset the counter for a new second
now := time.Now().Unix()
start := l.logCountStart.Load()
if now > start {
// New second, reset counter
l.logCount.Store(1) // Count this message
l.logCountStart.Store(now)
return true
}
// Check if we've exceeded our threshold
if count > l.maxLogsPerSec {
// Start rate limiting
l.rateLimited.Store(true)
l.rateLimitUntil.Store(now + int64(l.limitDuration.Seconds()))
// Log a warning about rate limiting
l.writeMessage(LevelServer,
fmt.Sprintf("Rate limiting logger temporarily due to high demand (%d logs/sec exceeded)", count),
false)
return false
}
return true
} }
// log handles the core logging logic with level filtering // log handles the core logging logic with level filtering
func (l *Logger) log(level int, format string, args ...any) { func (l *Logger) log(level int, format string, args ...any) {
// First check normal level filtering
if level < l.level { if level < l.level {
return return
} }
// Check rate limiting - always log high priority messages
if !l.checkRateLimit(level) {
return
}
// Format message
var message string var message string
if len(args) > 0 { if len(args) > 0 {
message = fmt.Sprintf(format, args...) message = fmt.Sprintf(format, args...)
@ -164,6 +250,11 @@ func (l *Logger) LogRaw(format string, args ...any) {
return return
} }
// Check rate limiting
if !l.checkRateLimit(LevelInfo) {
return
}
var message string var message string
if len(args) > 0 { if len(args) > 0 {
message = fmt.Sprintf(format, args...) message = fmt.Sprintf(format, args...)

View File

@ -2,7 +2,6 @@ package logger
import ( import (
"bytes" "bytes"
"strconv"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -16,36 +15,31 @@ func TestLoggerLevels(t *testing.T) {
// Debug should be below threshold // Debug should be below threshold
logger.Debug("This should not appear") logger.Debug("This should not appear")
time.Sleep(10 * time.Millisecond) // Wait for processing
if buf.Len() > 0 { if buf.Len() > 0 {
t.Error("Debug message appeared when it should be filtered") t.Error("Debug message appeared when it should be filtered")
} }
// Info and above should appear // Info and above should appear
logger.Info("Info message") logger.Info("Info message")
time.Sleep(10 * time.Millisecond) // Wait for processing if !strings.Contains(buf.String(), "INFO") {
if !strings.Contains(buf.String(), "[INF]") {
t.Errorf("Info message not logged, got: %q", buf.String()) t.Errorf("Info message not logged, got: %q", buf.String())
} }
buf.Reset() buf.Reset()
logger.Warning("Warning message") logger.Warning("Warning message")
time.Sleep(10 * time.Millisecond) // Wait for processing if !strings.Contains(buf.String(), "WARN") {
if !strings.Contains(buf.String(), "[WRN]") {
t.Errorf("Warning message not logged, got: %q", buf.String()) t.Errorf("Warning message not logged, got: %q", buf.String())
} }
buf.Reset() buf.Reset()
logger.Error("Error message") logger.Error("Error message")
time.Sleep(10 * time.Millisecond) // Wait for processing if !strings.Contains(buf.String(), "ERROR") {
if !strings.Contains(buf.String(), "[ERR]") {
t.Errorf("Error message not logged, got: %q", buf.String()) t.Errorf("Error message not logged, got: %q", buf.String())
} }
buf.Reset() buf.Reset()
// Test format strings // Test format strings
logger.Info("Count: %d", 42) logger.Info("Count: %d", 42)
time.Sleep(10 * time.Millisecond) // Wait for processing
if !strings.Contains(buf.String(), "Count: 42") { if !strings.Contains(buf.String(), "Count: 42") {
t.Errorf("Formatted message not logged correctly, got: %q", buf.String()) t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
} }
@ -60,17 +54,53 @@ func TestLoggerLevels(t *testing.T) {
} }
logger.Error("Error should appear") logger.Error("Error should appear")
time.Sleep(10 * time.Millisecond) // Wait for processing if !strings.Contains(buf.String(), "ERROR") {
if !strings.Contains(buf.String(), "[ERR]") {
t.Errorf("Error message not logged after level change, got: %q", buf.String()) t.Errorf("Error message not logged after level change, got: %q", buf.String())
} }
} }
func TestLoggerRateLimit(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelDebug, false)
logger.SetOutput(&buf)
// Override max logs per second to something small for testing
logger.maxLogsPerSec = 5
logger.limitDuration = 1 * time.Second
// Send debug messages (should get limited)
for i := 0; i < 20; i++ {
logger.Debug("Debug message %d", i)
}
// Error messages should always go through
logger.Error("Error message should appear")
content := buf.String()
// We should see some debug messages, then a warning about rate limiting,
// and finally the error message
if !strings.Contains(content, "Debug message 0") {
t.Error("First debug message should appear")
}
if !strings.Contains(content, "Rate limiting logger") {
t.Error("Rate limiting message should appear")
}
if !strings.Contains(content, "ERROR") {
t.Error("Error message should always appear despite rate limiting")
}
}
func TestLoggerConcurrency(t *testing.T) { func TestLoggerConcurrency(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
logger := New(LevelDebug, false) logger := New(LevelDebug, false)
logger.SetOutput(&buf) logger.SetOutput(&buf)
// Increase log threshold for this test
logger.maxLogsPerSec = 1000
// Log a bunch of messages concurrently // Log a bunch of messages concurrently
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@ -82,17 +112,10 @@ func TestLoggerConcurrency(t *testing.T) {
} }
wg.Wait() wg.Wait()
// Wait for processing // Check logs were processed
time.Sleep(10 * time.Millisecond)
// Check all messages were logged
content := buf.String() content := buf.String()
for i := 0; i < 100; i++ { if !strings.Contains(content, "Concurrent message") {
msg := "Concurrent message " + strconv.Itoa(i) t.Error("Concurrent messages should appear")
if !strings.Contains(content, msg) && !strings.Contains(content, "Concurrent message") {
t.Errorf("Missing concurrent messages")
break
}
} }
} }
@ -103,7 +126,6 @@ func TestLoggerColors(t *testing.T) {
// Test with color // Test with color
logger.Info("Colored message") logger.Info("Colored message")
time.Sleep(10 * time.Millisecond) // Wait for processing
content := buf.String() content := buf.String()
t.Logf("Colored output: %q", content) // Print actual output for diagnosis t.Logf("Colored output: %q", content) // Print actual output for diagnosis
@ -114,7 +136,6 @@ func TestLoggerColors(t *testing.T) {
buf.Reset() buf.Reset()
logger.DisableColors() logger.DisableColors()
logger.Info("Non-colored message") logger.Info("Non-colored message")
time.Sleep(10 * time.Millisecond) // Wait for processing
content = buf.String() content = buf.String()
if strings.Contains(content, "\033[") { if strings.Contains(content, "\033[") {
@ -127,10 +148,9 @@ func TestDefaultLogger(t *testing.T) {
SetOutput(&buf) SetOutput(&buf)
Info("Test default logger") Info("Test default logger")
time.Sleep(10 * time.Millisecond) // Wait for processing
content := buf.String() content := buf.String()
if !strings.Contains(content, "[INF]") { if !strings.Contains(content, "INFO") {
t.Errorf("Default logger not working, got: %q", content) t.Errorf("Default logger not working, got: %q", content)
} }
} }
@ -140,23 +160,55 @@ func BenchmarkLogger(b *testing.B) {
logger := New(LevelInfo, false) logger := New(LevelInfo, false)
logger.SetOutput(&buf) logger.SetOutput(&buf)
// Set very high threshold to avoid rate limiting during benchmark
logger.maxLogsPerSec = int64(b.N + 1)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
logger.Info("Benchmark message %d", i) logger.Info("Benchmark message %d", i)
} }
} }
func BenchmarkLoggerWithRateLimit(b *testing.B) {
var buf bytes.Buffer
logger := New(LevelDebug, false)
logger.SetOutput(&buf)
// Set threshold to allow about 10% of messages through
logger.maxLogsPerSec = int64(b.N / 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Debug("Benchmark message %d", i)
}
}
func BenchmarkLoggerParallel(b *testing.B) { func BenchmarkLoggerParallel(b *testing.B) {
var buf bytes.Buffer var buf bytes.Buffer
logger := New(LevelInfo, false) logger := New(LevelDebug, false)
logger.SetOutput(&buf) logger.SetOutput(&buf)
// Set very high threshold to avoid rate limiting during benchmark
logger.maxLogsPerSec = int64(b.N + 1)
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
i := 0 i := 0
for pb.Next() { for pb.Next() {
logger.Info("Parallel benchmark message %d", i) logger.Debug("Parallel benchmark message %d", i)
i++ i++
} }
}) })
} }
func BenchmarkProductionLevels(b *testing.B) {
var buf bytes.Buffer
logger := New(LevelWarning, false) // Only log warnings and above
logger.SetOutput(&buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// This should be filtered out before any processing
logger.Debug("Debug message that won't be logged %d", i)
}
}

View File

@ -29,9 +29,8 @@ type LuaRunner struct {
bufferSize int // Size of the job queue buffer bufferSize int // Size of the job queue buffer
requireCache *RequireCache // Cache for required modules requireCache *RequireCache // Cache for required modules
requireCfg *RequireConfig // Configuration for require paths requireCfg *RequireConfig // Configuration for require paths
scriptDir string // Base directory for scripts moduleLoader luajit.GoFunction // Keep reference to prevent GC
libDirs []string // Additional library directories sandbox *Sandbox // The sandbox environment
loaderFunc luajit.GoFunction // Keep reference to prevent GC
} }
// NewRunner creates a new LuaRunner // NewRunner creates a new LuaRunner
@ -43,6 +42,7 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
requireCfg: &RequireConfig{ requireCfg: &RequireConfig{
LibDirs: []string{}, LibDirs: []string{},
}, },
sandbox: NewSandbox(),
} }
// Apply options // Apply options
@ -63,12 +63,11 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
// Create a shared config pointer that will be updated per request // Create a shared config pointer that will be updated per request
runner.requireCfg = &RequireConfig{ runner.requireCfg = &RequireConfig{
ScriptDir: runner.scriptDir, ScriptDir: runner.scriptDir(),
LibDirs: runner.libDirs, LibDirs: runner.libDirs(),
} }
// Set up require functionality ONCE // Set up require functionality
// Create and register the module loader function
moduleLoader := func(s *luajit.State) int { moduleLoader := func(s *luajit.State) int {
// Get module name // Get module name
modName := s.ToString(1) modName := s.ToString(1)
@ -99,7 +98,7 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
} }
// Store reference to prevent garbage collection // Store reference to prevent garbage collection
runner.loaderFunc = moduleLoader runner.moduleLoader = moduleLoader
// Register with Lua state // Register with Lua state
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil { if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
@ -108,8 +107,35 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
} }
// Set up the require mechanism // Set up the require mechanism
setupRequireScript := ` if err := setupRequireFunction(state); err != nil {
-- Create a secure require function for sandboxed environments state.Close()
return nil, ErrInitFailed
}
// Set up sandbox
if err := runner.sandbox.Setup(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Run init function if provided
if runner.initFunc != nil {
if err := runner.initFunc(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
}
// Start the event loop
runner.wg.Add(1)
go runner.processJobs()
return runner, nil
}
// setupRequireFunction adds the secure require implementation
func setupRequireFunction(state *luajit.State) error {
return state.DoString(`
function __setup_secure_require(env) function __setup_secure_require(env)
-- Replace env.require with our secure version -- Replace env.require with our secure version
env.require = function(modname) env.require = function(modname)
@ -146,32 +172,7 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
return env return env
end end
` `)
if err := state.DoString(setupRequireScript); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Set up sandbox
if err := runner.setupSandbox(); err != nil {
state.Close()
return nil, ErrInitFailed
}
// Run init function if provided
if runner.initFunc != nil {
if err := runner.initFunc(state); err != nil {
state.Close()
return nil, ErrInitFailed
}
}
// Start the event loop
runner.wg.Add(1)
go runner.eventLoop()
return runner, nil
} }
// RunnerOption defines a functional option for configuring the LuaRunner // RunnerOption defines a functional option for configuring the LuaRunner
@ -196,7 +197,6 @@ func WithInitFunc(initFunc StateInitFunc) RunnerOption {
// WithScriptDir sets the base directory for scripts // WithScriptDir sets the base directory for scripts
func WithScriptDir(dir string) RunnerOption { func WithScriptDir(dir string) RunnerOption {
return func(r *LuaRunner) { return func(r *LuaRunner) {
r.scriptDir = dir
r.requireCfg.ScriptDir = dir r.requireCfg.ScriptDir = dir
} }
} }
@ -204,103 +204,31 @@ func WithScriptDir(dir string) RunnerOption {
// WithLibDirs sets additional library directories // WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption { func WithLibDirs(dirs ...string) RunnerOption {
return func(r *LuaRunner) { return func(r *LuaRunner) {
r.libDirs = dirs
r.requireCfg.LibDirs = dirs r.requireCfg.LibDirs = dirs
} }
} }
// setupSandbox initializes the sandbox environment // scriptDir returns the current script directory
func (r *LuaRunner) setupSandbox() error { func (r *LuaRunner) scriptDir() string {
// This is the Lua script that creates our sandbox function if r.requireCfg != nil {
setupScript := ` return r.requireCfg.ScriptDir
-- Create a function to run code in a sandbox environment
function __create_sandbox()
-- Create new environment table
local env = {}
-- Add standard library modules (can be restricted as needed)
env.string = string
env.table = table
env.math = math
env.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
} }
env.tonumber = tonumber return ""
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Set up the standard library package table
env.package = {
loaded = {} -- Table to store loaded modules
}
-- Explicitly expose the module loader function
env.__go_load_module = __go_load_module
-- Set up secure require function
env = __setup_secure_require(env)
-- Create metatable to restrict access to _G
local mt = {
__index = function(t, k)
-- First check in env table
local v = rawget(env, k)
if v ~= nil then return v end
-- If not found, check for registered modules/functions
local moduleValue = _G[k]
if type(moduleValue) == "table" or
type(moduleValue) == "function" then
return moduleValue
end
return nil
end,
__newindex = function(t, k, v)
rawset(env, k, v)
end
}
setmetatable(env, mt)
return env
end
-- Create function to execute code with a sandbox
function __run_sandboxed(f, ctx)
local env = __create_sandbox()
-- Add context to the environment if provided
if ctx then
env.ctx = ctx
end
-- Set the environment and run the function
setfenv(f, env)
return f()
end
`
return r.state.DoString(setupScript)
} }
// eventLoop processes jobs from the queue // libDirs returns the current library directories
func (r *LuaRunner) eventLoop() { func (r *LuaRunner) libDirs() []string {
if r.requireCfg != nil {
return r.requireCfg.LibDirs
}
return nil
}
// processJobs handles the job queue
func (r *LuaRunner) processJobs() {
defer r.wg.Done() defer r.wg.Done()
defer r.state.Close() defer r.state.Close()
// Process jobs until closure
for job := range r.jobQueue { for job := range r.jobQueue {
// Execute the job and send result // Execute the job and send result
result := r.executeJob(job) result := r.executeJob(job)
@ -329,55 +257,14 @@ func (r *LuaRunner) executeJob(j job) JobResult {
} }
} }
// Set up context if provided // Convert context for sandbox
var ctx map[string]any
if j.Context != nil { if j.Context != nil {
// Push context table ctx = j.Context.Values
r.state.NewTable()
// Add values to context table
for key, value := range j.Context.Values {
// Push key
r.state.PushString(key)
// Push value
if err := r.state.PushValue(value); err != nil {
return JobResult{nil, err}
} }
// Set table[key] = value // Execute in sandbox
r.state.SetTable(-3) value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx)
}
} else {
// Push nil if no context
r.state.PushNil()
}
// Load bytecode
if err := r.state.LoadBytecode(j.Bytecode, j.ScriptPath); err != nil {
r.state.Pop(1) // Pop context
return JobResult{nil, err}
}
// Get the sandbox runner function
r.state.GetGlobal("__run_sandboxed")
// Push loaded function and context as arguments
r.state.PushCopy(-2) // Copy the loaded function
r.state.PushCopy(-4) // Copy the context table or nil
// Remove the original function and context
r.state.Remove(-5) // Remove original context
r.state.Remove(-4) // Remove original function
// Call the sandbox runner with 2 args (function and context), expecting 1 result
if err := r.state.Call(2, 1); err != nil {
return JobResult{nil, err}
}
// Get result
value, err := r.state.ToValue(-1)
r.state.Pop(1) // Pop result
return JobResult{value, err} return JobResult{value, err}
} }
@ -442,3 +329,8 @@ func (r *LuaRunner) Close() error {
func (r *LuaRunner) ClearRequireCache() { func (r *LuaRunner) ClearRequireCache() {
r.requireCache = NewRequireCache() r.requireCache = NewRequireCache()
} }
// AddModule adds a module to the sandbox environment
func (r *LuaRunner) AddModule(name string, module any) {
r.sandbox.AddModule(name, module)
}

228
core/runner/sandbox.go Normal file
View File

@ -0,0 +1,228 @@
package runner
import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Sandbox manages a sandboxed Lua environment
type Sandbox struct {
modules map[string]any // Custom modules for environment
initialized bool // Whether base environment is initialized
}
// NewSandbox creates a new sandbox
func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any),
initialized: false,
}
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) {
s.modules[name] = module
}
// Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State) error {
// Register modules
if err := s.registerModules(state); err != nil {
return err
}
// Setup the sandbox creation logic with base environment reuse
return state.DoString(`
-- Create the base environment once (static parts)
local __base_env = nil
-- Create function to initialize base environment
function __init_base_env()
if __base_env then return end
local env = {}
-- Add standard library modules (restricted)
env.string = string
env.table = table
env.math = math
env.os = {
time = os.time,
date = os.date,
difftime = os.difftime,
clock = os.clock
}
env.tonumber = tonumber
env.tostring = tostring
env.type = type
env.pairs = pairs
env.ipairs = ipairs
env.next = next
env.select = select
env.unpack = unpack
env.pcall = pcall
env.xpcall = xpcall
env.error = error
env.assert = assert
-- Add module loader
env.__go_load_module = __go_load_module
-- Add custom modules from sandbox registry
if __sandbox_modules then
for name, module in pairs(__sandbox_modules) do
env[name] = module
end
end
-- Copy custom global functions
for k, v in pairs(_G) do
if (type(v) == "function" or type(v) == "table") and
k ~= "__sandbox_modules" and
k ~= "__base_env" and
k ~= "__init_base_env" and
k ~= "__create_sandbox_env" and
k ~= "__run_sandboxed" and
k ~= "__setup_secure_require" and
k ~= "__go_load_module" and
k ~= "string" and k ~= "table" and k ~= "math" and
k ~= "os" and k ~= "io" and k ~= "debug" and
k ~= "package" and k ~= "bit" and k ~= "jit" and
k ~= "coroutine" and k ~= "_G" and k ~= "_VERSION" then
env[k] = v
end
end
__base_env = env
end
-- Create function that builds sandbox from base env
function __create_sandbox_env(ctx)
-- Initialize base env if needed
__init_base_env()
-- Create new environment using base as prototype
local env = {}
-- Copy from base environment
for k, v in pairs(__base_env) do
env[k] = v
end
-- Add isolated package.loaded table
env.package = {
loaded = {}
}
-- Add context if provided
if ctx then
env.ctx = ctx
end
-- Setup require function
env = __setup_secure_require(env)
-- Create metatable for isolation
local mt = {
__index = function(t, k)
return rawget(env, k)
end,
__newindex = function(t, k, v)
rawset(env, k, v)
end
}
setmetatable(env, mt)
return env
end
-- Function to run code in sandbox
function __run_sandboxed(bytecode, ctx)
-- Create environment for this request
local env = __create_sandbox_env(ctx)
-- Set environment and execute
setfenv(bytecode, env)
return bytecode()
end
`)
}
// registerModules registers custom modules in the Lua state
func (s *Sandbox) registerModules(state *luajit.State) error {
// Create or get module registry table
state.GetGlobal("__sandbox_modules")
if state.IsNil(-1) {
// Table doesn't exist, create it
state.Pop(1)
state.NewTable()
state.SetGlobal("__sandbox_modules")
state.GetGlobal("__sandbox_modules")
}
// Add modules to registry
for name, module := range s.modules {
state.PushString(name)
if err := state.PushValue(module); err != nil {
state.Pop(2)
return err
}
state.SetTable(-3)
}
// Pop module table
state.Pop(1)
return nil
}
// Execute runs bytecode in the sandbox
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
// Update modules if needed
if !s.initialized {
if err := s.registerModules(state); err != nil {
return nil, err
}
s.initialized = true
}
// Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil {
return nil, err
}
// Create context table if provided
if len(ctx) > 0 {
state.NewTable()
for k, v := range ctx {
state.PushString(k)
if err := state.PushValue(v); err != nil {
state.Pop(3)
return nil, err
}
state.SetTable(-3)
}
} else {
state.PushNil() // No context
}
// Get sandbox function
state.GetGlobal("__run_sandboxed")
// Setup call with correct argument order
state.PushCopy(-3) // Copy bytecode function
state.PushCopy(-3) // Copy context
// Clean up stack
state.Remove(-5) // Remove original bytecode
state.Remove(-4) // Remove original context
// Call sandbox function
if err := state.Call(2, 1); err != nil {
return nil, err
}
// Get result
result, err := state.ToValue(-1)
state.Pop(1)
return result, err
}

View File

@ -85,7 +85,7 @@ func main() {
log.Fatal("Router initialization failed: %v", err) log.Fatal("Router initialization failed: %v", err)
} }
if cfg.GetBool("watchers", false) { if cfg.GetBool("watchers", true) {
// Set up file watchers for automatic reloading // Set up file watchers for automatic reloading
luaWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log) luaWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log)
if err != nil { if err != nil {
@ -104,10 +104,10 @@ func main() {
} }
} }
// Get buffer size from config or use default (used to be worker pool size) // Get buffer size from config or use default
bufferSize := cfg.GetInt("buffer_size", 20) bufferSize := cfg.GetInt("buffer_size", 20)
// Initialize Lua runner (replacing worker pool) // Initialize Lua runner
runner, err := runner.NewRunner( runner, err := runner.NewRunner(
runner.WithBufferSize(bufferSize), runner.WithBufferSize(bufferSize),
runner.WithLibDirs("./libs"), runner.WithLibDirs("./libs"),