diff --git a/functions/crypto.go b/functions/crypto.go deleted file mode 100644 index b3b2af3..0000000 --- a/functions/crypto.go +++ /dev/null @@ -1,394 +0,0 @@ -package functions - -import ( - "crypto/hmac" - "crypto/md5" - "crypto/rand" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/base64" - "encoding/hex" - "fmt" - "math/big" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/google/uuid" -) - -func GetCryptoFunctions() map[string]luajit.GoFunction { - return map[string]luajit.GoFunction{ - "base64_encode": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("base64_encode: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("base64_encode: argument must be a string") - return 2 - } - encoded := base64.StdEncoding.EncodeToString([]byte(str)) - s.PushString(encoded) - return 1 - }, - - "base64_decode": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("base64_decode: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("base64_decode: argument must be a string") - return 2 - } - decoded, err := base64.StdEncoding.DecodeString(str) - if err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("base64_decode: %v", err)) - return 2 - } - s.PushString(string(decoded)) - return 1 - }, - - "base64_url_encode": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("base64_url_encode: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("base64_url_encode: argument must be a string") - return 2 - } - encoded := base64.URLEncoding.EncodeToString([]byte(str)) - s.PushString(encoded) - return 1 - }, - - "base64_url_decode": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("base64_url_decode: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("base64_url_decode: argument must be a string") - return 2 - } - decoded, err := base64.URLEncoding.DecodeString(str) - if err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("base64_url_decode: %v", err)) - return 2 - } - s.PushString(string(decoded)) - return 1 - }, - - "hex_encode": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("hex_encode: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("hex_encode: argument must be a string") - return 2 - } - encoded := hex.EncodeToString([]byte(str)) - s.PushString(encoded) - return 1 - }, - - "hex_decode": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("hex_decode: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("hex_decode: argument must be a string") - return 2 - } - decoded, err := hex.DecodeString(str) - if err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("hex_decode: %v", err)) - return 2 - } - s.PushString(string(decoded)) - return 1 - }, - - "md5_hash": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("md5_hash: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("md5_hash: argument must be a string") - return 2 - } - hash := md5.Sum([]byte(str)) - s.PushString(hex.EncodeToString(hash[:])) - return 1 - }, - - "sha1_hash": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("sha1_hash: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("sha1_hash: argument must be a string") - return 2 - } - hash := sha1.Sum([]byte(str)) - s.PushString(hex.EncodeToString(hash[:])) - return 1 - }, - - "sha256_hash": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("sha256_hash: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("sha256_hash: argument must be a string") - return 2 - } - hash := sha256.Sum256([]byte(str)) - s.PushString(hex.EncodeToString(hash[:])) - return 1 - }, - - "sha512_hash": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("sha512_hash: %v", err)) - return 2 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("sha512_hash: argument must be a string") - return 2 - } - hash := sha512.Sum512([]byte(str)) - s.PushString(hex.EncodeToString(hash[:])) - return 1 - }, - - "hmac_sha256": func(s *luajit.State) int { - if err := s.CheckExactArgs(2); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("hmac_sha256: %v", err)) - return 2 - } - message, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("hmac_sha256: first argument must be a string") - return 2 - } - key, err := s.SafeToString(2) - if err != nil { - s.PushNil() - s.PushString("hmac_sha256: second argument must be a string") - return 2 - } - h := hmac.New(sha256.New, []byte(key)) - h.Write([]byte(message)) - s.PushString(hex.EncodeToString(h.Sum(nil))) - return 1 - }, - - "hmac_sha1": func(s *luajit.State) int { - if err := s.CheckExactArgs(2); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("hmac_sha1: %v", err)) - return 2 - } - message, err := s.SafeToString(1) - if err != nil { - s.PushNil() - s.PushString("hmac_sha1: first argument must be a string") - return 2 - } - key, err := s.SafeToString(2) - if err != nil { - s.PushNil() - s.PushString("hmac_sha1: second argument must be a string") - return 2 - } - h := hmac.New(sha1.New, []byte(key)) - h.Write([]byte(message)) - s.PushString(hex.EncodeToString(h.Sum(nil))) - return 1 - }, - - "uuid_generate": func(s *luajit.State) int { - id := uuid.New() - s.PushString(id.String()) - return 1 - }, - - "uuid_generate_v4": func(s *luajit.State) int { - id := uuid.New() - s.PushString(id.String()) - return 1 - }, - - "uuid_validate": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushBoolean(false) - return 1 - } - str, err := s.SafeToString(1) - if err != nil { - s.PushBoolean(false) - return 1 - } - _, err = uuid.Parse(str) - s.PushBoolean(err == nil) - return 1 - }, - - "random_bytes": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("random_bytes: %v", err)) - return 2 - } - length, err := s.SafeToNumber(1) - if err != nil || length < 0 || length != float64(int(length)) { - s.PushNil() - s.PushString("random_bytes: argument must be a non-negative integer") - return 2 - } - if length > 65536 { - s.PushNil() - s.PushString("random_bytes: length too large (max 65536)") - return 2 - } - bytes := make([]byte, int(length)) - if _, err := rand.Read(bytes); err != nil { - s.PushNil() - s.PushString("random_bytes: failed to generate random bytes") - return 2 - } - s.PushString(string(bytes)) - return 1 - }, - - "random_hex": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("random_hex: %v", err)) - return 2 - } - length, err := s.SafeToNumber(1) - if err != nil || length < 0 || length != float64(int(length)) { - s.PushNil() - s.PushString("random_hex: argument must be a non-negative integer") - return 2 - } - if length > 32768 { - s.PushNil() - s.PushString("random_hex: length too large (max 32768)") - return 2 - } - bytes := make([]byte, int(length)) - if _, err := rand.Read(bytes); err != nil { - s.PushNil() - s.PushString("random_hex: failed to generate random bytes") - return 2 - } - s.PushString(hex.EncodeToString(bytes)) - return 1 - }, - - "random_string": func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - s.PushNil() - s.PushString(fmt.Sprintf("random_string: %v", err)) - return 2 - } - length, err := s.SafeToNumber(1) - if err != nil || length < 0 || length != float64(int(length)) { - s.PushNil() - s.PushString("random_string: argument must be a non-negative integer") - return 2 - } - if length > 65536 { - s.PushNil() - s.PushString("random_string: length too large (max 65536)") - return 2 - } - - charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - if s.GetTop() >= 2 { - if customCharset, err := s.SafeToString(2); err == nil { - charset = customCharset - } - } - - result := make([]byte, int(length)) - charsetLen := big.NewInt(int64(len(charset))) - for i := range result { - n, err := rand.Int(rand.Reader, charsetLen) - if err != nil { - s.PushNil() - s.PushString("random_string: failed to generate random number") - return 2 - } - result[i] = charset[n.Int64()] - } - s.PushString(string(result)) - return 1 - }, - - "secure_compare": func(s *luajit.State) int { - if err := s.CheckExactArgs(2); err != nil { - s.PushBoolean(false) - return 1 - } - a, err := s.SafeToString(1) - if err != nil { - s.PushBoolean(false) - return 1 - } - b, err := s.SafeToString(2) - if err != nil { - s.PushBoolean(false) - return 1 - } - s.PushBoolean(hmac.Equal([]byte(a), []byte(b))) - return 1 - }, - } -} diff --git a/functions/registry.go b/functions/registry.go deleted file mode 100644 index 2524852..0000000 --- a/functions/registry.go +++ /dev/null @@ -1,41 +0,0 @@ -package functions - -import ( - "maps" - "sync" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// Registry holds all available Go functions for Lua modules -type Registry map[string]luajit.GoFunction - -// GetAll returns all registered functions -func GetAll() Registry { - registry := make(Registry) - - maps.Copy(registry, GetJSONFunctions()) - maps.Copy(registry, GetStringFunctions()) - maps.Copy(registry, GetMathFunctions()) - maps.Copy(registry, GetFSFunctions()) - maps.Copy(registry, GetCryptoFunctions()) - - return registry -} - -var ( - storedBytecode []byte - bytecodeMutex sync.RWMutex -) - -func SetStoredBytecode(bytecode []byte) { - bytecodeMutex.Lock() - defer bytecodeMutex.Unlock() - storedBytecode = bytecode -} - -func GetStoredBytecode() []byte { - bytecodeMutex.RLock() - defer bytecodeMutex.RUnlock() - return storedBytecode -} diff --git a/modules/crypto/crypto.go b/modules/crypto/crypto.go new file mode 100644 index 0000000..d4877e1 --- /dev/null +++ b/modules/crypto/crypto.go @@ -0,0 +1,231 @@ +package crypto + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "encoding/hex" + "math/big" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "github.com/google/uuid" +) + +func GetFunctionList() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "base64_encode": base64_encode, + "base64_decode": base64_decode, + "base64_url_encode": base64_url_encode, + "base64_url_decode": base64_url_decode, + "hex_encode": hex_encode, + "hex_decode": hex_decode, + "md5_hash": md5_hash, + "sha1_hash": sha1_hash, + "sha256_hash": sha256_hash, + "sha512_hash": sha512_hash, + "hmac_sha256": hmac_sha256, + "hmac_sha1": hmac_sha1, + "uuid_generate": uuid_generate, + "uuid_generate_v4": uuid_generate_v4, + "uuid_validate": uuid_validate, + "random_bytes": random_bytes, + "random_hex": random_hex, + "random_string": random_string, + "secure_compare": secure_compare, + } +} + +func base64_encode(s *luajit.State) int { + str := s.ToString(1) + encoded := base64.StdEncoding.EncodeToString([]byte(str)) + s.PushString(encoded) + return 1 +} + +func base64_decode(s *luajit.State) int { + str := s.ToString(1) + decoded, err := base64.StdEncoding.DecodeString(str) + if err != nil { + s.PushNil() + s.PushString("invalid base64 data") + return 2 + } + s.PushString(string(decoded)) + return 1 +} + +func base64_url_encode(s *luajit.State) int { + str := s.ToString(1) + encoded := base64.URLEncoding.EncodeToString([]byte(str)) + s.PushString(encoded) + return 1 +} + +func base64_url_decode(s *luajit.State) int { + str := s.ToString(1) + decoded, err := base64.URLEncoding.DecodeString(str) + if err != nil { + s.PushNil() + s.PushString("invalid base64url data") + return 2 + } + s.PushString(string(decoded)) + return 1 +} + +func hex_encode(s *luajit.State) int { + str := s.ToString(1) + encoded := hex.EncodeToString([]byte(str)) + s.PushString(encoded) + return 1 +} + +func hex_decode(s *luajit.State) int { + str := s.ToString(1) + decoded, err := hex.DecodeString(str) + if err != nil { + s.PushNil() + s.PushString("invalid hex data") + return 2 + } + s.PushString(string(decoded)) + return 1 +} + +func md5_hash(s *luajit.State) int { + str := s.ToString(1) + hash := md5.Sum([]byte(str)) + s.PushString(hex.EncodeToString(hash[:])) + return 1 +} + +func sha1_hash(s *luajit.State) int { + str := s.ToString(1) + hash := sha1.Sum([]byte(str)) + s.PushString(hex.EncodeToString(hash[:])) + return 1 +} + +func sha256_hash(s *luajit.State) int { + str := s.ToString(1) + hash := sha256.Sum256([]byte(str)) + s.PushString(hex.EncodeToString(hash[:])) + return 1 +} + +func sha512_hash(s *luajit.State) int { + str := s.ToString(1) + hash := sha512.Sum512([]byte(str)) + s.PushString(hex.EncodeToString(hash[:])) + return 1 +} + +func hmac_sha256(s *luajit.State) int { + message := s.ToString(1) + key := s.ToString(2) + h := hmac.New(sha256.New, []byte(key)) + h.Write([]byte(message)) + s.PushString(hex.EncodeToString(h.Sum(nil))) + return 1 +} + +func hmac_sha1(s *luajit.State) int { + message := s.ToString(1) + key := s.ToString(2) + h := hmac.New(sha1.New, []byte(key)) + h.Write([]byte(message)) + s.PushString(hex.EncodeToString(h.Sum(nil))) + return 1 +} + +func uuid_generate(s *luajit.State) int { + id := uuid.New() + s.PushString(id.String()) + return 1 +} + +func uuid_generate_v4(s *luajit.State) int { + id := uuid.New() + s.PushString(id.String()) + return 1 +} + +func uuid_validate(s *luajit.State) int { + str := s.ToString(1) + _, err := uuid.Parse(str) + s.PushBoolean(err == nil) + return 1 +} + +func random_bytes(s *luajit.State) int { + length := int(s.ToNumber(1)) + if length < 0 || length > 65536 { + s.PushNil() + s.PushString("invalid length") + return 2 + } + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + s.PushNil() + s.PushString("failed to generate random bytes") + return 2 + } + s.PushString(string(bytes)) + return 1 +} + +func random_hex(s *luajit.State) int { + length := int(s.ToNumber(1)) + if length < 0 || length > 32768 { + s.PushNil() + s.PushString("invalid length") + return 2 + } + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + s.PushNil() + s.PushString("failed to generate random bytes") + return 2 + } + s.PushString(hex.EncodeToString(bytes)) + return 1 +} + +func random_string(s *luajit.State) int { + length := int(s.ToNumber(1)) + if length < 0 || length > 65536 { + s.PushNil() + s.PushString("invalid length") + return 2 + } + + charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + if s.GetTop() >= 2 { + charset = s.ToString(2) + } + + result := make([]byte, length) + charsetLen := big.NewInt(int64(len(charset))) + for i := range result { + n, err := rand.Int(rand.Reader, charsetLen) + if err != nil { + s.PushNil() + s.PushString("failed to generate random number") + return 2 + } + result[i] = charset[n.Int64()] + } + s.PushString(string(result)) + return 1 +} + +func secure_compare(s *luajit.State) int { + a := s.ToString(1) + b := s.ToString(2) + s.PushBoolean(hmac.Equal([]byte(a), []byte(b))) + return 1 +} diff --git a/modules/crypto.lua b/modules/crypto/crypto.lua similarity index 100% rename from modules/crypto.lua rename to modules/crypto/crypto.lua diff --git a/functions/fs.go b/modules/fs/fs.go similarity index 99% rename from functions/fs.go rename to modules/fs/fs.go index 14929bf..898091c 100644 --- a/functions/fs.go +++ b/modules/fs/fs.go @@ -1,4 +1,4 @@ -package functions +package fs import ( "io" diff --git a/modules/fs.lua b/modules/fs/fs.lua similarity index 100% rename from modules/fs.lua rename to modules/fs/fs.lua diff --git a/functions/json.go b/modules/json/json.go similarity index 98% rename from functions/json.go rename to modules/json/json.go index 130ab62..1936c15 100644 --- a/functions/json.go +++ b/modules/json/json.go @@ -1,4 +1,4 @@ -package functions +package json import ( luajit "git.sharkk.net/Sky/LuaJIT-to-Go" diff --git a/modules/json.lua b/modules/json/json.lua similarity index 100% rename from modules/json.lua rename to modules/json/json.lua diff --git a/functions/math.go b/modules/math/math.go similarity index 99% rename from functions/math.go rename to modules/math/math.go index fd9ffb3..07b2bff 100644 --- a/functions/math.go +++ b/modules/math/math.go @@ -1,4 +1,4 @@ -package functions +package math import luajit "git.sharkk.net/Sky/LuaJIT-to-Go" diff --git a/modules/math.lua b/modules/math/math.lua similarity index 100% rename from modules/math.lua rename to modules/math/math.lua diff --git a/modules/modules.go b/modules/modules.go deleted file mode 100644 index 5083c9e..0000000 --- a/modules/modules.go +++ /dev/null @@ -1,139 +0,0 @@ -package modules - -import ( - "embed" - "fmt" - "path/filepath" - "strings" - - "Moonshark/functions" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -//go:embed *.lua -var builtinModules embed.FS - -// ModuleRegistry manages built-in modules and Go functions -type ModuleRegistry struct { - modules map[string]string - goFuncs map[string]luajit.GoFunction -} - -// NewModuleRegistry creates a new module registry -func NewModuleRegistry() *ModuleRegistry { - mr := &ModuleRegistry{ - modules: make(map[string]string), - goFuncs: functions.GetAll(), - } - - return mr -} - -// RegisterModule adds a module by name and source code -func (mr *ModuleRegistry) RegisterModule(name, source string) { - mr.modules[name] = source -} - -// RegisterGoFunction adds a Go function that modules can use -func (mr *ModuleRegistry) RegisterGoFunction(name string, fn luajit.GoFunction) { - mr.goFuncs[name] = fn -} - -// LoadEmbeddedModules loads all modules from the embedded filesystem -func (mr *ModuleRegistry) LoadEmbeddedModules() error { - entries, err := builtinModules.ReadDir(".") - if err != nil { - return fmt.Errorf("failed to read modules directory: %w", err) - } - - for _, entry := range entries { - if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".lua") { - continue - } - - moduleName := strings.TrimSuffix(entry.Name(), ".lua") - source, err := builtinModules.ReadFile(filepath.Join(".", entry.Name())) - if err != nil { - return fmt.Errorf("failed to read module %s: %w", moduleName, err) - } - - mr.RegisterModule(moduleName, string(source)) - } - - return nil -} - -// InstallModules sets up the module system in the Lua state -func (mr *ModuleRegistry) InstallModules(state *luajit.State) error { - // Create moonshark global table - state.NewTable() - state.SetGlobal("moonshark") - - // Install Go functions first - if err := mr.installGoFunctions(state); err != nil { - return fmt.Errorf("failed to install Go functions: %w", err) - } - - // Register require function that checks our built-in modules first - err := state.RegisterGoFunction("require", func(s *luajit.State) int { - if err := s.CheckMinArgs(1); err != nil { - return s.PushError("require: %v", err) - } - - moduleName, err := s.SafeToString(1) - if err != nil { - return s.PushError("require: module name must be a string") - } - - // Check if it's a built-in module - if source, exists := mr.modules[moduleName]; exists { - // Execute the module and return its result - if err := s.LoadString(source); err != nil { - return s.PushError("require: failed to load module '%s': %v", moduleName, err) - } - - if err := s.Call(0, 1); err != nil { - return s.PushError("require: failed to execute module '%s': %v", moduleName, err) - } - - return 1 // Return the module's result - } - - // Fall back to standard Lua require - s.GetGlobal("_require_original") - if s.IsFunction(-1) { - s.PushString(moduleName) - if err := s.Call(1, 1); err != nil { - return s.PushError("require: %v", err) - } - return 1 - } - - return s.PushError("require: module '%s' not found", moduleName) - }) - - return err -} - -// installGoFunctions installs all registered Go functions into the Lua state -func (mr *ModuleRegistry) installGoFunctions(state *luajit.State) error { - // Install functions in moonshark namespace - state.GetGlobal("moonshark") - - for name, fn := range mr.goFuncs { - if err := state.PushGoFunction(fn); err != nil { - return fmt.Errorf("failed to register Go function '%s': %w", name, err) - } - state.SetField(-2, name) - } - - state.Pop(1) // Remove moonshark table - return nil -} - -// BackupOriginalRequire saves the original require function -func BackupOriginalRequire(state *luajit.State) { - state.GetGlobal("require") - state.SetGlobal("_require_original") -} diff --git a/modules/registry.go b/modules/registry.go new file mode 100644 index 0000000..46644ae --- /dev/null +++ b/modules/registry.go @@ -0,0 +1,152 @@ +package registry + +import ( + "embed" + "fmt" + "maps" + "path/filepath" + "strings" + "sync" + + "Moonshark/functions" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Global registry instance +var Global *Registry + +//go:embed modules/*.lua +var embeddedModules embed.FS + +// Registry manages all Lua modules and Go functions +type Registry struct { + modules map[string]string + goFuncs map[string]luajit.GoFunction + mutex sync.RWMutex +} + +// New creates a new registry with all Go functions loaded +func New() *Registry { + r := &Registry{ + modules: make(map[string]string), + goFuncs: make(map[string]luajit.GoFunction), + } + + // Load all Go functions + maps.Copy(r.goFuncs, functions.GetJSONFunctions()) + maps.Copy(r.goFuncs, functions.GetStringFunctions()) + maps.Copy(r.goFuncs, functions.GetMathFunctions()) + maps.Copy(r.goFuncs, functions.GetFSFunctions()) + maps.Copy(r.goFuncs, functions.GetCryptoFunctions()) + + return r +} + +// LoadEmbeddedModules loads all .lua files from embedded filesystem +func (r *Registry) LoadEmbeddedModules() error { + entries, err := embeddedModules.ReadDir("modules") + if err != nil { + return fmt.Errorf("failed to read modules directory: %w", err) + } + + r.mutex.Lock() + defer r.mutex.Unlock() + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".lua") { + continue + } + + moduleName := strings.TrimSuffix(entry.Name(), ".lua") + source, err := embeddedModules.ReadFile(filepath.Join("modules", entry.Name())) + if err != nil { + return fmt.Errorf("failed to read module %s: %w", moduleName, err) + } + + r.modules[moduleName] = string(source) + } + + return nil +} + +// InstallInState sets up the complete module system in a Lua state +func (r *Registry) InstallInState(state *luajit.State) error { + r.mutex.RLock() + defer r.mutex.RUnlock() + + // Create moonshark global table + state.NewTable() + state.SetGlobal("moonshark") + + // Install Go functions + state.GetGlobal("moonshark") + for name, fn := range r.goFuncs { + if err := state.PushGoFunction(fn); err != nil { + return fmt.Errorf("failed to register Go function '%s': %w", name, err) + } + state.SetField(-2, name) + } + state.Pop(1) // Remove moonshark table + + // Backup original require + state.GetGlobal("require") + state.SetGlobal("_require_original") + + // Install custom require function + return state.RegisterGoFunction("require", func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("require: %v", err) + } + + moduleName, err := s.SafeToString(1) + if err != nil { + return s.PushError("require: module name must be a string") + } + + // Check built-in modules first + r.mutex.RLock() + source, exists := r.modules[moduleName] + r.mutex.RUnlock() + + if exists { + if err := s.LoadString(source); err != nil { + return s.PushError("require: failed to load module '%s': %v", moduleName, err) + } + if err := s.Call(0, 1); err != nil { + return s.PushError("require: failed to execute module '%s': %v", moduleName, err) + } + return 1 + } + + // Fall back to original require + s.GetGlobal("_require_original") + if s.IsFunction(-1) { + s.PushString(moduleName) + if err := s.Call(1, 1); err != nil { + return s.PushError("require: %v", err) + } + return 1 + } + + return s.PushError("require: module '%s' not found", moduleName) + }) +} + +// Initialize sets up the global registry with all modules loaded +func Initialize() error { + Global = New() + return Global.LoadEmbeddedModules() +} + +// GetGoFunctions returns all Go functions +func (r *Registry) GetGoFunctions() map[string]luajit.GoFunction { + r.mutex.RLock() + defer r.mutex.RUnlock() + + result := make(map[string]luajit.GoFunction, len(r.goFuncs)) + for k, v := range r.goFuncs { + result[k] = v + } + return result +} diff --git a/functions/string.go b/modules/string/string.go similarity index 99% rename from functions/string.go rename to modules/string/string.go index 4b60d0a..e2c64f0 100644 --- a/functions/string.go +++ b/modules/string/string.go @@ -1,4 +1,4 @@ -package functions +package string import ( "fmt" diff --git a/modules/string.lua b/modules/string/string.lua similarity index 100% rename from modules/string.lua rename to modules/string/string.lua diff --git a/moonshark.go b/moonshark.go index 7608ae2..fd143e7 100644 --- a/moonshark.go +++ b/moonshark.go @@ -1,12 +1,12 @@ package main import ( - "Moonshark/functions" - "Moonshark/modules" "fmt" "os" "path/filepath" + "Moonshark/registry" + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) @@ -24,6 +24,12 @@ func main() { os.Exit(1) } + // Initialize global registry + if err := registry.Initialize(); err != nil { + fmt.Fprintf(os.Stderr, "Error: failed to initialize registry: %v\n", err) + os.Exit(1) + } + // Create new Lua state with standard libraries state := luajit.New() if state == nil { @@ -32,15 +38,8 @@ func main() { } defer state.Close() - // Set up module system - registry := modules.NewModuleRegistry() - if err := registry.LoadEmbeddedModules(); err != nil { - fmt.Fprintf(os.Stderr, "Warning: failed to load built-in modules: %v\n", err) - } - - // Backup original require and install module system - modules.BackupOriginalRequire(state) - if err := registry.InstallModules(state); err != nil { + // Install module system in main state + if err := registry.Global.InstallInState(state); err != nil { fmt.Fprintf(os.Stderr, "Error: failed to install module system: %v\n", err) os.Exit(1) } @@ -73,9 +72,6 @@ func main() { os.Exit(1) } - // Store the bytecode so our Go module implementations can use it elsewhere - functions.SetStoredBytecode(bytecode) - // Execute the compiled bytecode if err := state.LoadAndRunBytecode(bytecode, scriptPath); err != nil { fmt.Fprintf(os.Stderr, "Error executing '%s': %v\n", scriptPath, err)