Moonshark/runner/sandbox.go

299 lines
7.3 KiB
Go

package runner
import (
"fmt"
"sync"
"github.com/goccy/go-json"
"github.com/valyala/fasthttp"
"Moonshark/utils/logger"
"maps"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Error represents a simple error string
type Error string
func (e Error) Error() string {
return string(e)
}
// Error types
var (
ErrSandboxNotInitialized = Error("sandbox not initialized")
)
// Sandbox provides a secure execution environment for Lua scripts
type Sandbox struct {
modules map[string]any
debug bool
mu sync.RWMutex
}
// NewSandbox creates a new sandbox environment
func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any, 8),
debug: false,
}
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) {
s.mu.Lock()
defer s.mu.Unlock()
s.modules[name] = module
logger.Debugf("Added module: %s", name)
}
// Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State, verbose bool) error {
if verbose {
logger.Debugf("Setting up sandbox...")
}
if err := loadSandboxIntoState(state, verbose); err != nil {
logger.Errorf("Failed to load sandbox: %v", err)
return err
}
if err := s.registerCoreFunctions(state); err != nil {
logger.Errorf("Failed to register core functions: %v", err)
return err
}
s.mu.RLock()
for name, module := range s.modules {
logger.Debugf("Registering module: %s", name)
if err := state.PushValue(module); err != nil {
s.mu.RUnlock()
logger.Errorf("Failed to register module %s: %v", name, err)
return err
}
state.SetGlobal(name)
}
s.mu.RUnlock()
if verbose {
logger.Debugf("Sandbox setup complete")
}
return nil
}
// registerCoreFunctions registers all built-in functions in the Lua state
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
return err
}
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
return err
}
if err := state.RegisterGoFunction("__json_marshal", jsonMarshal); err != nil {
return err
}
if err := state.RegisterGoFunction("__json_unmarshal", jsonUnmarshal); err != nil {
return err
}
if err := RegisterSQLiteFunctions(state); err != nil {
return err
}
if err := RegisterFSFunctions(state); err != nil {
return err
}
if err := RegisterPasswordFunctions(state); err != nil {
return err
}
if err := RegisterUtilFunctions(state); err != nil {
return err
}
if err := RegisterCryptoFunctions(state); err != nil {
return err
}
if err := RegisterEnvFunctions(state); err != nil {
return err
}
return nil
}
// Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
// Load bytecode - pushes function onto stack
if err := state.LoadBytecode(bytecode, "script"); err != nil {
return nil, fmt.Errorf("failed to load bytecode: %w", err)
}
// Stack: [function]
state.GetGlobal("__execute_script") // Stack: [function, __execute_script]
state.PushCopy(-2) // Stack: [function, __execute_script, function]
// Push context using PushValue
if err := state.PushValue(ctx.Values); err != nil {
state.Pop(3)
return nil, fmt.Errorf("failed to push context: %w", err)
}
// Stack: [function, __execute_script, function, context]
// Call __execute_script(function, context)
if err := state.Call(2, 1); err != nil {
state.Pop(1) // Clean up original function
return nil, fmt.Errorf("script execution failed: %w", err)
}
// Stack: [function, result]
response := NewResponse()
if result, err := state.ToValue(-1); err == nil {
response.Body = result
}
state.SetTop(0) // Clear stack
extractHTTPResponseData(state, response)
return response, nil
}
// extractResponseData pulls response info from the Lua state using new API
func extractHTTPResponseData(state *luajit.State, response *Response) {
state.GetGlobal("__http_response")
if !state.IsTable(-1) {
state.Pop(1)
return
}
// Use new field getters with defaults
response.Status = int(state.GetFieldNumber(-1, "status", 200))
// Extract headers using ForEachTableKV
if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
switch headers := headerTable.(type) {
case map[string]any:
for k, v := range headers {
if str, ok := v.(string); ok {
response.Headers[k] = str
}
}
case map[string]string:
maps.Copy(response.Headers, headers)
}
}
// Extract cookies using ForEachArray
state.GetField(-1, "cookies")
if state.IsTable(-1) {
state.ForEachArray(-1, func(i int, s *luajit.State) bool {
if s.IsTable(-1) {
extractCookie(s, response)
}
return true
})
}
state.Pop(1)
// Extract metadata
if metadata, ok := state.GetFieldTable(-1, "metadata"); ok {
if metaMap, ok := metadata.(map[string]any); ok {
maps.Copy(response.Metadata, metaMap)
}
}
// Extract session data
if session, ok := state.GetFieldTable(-1, "session"); ok {
switch sessMap := session.(type) {
case map[string]any:
maps.Copy(response.SessionData, sessMap)
case map[string]string:
for k, v := range sessMap {
response.SessionData[k] = v
}
case map[string]int:
for k, v := range sessMap {
response.SessionData[k] = v
}
default:
logger.Debugf("Unexpected session type: %T", session)
}
}
state.Pop(1) // Pop __http_response
}
// extractCookie pulls cookie data from the current table on the stack using new API
func extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie()
// Use new field getters with defaults
name := state.GetFieldString(-1, "name", "")
if name == "" {
fasthttp.ReleaseCookie(cookie)
return
}
cookie.SetKey(name)
cookie.SetValue(state.GetFieldString(-1, "value", ""))
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
cookie.SetDomain(state.GetFieldString(-1, "domain", ""))
cookie.SetHTTPOnly(state.GetFieldBool(-1, "http_only", false))
cookie.SetSecure(state.GetFieldBool(-1, "secure", false))
cookie.SetMaxAge(int(state.GetFieldNumber(-1, "max_age", 0)))
response.Cookies = append(response.Cookies, cookie)
}
// jsonMarshal converts a Lua value to a JSON string with validation
func jsonMarshal(state *luajit.State) int {
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("json marshal: %v", err)
}
value, err := state.SafeToTable(1)
if err != nil {
// Try as generic value if not a table
value, err = state.ToValue(1)
if err != nil {
return state.PushError("json marshal error: %v", err)
}
}
bytes, err := json.Marshal(value)
if err != nil {
return state.PushError("json marshal error: %v", err)
}
state.PushString(string(bytes))
return 1
}
// jsonUnmarshal converts a JSON string to a Lua value with validation
func jsonUnmarshal(state *luajit.State) int {
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("json unmarshal: %v", err)
}
jsonStr, err := state.SafeToString(1)
if err != nil {
return state.PushError("json unmarshal: expected string, got %s", state.GetType(1))
}
var value any
if err := json.Unmarshal([]byte(jsonStr), &value); err != nil {
return state.PushError("json unmarshal error: %v", err)
}
if err := state.PushValue(value); err != nil {
return state.PushError("json unmarshal error: %v", err)
}
return 1
}