Compare commits
No commits in common. "3e26f348b4fb7a8f1315b90d1994fde01a005ea5" and "fc57a03a8e361338c7b263f5e5c191ef33468370" have entirely different histories.
3e26f348b4
...
fc57a03a8e
|
@ -28,15 +28,29 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
||||||
staticRouter: staticRouter,
|
staticRouter: staticRouter,
|
||||||
luaRunner: runner,
|
luaRunner: runner,
|
||||||
logger: log,
|
logger: log,
|
||||||
httpServer: &http.Server{},
|
httpServer: &http.Server{
|
||||||
|
// Connection timeouts
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
|
||||||
|
// Improved connection handling
|
||||||
|
MaxHeaderBytes: 1 << 16, // 64KB
|
||||||
|
},
|
||||||
}
|
}
|
||||||
server.httpServer.Handler = server
|
server.httpServer.Handler = server
|
||||||
|
|
||||||
// Set TCP keep-alive for connections
|
// Set TCP keep-alive settings for the underlying TCP connections
|
||||||
server.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
|
server.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
|
||||||
if state == http.StateNew {
|
if state == http.StateNew {
|
||||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||||
|
// Enable TCP keep-alive
|
||||||
tcpConn.SetKeepAlive(true)
|
tcpConn.SetKeepAlive(true)
|
||||||
|
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||||
|
|
||||||
|
// Set TCP_NODELAY (disable Nagle's algorithm)
|
||||||
|
tcpConn.SetNoDelay(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,17 +124,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
||||||
ctx.Set("method", r.Method)
|
ctx.Set("method", r.Method)
|
||||||
ctx.Set("path", r.URL.Path)
|
ctx.Set("path", r.URL.Path)
|
||||||
ctx.Set("host", r.Host)
|
ctx.Set("host", r.Host)
|
||||||
|
ctx.Set("headers", makeHeaderMap(r.Header))
|
||||||
// Inline the header conversion (previously makeHeaderMap)
|
|
||||||
headerMap := make(map[string]any, len(r.Header))
|
|
||||||
for name, values := range r.Header {
|
|
||||||
if len(values) == 1 {
|
|
||||||
headerMap[name] = values[0]
|
|
||||||
} else {
|
|
||||||
headerMap[name] = values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx.Set("headers", headerMap)
|
|
||||||
|
|
||||||
// Add URL parameters
|
// Add URL parameters
|
||||||
if params.Count > 0 {
|
if params.Count > 0 {
|
||||||
|
@ -131,11 +135,12 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
||||||
ctx.Set("params", paramMap)
|
ctx.Set("params", paramMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query parameters will be parsed lazily via metatable in Lua
|
// Add query parameters
|
||||||
// Instead of parsing for every request, we'll pass the raw URL
|
if queryParams := QueryToLua(r); queryParams != nil {
|
||||||
ctx.Set("rawQuery", r.URL.RawQuery)
|
ctx.Set("query", queryParams)
|
||||||
|
}
|
||||||
|
|
||||||
// Add form data for POST/PUT/PATCH only when needed
|
// Add form data
|
||||||
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
||||||
if formData, err := ParseForm(r); err == nil && len(formData) > 0 {
|
if formData, err := ParseForm(r); err == nil && len(formData) > 0 {
|
||||||
ctx.Set("form", formData)
|
ctx.Set("form", formData)
|
||||||
|
@ -153,11 +158,18 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
||||||
writeResponse(w, result, s.logger)
|
writeResponse(w, result, s.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Content types for responses
|
// makeHeaderMap converts HTTP headers to a map
|
||||||
const (
|
func makeHeaderMap(header http.Header) map[string]any {
|
||||||
contentTypeJSON = "application/json"
|
result := make(map[string]any, len(header))
|
||||||
contentTypePlain = "text/plain"
|
for name, values := range header {
|
||||||
)
|
if len(values) == 1 {
|
||||||
|
result[name] = values[0]
|
||||||
|
} else {
|
||||||
|
result[name] = values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// writeResponse writes the Lua result to the HTTP response
|
// writeResponse writes the Lua result to the HTTP response
|
||||||
func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||||
|
@ -169,12 +181,12 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||||
switch res := result.(type) {
|
switch res := result.(type) {
|
||||||
case string:
|
case string:
|
||||||
// String result
|
// String result
|
||||||
w.Header().Set("Content-Type", contentTypePlain)
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.Write([]byte(res))
|
w.Write([]byte(res))
|
||||||
|
|
||||||
case map[string]any, []any:
|
case map[string]any:
|
||||||
// Table or array result - convert to JSON
|
// Table result - convert to JSON
|
||||||
w.Header().Set("Content-Type", contentTypeJSON)
|
w.Header().Set("Content-Type", "application/json")
|
||||||
data, err := json.Marshal(res)
|
data, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Failed to marshal response: %v", err)
|
log.Error("Failed to marshal response: %v", err)
|
||||||
|
@ -185,7 +197,7 @@ func writeResponse(w http.ResponseWriter, result any, log *logger.Logger) {
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// Other result types - convert to JSON
|
// Other result types - convert to JSON
|
||||||
w.Header().Set("Content-Type", contentTypeJSON)
|
w.Header().Set("Content-Type", "application/json")
|
||||||
data, err := json.Marshal(result)
|
data, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Failed to marshal response: %v", err)
|
log.Error("Failed to marshal response: %v", err)
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,12 +47,6 @@ var levelProps = map[int]struct {
|
||||||
// Time format for log messages
|
// Time format for log messages
|
||||||
const timeFormat = "15:04:05"
|
const timeFormat = "15:04:05"
|
||||||
|
|
||||||
// Default rate limiting settings
|
|
||||||
const (
|
|
||||||
defaultMaxLogs = 1000 // Max logs per second before rate limiting
|
|
||||||
defaultRateLimitTime = 10 * time.Second // How long to pause during rate limiting
|
|
||||||
)
|
|
||||||
|
|
||||||
// Logger handles logging operations
|
// Logger handles logging operations
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
|
@ -61,39 +54,16 @@ type Logger struct {
|
||||||
useColors bool
|
useColors bool
|
||||||
timeFormat string
|
timeFormat string
|
||||||
mu sync.Mutex // Mutex for thread-safe writing
|
mu sync.Mutex // Mutex for thread-safe writing
|
||||||
|
|
||||||
// Simple rate limiting
|
|
||||||
logCount atomic.Int64 // Number of logs in current window
|
|
||||||
logCountStart atomic.Int64 // Start time of current counting window
|
|
||||||
rateLimited atomic.Bool // Whether we're currently rate limited
|
|
||||||
rateLimitUntil atomic.Int64 // Timestamp when rate limiting ends
|
|
||||||
maxLogsPerSec int64 // Maximum logs per second before limiting
|
|
||||||
limitDuration time.Duration // How long to pause logging when rate limited
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new logger
|
// New creates a new logger
|
||||||
func New(minLevel int, useColors bool) *Logger {
|
func New(minLevel int, useColors bool) *Logger {
|
||||||
logger := &Logger{
|
return &Logger{
|
||||||
writer: os.Stdout,
|
writer: os.Stdout,
|
||||||
level: minLevel,
|
level: minLevel,
|
||||||
useColors: useColors,
|
useColors: useColors,
|
||||||
timeFormat: timeFormat,
|
timeFormat: timeFormat,
|
||||||
maxLogsPerSec: defaultMaxLogs,
|
|
||||||
limitDuration: defaultRateLimitTime,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize counters
|
|
||||||
logger.resetCounters()
|
|
||||||
|
|
||||||
return logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// resetCounters resets the rate limiting counters
|
|
||||||
func (l *Logger) resetCounters() {
|
|
||||||
l.logCount.Store(0)
|
|
||||||
l.logCountStart.Store(time.Now().Unix())
|
|
||||||
l.rateLimited.Store(false)
|
|
||||||
l.rateLimitUntil.Store(0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOutput changes the output destination
|
// SetOutput changes the output destination
|
||||||
|
@ -149,85 +119,29 @@ func (l *Logger) writeMessage(level int, message string, rawMode bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Synchronously write the log message
|
// Asynchronously write the log message
|
||||||
|
go func(w io.Writer, data string) {
|
||||||
l.mu.Lock()
|
l.mu.Lock()
|
||||||
_, _ = fmt.Fprint(l.writer, logLine)
|
_, _ = fmt.Fprint(w, data)
|
||||||
|
l.mu.Unlock()
|
||||||
|
}(l.writer, logLine)
|
||||||
|
|
||||||
// For fatal errors, ensure we sync immediately
|
// For fatal errors, ensure we sync immediately in the current goroutine
|
||||||
if level == LevelFatal {
|
if level == LevelFatal {
|
||||||
|
l.mu.Lock()
|
||||||
if f, ok := l.writer.(*os.File); ok {
|
if f, ok := l.writer.(*os.File); ok {
|
||||||
_ = f.Sync()
|
_ = f.Sync()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
l.mu.Unlock()
|
l.mu.Unlock()
|
||||||
}
|
|
||||||
|
|
||||||
// checkRateLimit checks if we should rate limit logging
|
|
||||||
// Returns true if the message should be logged, false if it should be dropped
|
|
||||||
func (l *Logger) checkRateLimit(level int) bool {
|
|
||||||
// High priority messages are never rate limited
|
|
||||||
if level >= LevelWarning {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we're currently in a rate-limited period
|
|
||||||
if l.rateLimited.Load() {
|
|
||||||
now := time.Now().Unix()
|
|
||||||
limitUntil := l.rateLimitUntil.Load()
|
|
||||||
|
|
||||||
if now >= limitUntil {
|
|
||||||
// Rate limiting period is over
|
|
||||||
l.rateLimited.Store(false)
|
|
||||||
l.resetCounters()
|
|
||||||
} else {
|
|
||||||
// Still in rate limiting period, drop the message
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If not rate limited, check if we should start rate limiting
|
|
||||||
count := l.logCount.Add(1)
|
|
||||||
|
|
||||||
// Check if we need to reset the counter for a new second
|
|
||||||
now := time.Now().Unix()
|
|
||||||
start := l.logCountStart.Load()
|
|
||||||
if now > start {
|
|
||||||
// New second, reset counter
|
|
||||||
l.logCount.Store(1) // Count this message
|
|
||||||
l.logCountStart.Store(now)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we've exceeded our threshold
|
|
||||||
if count > l.maxLogsPerSec {
|
|
||||||
// Start rate limiting
|
|
||||||
l.rateLimited.Store(true)
|
|
||||||
l.rateLimitUntil.Store(now + int64(l.limitDuration.Seconds()))
|
|
||||||
|
|
||||||
// Log a warning about rate limiting
|
|
||||||
l.writeMessage(LevelServer,
|
|
||||||
fmt.Sprintf("Rate limiting logger temporarily due to high demand (%d logs/sec exceeded)", count),
|
|
||||||
false)
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// log handles the core logging logic with level filtering
|
// log handles the core logging logic with level filtering
|
||||||
func (l *Logger) log(level int, format string, args ...any) {
|
func (l *Logger) log(level int, format string, args ...any) {
|
||||||
// First check normal level filtering
|
|
||||||
if level < l.level {
|
if level < l.level {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check rate limiting - always log high priority messages
|
|
||||||
if !l.checkRateLimit(level) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format message
|
|
||||||
var message string
|
var message string
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
message = fmt.Sprintf(format, args...)
|
message = fmt.Sprintf(format, args...)
|
||||||
|
@ -250,11 +164,6 @@ func (l *Logger) LogRaw(format string, args ...any) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check rate limiting
|
|
||||||
if !l.checkRateLimit(LevelInfo) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var message string
|
var message string
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
message = fmt.Sprintf(format, args...)
|
message = fmt.Sprintf(format, args...)
|
||||||
|
|
|
@ -2,6 +2,7 @@ package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -15,31 +16,36 @@ func TestLoggerLevels(t *testing.T) {
|
||||||
|
|
||||||
// Debug should be below threshold
|
// Debug should be below threshold
|
||||||
logger.Debug("This should not appear")
|
logger.Debug("This should not appear")
|
||||||
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
if buf.Len() > 0 {
|
if buf.Len() > 0 {
|
||||||
t.Error("Debug message appeared when it should be filtered")
|
t.Error("Debug message appeared when it should be filtered")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Info and above should appear
|
// Info and above should appear
|
||||||
logger.Info("Info message")
|
logger.Info("Info message")
|
||||||
if !strings.Contains(buf.String(), "INFO") {
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
|
if !strings.Contains(buf.String(), "[INF]") {
|
||||||
t.Errorf("Info message not logged, got: %q", buf.String())
|
t.Errorf("Info message not logged, got: %q", buf.String())
|
||||||
}
|
}
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
|
|
||||||
logger.Warning("Warning message")
|
logger.Warning("Warning message")
|
||||||
if !strings.Contains(buf.String(), "WARN") {
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
|
if !strings.Contains(buf.String(), "[WRN]") {
|
||||||
t.Errorf("Warning message not logged, got: %q", buf.String())
|
t.Errorf("Warning message not logged, got: %q", buf.String())
|
||||||
}
|
}
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
|
|
||||||
logger.Error("Error message")
|
logger.Error("Error message")
|
||||||
if !strings.Contains(buf.String(), "ERROR") {
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
|
if !strings.Contains(buf.String(), "[ERR]") {
|
||||||
t.Errorf("Error message not logged, got: %q", buf.String())
|
t.Errorf("Error message not logged, got: %q", buf.String())
|
||||||
}
|
}
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
|
|
||||||
// Test format strings
|
// Test format strings
|
||||||
logger.Info("Count: %d", 42)
|
logger.Info("Count: %d", 42)
|
||||||
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
if !strings.Contains(buf.String(), "Count: 42") {
|
if !strings.Contains(buf.String(), "Count: 42") {
|
||||||
t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
|
t.Errorf("Formatted message not logged correctly, got: %q", buf.String())
|
||||||
}
|
}
|
||||||
|
@ -54,53 +60,17 @@ func TestLoggerLevels(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Error("Error should appear")
|
logger.Error("Error should appear")
|
||||||
if !strings.Contains(buf.String(), "ERROR") {
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
|
if !strings.Contains(buf.String(), "[ERR]") {
|
||||||
t.Errorf("Error message not logged after level change, got: %q", buf.String())
|
t.Errorf("Error message not logged after level change, got: %q", buf.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoggerRateLimit(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelDebug, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Override max logs per second to something small for testing
|
|
||||||
logger.maxLogsPerSec = 5
|
|
||||||
logger.limitDuration = 1 * time.Second
|
|
||||||
|
|
||||||
// Send debug messages (should get limited)
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
logger.Debug("Debug message %d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error messages should always go through
|
|
||||||
logger.Error("Error message should appear")
|
|
||||||
|
|
||||||
content := buf.String()
|
|
||||||
|
|
||||||
// We should see some debug messages, then a warning about rate limiting,
|
|
||||||
// and finally the error message
|
|
||||||
if !strings.Contains(content, "Debug message 0") {
|
|
||||||
t.Error("First debug message should appear")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(content, "Rate limiting logger") {
|
|
||||||
t.Error("Rate limiting message should appear")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(content, "ERROR") {
|
|
||||||
t.Error("Error message should always appear despite rate limiting")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerConcurrency(t *testing.T) {
|
func TestLoggerConcurrency(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
logger := New(LevelDebug, false)
|
logger := New(LevelDebug, false)
|
||||||
logger.SetOutput(&buf)
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
// Increase log threshold for this test
|
|
||||||
logger.maxLogsPerSec = 1000
|
|
||||||
|
|
||||||
// Log a bunch of messages concurrently
|
// Log a bunch of messages concurrently
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
|
@ -112,10 +82,17 @@ func TestLoggerConcurrency(t *testing.T) {
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// Check logs were processed
|
// Wait for processing
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check all messages were logged
|
||||||
content := buf.String()
|
content := buf.String()
|
||||||
if !strings.Contains(content, "Concurrent message") {
|
for i := 0; i < 100; i++ {
|
||||||
t.Error("Concurrent messages should appear")
|
msg := "Concurrent message " + strconv.Itoa(i)
|
||||||
|
if !strings.Contains(content, msg) && !strings.Contains(content, "Concurrent message") {
|
||||||
|
t.Errorf("Missing concurrent messages")
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,6 +103,7 @@ func TestLoggerColors(t *testing.T) {
|
||||||
|
|
||||||
// Test with color
|
// Test with color
|
||||||
logger.Info("Colored message")
|
logger.Info("Colored message")
|
||||||
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
|
|
||||||
content := buf.String()
|
content := buf.String()
|
||||||
t.Logf("Colored output: %q", content) // Print actual output for diagnosis
|
t.Logf("Colored output: %q", content) // Print actual output for diagnosis
|
||||||
|
@ -136,6 +114,7 @@ func TestLoggerColors(t *testing.T) {
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
logger.DisableColors()
|
logger.DisableColors()
|
||||||
logger.Info("Non-colored message")
|
logger.Info("Non-colored message")
|
||||||
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
|
|
||||||
content = buf.String()
|
content = buf.String()
|
||||||
if strings.Contains(content, "\033[") {
|
if strings.Contains(content, "\033[") {
|
||||||
|
@ -148,9 +127,10 @@ func TestDefaultLogger(t *testing.T) {
|
||||||
SetOutput(&buf)
|
SetOutput(&buf)
|
||||||
|
|
||||||
Info("Test default logger")
|
Info("Test default logger")
|
||||||
|
time.Sleep(10 * time.Millisecond) // Wait for processing
|
||||||
|
|
||||||
content := buf.String()
|
content := buf.String()
|
||||||
if !strings.Contains(content, "INFO") {
|
if !strings.Contains(content, "[INF]") {
|
||||||
t.Errorf("Default logger not working, got: %q", content)
|
t.Errorf("Default logger not working, got: %q", content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,55 +140,23 @@ func BenchmarkLogger(b *testing.B) {
|
||||||
logger := New(LevelInfo, false)
|
logger := New(LevelInfo, false)
|
||||||
logger.SetOutput(&buf)
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
// Set very high threshold to avoid rate limiting during benchmark
|
|
||||||
logger.maxLogsPerSec = int64(b.N + 1)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
logger.Info("Benchmark message %d", i)
|
logger.Info("Benchmark message %d", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLoggerWithRateLimit(b *testing.B) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelDebug, false)
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
// Set threshold to allow about 10% of messages through
|
|
||||||
logger.maxLogsPerSec = int64(b.N / 10)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
logger.Debug("Benchmark message %d", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLoggerParallel(b *testing.B) {
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
logger := New(LevelDebug, false)
|
logger := New(LevelInfo, false)
|
||||||
logger.SetOutput(&buf)
|
logger.SetOutput(&buf)
|
||||||
|
|
||||||
// Set very high threshold to avoid rate limiting during benchmark
|
|
||||||
logger.maxLogsPerSec = int64(b.N + 1)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
i := 0
|
i := 0
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
logger.Debug("Parallel benchmark message %d", i)
|
logger.Info("Parallel benchmark message %d", i)
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkProductionLevels(b *testing.B) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := New(LevelWarning, false) // Only log warnings and above
|
|
||||||
logger.SetOutput(&buf)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// This should be filtered out before any processing
|
|
||||||
logger.Debug("Debug message that won't be logged %d", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -29,8 +29,9 @@ type LuaRunner struct {
|
||||||
bufferSize int // Size of the job queue buffer
|
bufferSize int // Size of the job queue buffer
|
||||||
requireCache *RequireCache // Cache for required modules
|
requireCache *RequireCache // Cache for required modules
|
||||||
requireCfg *RequireConfig // Configuration for require paths
|
requireCfg *RequireConfig // Configuration for require paths
|
||||||
moduleLoader luajit.GoFunction // Keep reference to prevent GC
|
scriptDir string // Base directory for scripts
|
||||||
sandbox *Sandbox // The sandbox environment
|
libDirs []string // Additional library directories
|
||||||
|
loaderFunc luajit.GoFunction // Keep reference to prevent GC
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRunner creates a new LuaRunner
|
// NewRunner creates a new LuaRunner
|
||||||
|
@ -42,7 +43,6 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
requireCfg: &RequireConfig{
|
requireCfg: &RequireConfig{
|
||||||
LibDirs: []string{},
|
LibDirs: []string{},
|
||||||
},
|
},
|
||||||
sandbox: NewSandbox(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options
|
// Apply options
|
||||||
|
@ -63,11 +63,12 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
|
|
||||||
// Create a shared config pointer that will be updated per request
|
// Create a shared config pointer that will be updated per request
|
||||||
runner.requireCfg = &RequireConfig{
|
runner.requireCfg = &RequireConfig{
|
||||||
ScriptDir: runner.scriptDir(),
|
ScriptDir: runner.scriptDir,
|
||||||
LibDirs: runner.libDirs(),
|
LibDirs: runner.libDirs,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up require functionality
|
// Set up require functionality ONCE
|
||||||
|
// Create and register the module loader function
|
||||||
moduleLoader := func(s *luajit.State) int {
|
moduleLoader := func(s *luajit.State) int {
|
||||||
// Get module name
|
// Get module name
|
||||||
modName := s.ToString(1)
|
modName := s.ToString(1)
|
||||||
|
@ -98,7 +99,7 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store reference to prevent garbage collection
|
// Store reference to prevent garbage collection
|
||||||
runner.moduleLoader = moduleLoader
|
runner.loaderFunc = moduleLoader
|
||||||
|
|
||||||
// Register with Lua state
|
// Register with Lua state
|
||||||
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
|
if err := state.RegisterGoFunction("__go_load_module", moduleLoader); err != nil {
|
||||||
|
@ -107,35 +108,8 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the require mechanism
|
// Set up the require mechanism
|
||||||
if err := setupRequireFunction(state); err != nil {
|
setupRequireScript := `
|
||||||
state.Close()
|
-- Create a secure require function for sandboxed environments
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up sandbox
|
|
||||||
if err := runner.sandbox.Setup(state); err != nil {
|
|
||||||
state.Close()
|
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run init function if provided
|
|
||||||
if runner.initFunc != nil {
|
|
||||||
if err := runner.initFunc(state); err != nil {
|
|
||||||
state.Close()
|
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the event loop
|
|
||||||
runner.wg.Add(1)
|
|
||||||
go runner.processJobs()
|
|
||||||
|
|
||||||
return runner, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupRequireFunction adds the secure require implementation
|
|
||||||
func setupRequireFunction(state *luajit.State) error {
|
|
||||||
return state.DoString(`
|
|
||||||
function __setup_secure_require(env)
|
function __setup_secure_require(env)
|
||||||
-- Replace env.require with our secure version
|
-- Replace env.require with our secure version
|
||||||
env.require = function(modname)
|
env.require = function(modname)
|
||||||
|
@ -172,7 +146,32 @@ func setupRequireFunction(state *luajit.State) error {
|
||||||
|
|
||||||
return env
|
return env
|
||||||
end
|
end
|
||||||
`)
|
`
|
||||||
|
|
||||||
|
if err := state.DoString(setupRequireScript); err != nil {
|
||||||
|
state.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up sandbox
|
||||||
|
if err := runner.setupSandbox(); err != nil {
|
||||||
|
state.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run init function if provided
|
||||||
|
if runner.initFunc != nil {
|
||||||
|
if err := runner.initFunc(state); err != nil {
|
||||||
|
state.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the event loop
|
||||||
|
runner.wg.Add(1)
|
||||||
|
go runner.eventLoop()
|
||||||
|
|
||||||
|
return runner, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunnerOption defines a functional option for configuring the LuaRunner
|
// RunnerOption defines a functional option for configuring the LuaRunner
|
||||||
|
@ -197,6 +196,7 @@ func WithInitFunc(initFunc StateInitFunc) RunnerOption {
|
||||||
// WithScriptDir sets the base directory for scripts
|
// WithScriptDir sets the base directory for scripts
|
||||||
func WithScriptDir(dir string) RunnerOption {
|
func WithScriptDir(dir string) RunnerOption {
|
||||||
return func(r *LuaRunner) {
|
return func(r *LuaRunner) {
|
||||||
|
r.scriptDir = dir
|
||||||
r.requireCfg.ScriptDir = dir
|
r.requireCfg.ScriptDir = dir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -204,31 +204,103 @@ func WithScriptDir(dir string) RunnerOption {
|
||||||
// WithLibDirs sets additional library directories
|
// WithLibDirs sets additional library directories
|
||||||
func WithLibDirs(dirs ...string) RunnerOption {
|
func WithLibDirs(dirs ...string) RunnerOption {
|
||||||
return func(r *LuaRunner) {
|
return func(r *LuaRunner) {
|
||||||
|
r.libDirs = dirs
|
||||||
r.requireCfg.LibDirs = dirs
|
r.requireCfg.LibDirs = dirs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// scriptDir returns the current script directory
|
// setupSandbox initializes the sandbox environment
|
||||||
func (r *LuaRunner) scriptDir() string {
|
func (r *LuaRunner) setupSandbox() error {
|
||||||
if r.requireCfg != nil {
|
// This is the Lua script that creates our sandbox function
|
||||||
return r.requireCfg.ScriptDir
|
setupScript := `
|
||||||
}
|
-- Create a function to run code in a sandbox environment
|
||||||
return ""
|
function __create_sandbox()
|
||||||
}
|
-- Create new environment table
|
||||||
|
local env = {}
|
||||||
|
|
||||||
// libDirs returns the current library directories
|
-- Add standard library modules (can be restricted as needed)
|
||||||
func (r *LuaRunner) libDirs() []string {
|
env.string = string
|
||||||
if r.requireCfg != nil {
|
env.table = table
|
||||||
return r.requireCfg.LibDirs
|
env.math = math
|
||||||
|
env.os = {
|
||||||
|
time = os.time,
|
||||||
|
date = os.date,
|
||||||
|
difftime = os.difftime,
|
||||||
|
clock = os.clock
|
||||||
}
|
}
|
||||||
|
env.tonumber = tonumber
|
||||||
|
env.tostring = tostring
|
||||||
|
env.type = type
|
||||||
|
env.pairs = pairs
|
||||||
|
env.ipairs = ipairs
|
||||||
|
env.next = next
|
||||||
|
env.select = select
|
||||||
|
env.unpack = unpack
|
||||||
|
env.pcall = pcall
|
||||||
|
env.xpcall = xpcall
|
||||||
|
env.error = error
|
||||||
|
env.assert = assert
|
||||||
|
|
||||||
|
-- Set up the standard library package table
|
||||||
|
env.package = {
|
||||||
|
loaded = {} -- Table to store loaded modules
|
||||||
|
}
|
||||||
|
|
||||||
|
-- Explicitly expose the module loader function
|
||||||
|
env.__go_load_module = __go_load_module
|
||||||
|
|
||||||
|
-- Set up secure require function
|
||||||
|
env = __setup_secure_require(env)
|
||||||
|
|
||||||
|
-- Create metatable to restrict access to _G
|
||||||
|
local mt = {
|
||||||
|
__index = function(t, k)
|
||||||
|
-- First check in env table
|
||||||
|
local v = rawget(env, k)
|
||||||
|
if v ~= nil then return v end
|
||||||
|
|
||||||
|
-- If not found, check for registered modules/functions
|
||||||
|
local moduleValue = _G[k]
|
||||||
|
if type(moduleValue) == "table" or
|
||||||
|
type(moduleValue) == "function" then
|
||||||
|
return moduleValue
|
||||||
|
end
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
end,
|
||||||
|
__newindex = function(t, k, v)
|
||||||
|
rawset(env, k, v)
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
setmetatable(env, mt)
|
||||||
|
return env
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Create function to execute code with a sandbox
|
||||||
|
function __run_sandboxed(f, ctx)
|
||||||
|
local env = __create_sandbox()
|
||||||
|
|
||||||
|
-- Add context to the environment if provided
|
||||||
|
if ctx then
|
||||||
|
env.ctx = ctx
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Set the environment and run the function
|
||||||
|
setfenv(f, env)
|
||||||
|
return f()
|
||||||
|
end
|
||||||
|
`
|
||||||
|
|
||||||
|
return r.state.DoString(setupScript)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processJobs handles the job queue
|
// eventLoop processes jobs from the queue
|
||||||
func (r *LuaRunner) processJobs() {
|
func (r *LuaRunner) eventLoop() {
|
||||||
defer r.wg.Done()
|
defer r.wg.Done()
|
||||||
defer r.state.Close()
|
defer r.state.Close()
|
||||||
|
|
||||||
|
// Process jobs until closure
|
||||||
for job := range r.jobQueue {
|
for job := range r.jobQueue {
|
||||||
// Execute the job and send result
|
// Execute the job and send result
|
||||||
result := r.executeJob(job)
|
result := r.executeJob(job)
|
||||||
|
@ -257,14 +329,55 @@ func (r *LuaRunner) executeJob(j job) JobResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert context for sandbox
|
// Set up context if provided
|
||||||
var ctx map[string]any
|
|
||||||
if j.Context != nil {
|
if j.Context != nil {
|
||||||
ctx = j.Context.Values
|
// Push context table
|
||||||
|
r.state.NewTable()
|
||||||
|
|
||||||
|
// Add values to context table
|
||||||
|
for key, value := range j.Context.Values {
|
||||||
|
// Push key
|
||||||
|
r.state.PushString(key)
|
||||||
|
|
||||||
|
// Push value
|
||||||
|
if err := r.state.PushValue(value); err != nil {
|
||||||
|
return JobResult{nil, err}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute in sandbox
|
// Set table[key] = value
|
||||||
value, err := r.sandbox.Execute(r.state, j.Bytecode, ctx)
|
r.state.SetTable(-3)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Push nil if no context
|
||||||
|
r.state.PushNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load bytecode
|
||||||
|
if err := r.state.LoadBytecode(j.Bytecode, j.ScriptPath); err != nil {
|
||||||
|
r.state.Pop(1) // Pop context
|
||||||
|
return JobResult{nil, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the sandbox runner function
|
||||||
|
r.state.GetGlobal("__run_sandboxed")
|
||||||
|
|
||||||
|
// Push loaded function and context as arguments
|
||||||
|
r.state.PushCopy(-2) // Copy the loaded function
|
||||||
|
r.state.PushCopy(-4) // Copy the context table or nil
|
||||||
|
|
||||||
|
// Remove the original function and context
|
||||||
|
r.state.Remove(-5) // Remove original context
|
||||||
|
r.state.Remove(-4) // Remove original function
|
||||||
|
|
||||||
|
// Call the sandbox runner with 2 args (function and context), expecting 1 result
|
||||||
|
if err := r.state.Call(2, 1); err != nil {
|
||||||
|
return JobResult{nil, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get result
|
||||||
|
value, err := r.state.ToValue(-1)
|
||||||
|
r.state.Pop(1) // Pop result
|
||||||
|
|
||||||
return JobResult{value, err}
|
return JobResult{value, err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -329,8 +442,3 @@ func (r *LuaRunner) Close() error {
|
||||||
func (r *LuaRunner) ClearRequireCache() {
|
func (r *LuaRunner) ClearRequireCache() {
|
||||||
r.requireCache = NewRequireCache()
|
r.requireCache = NewRequireCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
|
||||||
func (r *LuaRunner) AddModule(name string, module any) {
|
|
||||||
r.sandbox.AddModule(name, module)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,228 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Sandbox manages a sandboxed Lua environment
|
|
||||||
type Sandbox struct {
|
|
||||||
modules map[string]any // Custom modules for environment
|
|
||||||
initialized bool // Whether base environment is initialized
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox
|
|
||||||
func NewSandbox() *Sandbox {
|
|
||||||
return &Sandbox{
|
|
||||||
modules: make(map[string]any),
|
|
||||||
initialized: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
|
||||||
func (s *Sandbox) AddModule(name string, module any) {
|
|
||||||
s.modules[name] = module
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup initializes the sandbox in a Lua state
|
|
||||||
func (s *Sandbox) Setup(state *luajit.State) error {
|
|
||||||
// Register modules
|
|
||||||
if err := s.registerModules(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup the sandbox creation logic with base environment reuse
|
|
||||||
return state.DoString(`
|
|
||||||
-- Create the base environment once (static parts)
|
|
||||||
local __base_env = nil
|
|
||||||
|
|
||||||
-- Create function to initialize base environment
|
|
||||||
function __init_base_env()
|
|
||||||
if __base_env then return end
|
|
||||||
|
|
||||||
local env = {}
|
|
||||||
|
|
||||||
-- Add standard library modules (restricted)
|
|
||||||
env.string = string
|
|
||||||
env.table = table
|
|
||||||
env.math = math
|
|
||||||
env.os = {
|
|
||||||
time = os.time,
|
|
||||||
date = os.date,
|
|
||||||
difftime = os.difftime,
|
|
||||||
clock = os.clock
|
|
||||||
}
|
|
||||||
env.tonumber = tonumber
|
|
||||||
env.tostring = tostring
|
|
||||||
env.type = type
|
|
||||||
env.pairs = pairs
|
|
||||||
env.ipairs = ipairs
|
|
||||||
env.next = next
|
|
||||||
env.select = select
|
|
||||||
env.unpack = unpack
|
|
||||||
env.pcall = pcall
|
|
||||||
env.xpcall = xpcall
|
|
||||||
env.error = error
|
|
||||||
env.assert = assert
|
|
||||||
|
|
||||||
-- Add module loader
|
|
||||||
env.__go_load_module = __go_load_module
|
|
||||||
|
|
||||||
-- Add custom modules from sandbox registry
|
|
||||||
if __sandbox_modules then
|
|
||||||
for name, module in pairs(__sandbox_modules) do
|
|
||||||
env[name] = module
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Copy custom global functions
|
|
||||||
for k, v in pairs(_G) do
|
|
||||||
if (type(v) == "function" or type(v) == "table") and
|
|
||||||
k ~= "__sandbox_modules" and
|
|
||||||
k ~= "__base_env" and
|
|
||||||
k ~= "__init_base_env" and
|
|
||||||
k ~= "__create_sandbox_env" and
|
|
||||||
k ~= "__run_sandboxed" and
|
|
||||||
k ~= "__setup_secure_require" and
|
|
||||||
k ~= "__go_load_module" and
|
|
||||||
k ~= "string" and k ~= "table" and k ~= "math" and
|
|
||||||
k ~= "os" and k ~= "io" and k ~= "debug" and
|
|
||||||
k ~= "package" and k ~= "bit" and k ~= "jit" and
|
|
||||||
k ~= "coroutine" and k ~= "_G" and k ~= "_VERSION" then
|
|
||||||
env[k] = v
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
__base_env = env
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Create function that builds sandbox from base env
|
|
||||||
function __create_sandbox_env(ctx)
|
|
||||||
-- Initialize base env if needed
|
|
||||||
__init_base_env()
|
|
||||||
|
|
||||||
-- Create new environment using base as prototype
|
|
||||||
local env = {}
|
|
||||||
|
|
||||||
-- Copy from base environment
|
|
||||||
for k, v in pairs(__base_env) do
|
|
||||||
env[k] = v
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Add isolated package.loaded table
|
|
||||||
env.package = {
|
|
||||||
loaded = {}
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Add context if provided
|
|
||||||
if ctx then
|
|
||||||
env.ctx = ctx
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Setup require function
|
|
||||||
env = __setup_secure_require(env)
|
|
||||||
|
|
||||||
-- Create metatable for isolation
|
|
||||||
local mt = {
|
|
||||||
__index = function(t, k)
|
|
||||||
return rawget(env, k)
|
|
||||||
end,
|
|
||||||
__newindex = function(t, k, v)
|
|
||||||
rawset(env, k, v)
|
|
||||||
end
|
|
||||||
}
|
|
||||||
|
|
||||||
setmetatable(env, mt)
|
|
||||||
return env
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Function to run code in sandbox
|
|
||||||
function __run_sandboxed(bytecode, ctx)
|
|
||||||
-- Create environment for this request
|
|
||||||
local env = __create_sandbox_env(ctx)
|
|
||||||
|
|
||||||
-- Set environment and execute
|
|
||||||
setfenv(bytecode, env)
|
|
||||||
return bytecode()
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerModules registers custom modules in the Lua state
|
|
||||||
func (s *Sandbox) registerModules(state *luajit.State) error {
|
|
||||||
// Create or get module registry table
|
|
||||||
state.GetGlobal("__sandbox_modules")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
// Table doesn't exist, create it
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__sandbox_modules")
|
|
||||||
state.GetGlobal("__sandbox_modules")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add modules to registry
|
|
||||||
for name, module := range s.modules {
|
|
||||||
state.PushString(name)
|
|
||||||
if err := state.PushValue(module); err != nil {
|
|
||||||
state.Pop(2)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pop module table
|
|
||||||
state.Pop(1)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute runs bytecode in the sandbox
|
|
||||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
|
|
||||||
// Update modules if needed
|
|
||||||
if !s.initialized {
|
|
||||||
if err := s.registerModules(state); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s.initialized = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load bytecode
|
|
||||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create context table if provided
|
|
||||||
if len(ctx) > 0 {
|
|
||||||
state.NewTable()
|
|
||||||
for k, v := range ctx {
|
|
||||||
state.PushString(k)
|
|
||||||
if err := state.PushValue(v); err != nil {
|
|
||||||
state.Pop(3)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
state.PushNil() // No context
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get sandbox function
|
|
||||||
state.GetGlobal("__run_sandboxed")
|
|
||||||
|
|
||||||
// Setup call with correct argument order
|
|
||||||
state.PushCopy(-3) // Copy bytecode function
|
|
||||||
state.PushCopy(-3) // Copy context
|
|
||||||
|
|
||||||
// Clean up stack
|
|
||||||
state.Remove(-5) // Remove original bytecode
|
|
||||||
state.Remove(-4) // Remove original context
|
|
||||||
|
|
||||||
// Call sandbox function
|
|
||||||
if err := state.Call(2, 1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get result
|
|
||||||
result, err := state.ToValue(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
|
@ -85,7 +85,7 @@ func main() {
|
||||||
log.Fatal("Router initialization failed: %v", err)
|
log.Fatal("Router initialization failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.GetBool("watchers", true) {
|
if cfg.GetBool("watchers", false) {
|
||||||
// Set up file watchers for automatic reloading
|
// Set up file watchers for automatic reloading
|
||||||
luaWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log)
|
luaWatcher, err := watchers.WatchLuaRouter(luaRouter, routesDir, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -104,10 +104,10 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get buffer size from config or use default
|
// Get buffer size from config or use default (used to be worker pool size)
|
||||||
bufferSize := cfg.GetInt("buffer_size", 20)
|
bufferSize := cfg.GetInt("buffer_size", 20)
|
||||||
|
|
||||||
// Initialize Lua runner
|
// Initialize Lua runner (replacing worker pool)
|
||||||
runner, err := runner.NewRunner(
|
runner, err := runner.NewRunner(
|
||||||
runner.WithBufferSize(bufferSize),
|
runner.WithBufferSize(bufferSize),
|
||||||
runner.WithLibDirs("./libs"),
|
runner.WithLibDirs("./libs"),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user