refactor modules, update crypto module
This commit is contained in:
parent
4ff04e141d
commit
503f76d127
@ -1,394 +0,0 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func GetCryptoFunctions() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"base64_encode": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("base64_encode: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("base64_encode: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
},
|
||||
|
||||
"base64_decode": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("base64_decode: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("base64_decode: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("base64_decode: %v", err))
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
},
|
||||
|
||||
"base64_url_encode": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("base64_url_encode: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("base64_url_encode: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
encoded := base64.URLEncoding.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
},
|
||||
|
||||
"base64_url_decode": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("base64_url_decode: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("base64_url_decode: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
decoded, err := base64.URLEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("base64_url_decode: %v", err))
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
},
|
||||
|
||||
"hex_encode": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("hex_encode: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("hex_encode: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
encoded := hex.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
},
|
||||
|
||||
"hex_decode": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("hex_decode: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("hex_decode: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
decoded, err := hex.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("hex_decode: %v", err))
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
},
|
||||
|
||||
"md5_hash": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("md5_hash: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("md5_hash: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
hash := md5.Sum([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
},
|
||||
|
||||
"sha1_hash": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("sha1_hash: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("sha1_hash: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
hash := sha1.Sum([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
},
|
||||
|
||||
"sha256_hash": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("sha256_hash: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("sha256_hash: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
hash := sha256.Sum256([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
},
|
||||
|
||||
"sha512_hash": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("sha512_hash: %v", err))
|
||||
return 2
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("sha512_hash: argument must be a string")
|
||||
return 2
|
||||
}
|
||||
hash := sha512.Sum512([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
},
|
||||
|
||||
"hmac_sha256": func(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(2); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("hmac_sha256: %v", err))
|
||||
return 2
|
||||
}
|
||||
message, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("hmac_sha256: first argument must be a string")
|
||||
return 2
|
||||
}
|
||||
key, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("hmac_sha256: second argument must be a string")
|
||||
return 2
|
||||
}
|
||||
h := hmac.New(sha256.New, []byte(key))
|
||||
h.Write([]byte(message))
|
||||
s.PushString(hex.EncodeToString(h.Sum(nil)))
|
||||
return 1
|
||||
},
|
||||
|
||||
"hmac_sha1": func(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(2); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("hmac_sha1: %v", err))
|
||||
return 2
|
||||
}
|
||||
message, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("hmac_sha1: first argument must be a string")
|
||||
return 2
|
||||
}
|
||||
key, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("hmac_sha1: second argument must be a string")
|
||||
return 2
|
||||
}
|
||||
h := hmac.New(sha1.New, []byte(key))
|
||||
h.Write([]byte(message))
|
||||
s.PushString(hex.EncodeToString(h.Sum(nil)))
|
||||
return 1
|
||||
},
|
||||
|
||||
"uuid_generate": func(s *luajit.State) int {
|
||||
id := uuid.New()
|
||||
s.PushString(id.String())
|
||||
return 1
|
||||
},
|
||||
|
||||
"uuid_generate_v4": func(s *luajit.State) int {
|
||||
id := uuid.New()
|
||||
s.PushString(id.String())
|
||||
return 1
|
||||
},
|
||||
|
||||
"uuid_validate": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
str, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
_, err = uuid.Parse(str)
|
||||
s.PushBoolean(err == nil)
|
||||
return 1
|
||||
},
|
||||
|
||||
"random_bytes": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("random_bytes: %v", err))
|
||||
return 2
|
||||
}
|
||||
length, err := s.SafeToNumber(1)
|
||||
if err != nil || length < 0 || length != float64(int(length)) {
|
||||
s.PushNil()
|
||||
s.PushString("random_bytes: argument must be a non-negative integer")
|
||||
return 2
|
||||
}
|
||||
if length > 65536 {
|
||||
s.PushNil()
|
||||
s.PushString("random_bytes: length too large (max 65536)")
|
||||
return 2
|
||||
}
|
||||
bytes := make([]byte, int(length))
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("random_bytes: failed to generate random bytes")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(bytes))
|
||||
return 1
|
||||
},
|
||||
|
||||
"random_hex": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("random_hex: %v", err))
|
||||
return 2
|
||||
}
|
||||
length, err := s.SafeToNumber(1)
|
||||
if err != nil || length < 0 || length != float64(int(length)) {
|
||||
s.PushNil()
|
||||
s.PushString("random_hex: argument must be a non-negative integer")
|
||||
return 2
|
||||
}
|
||||
if length > 32768 {
|
||||
s.PushNil()
|
||||
s.PushString("random_hex: length too large (max 32768)")
|
||||
return 2
|
||||
}
|
||||
bytes := make([]byte, int(length))
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("random_hex: failed to generate random bytes")
|
||||
return 2
|
||||
}
|
||||
s.PushString(hex.EncodeToString(bytes))
|
||||
return 1
|
||||
},
|
||||
|
||||
"random_string": func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(fmt.Sprintf("random_string: %v", err))
|
||||
return 2
|
||||
}
|
||||
length, err := s.SafeToNumber(1)
|
||||
if err != nil || length < 0 || length != float64(int(length)) {
|
||||
s.PushNil()
|
||||
s.PushString("random_string: argument must be a non-negative integer")
|
||||
return 2
|
||||
}
|
||||
if length > 65536 {
|
||||
s.PushNil()
|
||||
s.PushString("random_string: length too large (max 65536)")
|
||||
return 2
|
||||
}
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
if s.GetTop() >= 2 {
|
||||
if customCharset, err := s.SafeToString(2); err == nil {
|
||||
charset = customCharset
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]byte, int(length))
|
||||
charsetLen := big.NewInt(int64(len(charset)))
|
||||
for i := range result {
|
||||
n, err := rand.Int(rand.Reader, charsetLen)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("random_string: failed to generate random number")
|
||||
return 2
|
||||
}
|
||||
result[i] = charset[n.Int64()]
|
||||
}
|
||||
s.PushString(string(result))
|
||||
return 1
|
||||
},
|
||||
|
||||
"secure_compare": func(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(2); err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
a, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
b, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
s.PushBoolean(hmac.Equal([]byte(a), []byte(b)))
|
||||
return 1
|
||||
},
|
||||
}
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Registry holds all available Go functions for Lua modules
|
||||
type Registry map[string]luajit.GoFunction
|
||||
|
||||
// GetAll returns all registered functions
|
||||
func GetAll() Registry {
|
||||
registry := make(Registry)
|
||||
|
||||
maps.Copy(registry, GetJSONFunctions())
|
||||
maps.Copy(registry, GetStringFunctions())
|
||||
maps.Copy(registry, GetMathFunctions())
|
||||
maps.Copy(registry, GetFSFunctions())
|
||||
maps.Copy(registry, GetCryptoFunctions())
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
var (
|
||||
storedBytecode []byte
|
||||
bytecodeMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func SetStoredBytecode(bytecode []byte) {
|
||||
bytecodeMutex.Lock()
|
||||
defer bytecodeMutex.Unlock()
|
||||
storedBytecode = bytecode
|
||||
}
|
||||
|
||||
func GetStoredBytecode() []byte {
|
||||
bytecodeMutex.RLock()
|
||||
defer bytecodeMutex.RUnlock()
|
||||
return storedBytecode
|
||||
}
|
231
modules/crypto/crypto.go
Normal file
231
modules/crypto/crypto.go
Normal file
@ -0,0 +1,231 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"base64_encode": base64_encode,
|
||||
"base64_decode": base64_decode,
|
||||
"base64_url_encode": base64_url_encode,
|
||||
"base64_url_decode": base64_url_decode,
|
||||
"hex_encode": hex_encode,
|
||||
"hex_decode": hex_decode,
|
||||
"md5_hash": md5_hash,
|
||||
"sha1_hash": sha1_hash,
|
||||
"sha256_hash": sha256_hash,
|
||||
"sha512_hash": sha512_hash,
|
||||
"hmac_sha256": hmac_sha256,
|
||||
"hmac_sha1": hmac_sha1,
|
||||
"uuid_generate": uuid_generate,
|
||||
"uuid_generate_v4": uuid_generate_v4,
|
||||
"uuid_validate": uuid_validate,
|
||||
"random_bytes": random_bytes,
|
||||
"random_hex": random_hex,
|
||||
"random_string": random_string,
|
||||
"secure_compare": secure_compare,
|
||||
}
|
||||
}
|
||||
|
||||
func base64_encode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
}
|
||||
|
||||
func base64_decode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
decoded, err := base64.StdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("invalid base64 data")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
}
|
||||
|
||||
func base64_url_encode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
encoded := base64.URLEncoding.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
}
|
||||
|
||||
func base64_url_decode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
decoded, err := base64.URLEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("invalid base64url data")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
}
|
||||
|
||||
func hex_encode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
encoded := hex.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
}
|
||||
|
||||
func hex_decode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
decoded, err := hex.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("invalid hex data")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
}
|
||||
|
||||
func md5_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := md5.Sum([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func sha1_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := sha1.Sum([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func sha256_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := sha256.Sum256([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func sha512_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := sha512.Sum512([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func hmac_sha256(s *luajit.State) int {
|
||||
message := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
h := hmac.New(sha256.New, []byte(key))
|
||||
h.Write([]byte(message))
|
||||
s.PushString(hex.EncodeToString(h.Sum(nil)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func hmac_sha1(s *luajit.State) int {
|
||||
message := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
h := hmac.New(sha1.New, []byte(key))
|
||||
h.Write([]byte(message))
|
||||
s.PushString(hex.EncodeToString(h.Sum(nil)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func uuid_generate(s *luajit.State) int {
|
||||
id := uuid.New()
|
||||
s.PushString(id.String())
|
||||
return 1
|
||||
}
|
||||
|
||||
func uuid_generate_v4(s *luajit.State) int {
|
||||
id := uuid.New()
|
||||
s.PushString(id.String())
|
||||
return 1
|
||||
}
|
||||
|
||||
func uuid_validate(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
_, err := uuid.Parse(str)
|
||||
s.PushBoolean(err == nil)
|
||||
return 1
|
||||
}
|
||||
|
||||
func random_bytes(s *luajit.State) int {
|
||||
length := int(s.ToNumber(1))
|
||||
if length < 0 || length > 65536 {
|
||||
s.PushNil()
|
||||
s.PushString("invalid length")
|
||||
return 2
|
||||
}
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate random bytes")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(bytes))
|
||||
return 1
|
||||
}
|
||||
|
||||
func random_hex(s *luajit.State) int {
|
||||
length := int(s.ToNumber(1))
|
||||
if length < 0 || length > 32768 {
|
||||
s.PushNil()
|
||||
s.PushString("invalid length")
|
||||
return 2
|
||||
}
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate random bytes")
|
||||
return 2
|
||||
}
|
||||
s.PushString(hex.EncodeToString(bytes))
|
||||
return 1
|
||||
}
|
||||
|
||||
func random_string(s *luajit.State) int {
|
||||
length := int(s.ToNumber(1))
|
||||
if length < 0 || length > 65536 {
|
||||
s.PushNil()
|
||||
s.PushString("invalid length")
|
||||
return 2
|
||||
}
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
if s.GetTop() >= 2 {
|
||||
charset = s.ToString(2)
|
||||
}
|
||||
|
||||
result := make([]byte, length)
|
||||
charsetLen := big.NewInt(int64(len(charset)))
|
||||
for i := range result {
|
||||
n, err := rand.Int(rand.Reader, charsetLen)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate random number")
|
||||
return 2
|
||||
}
|
||||
result[i] = charset[n.Int64()]
|
||||
}
|
||||
s.PushString(string(result))
|
||||
return 1
|
||||
}
|
||||
|
||||
func secure_compare(s *luajit.State) int {
|
||||
a := s.ToString(1)
|
||||
b := s.ToString(2)
|
||||
s.PushBoolean(hmac.Equal([]byte(a), []byte(b)))
|
||||
return 1
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package functions
|
||||
package fs
|
||||
|
||||
import (
|
||||
"io"
|
@ -1,4 +1,4 @@
|
||||
package functions
|
||||
package json
|
||||
|
||||
import (
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
@ -1,4 +1,4 @@
|
||||
package functions
|
||||
package math
|
||||
|
||||
import luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
|
@ -1,139 +0,0 @@
|
||||
package modules
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"Moonshark/functions"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
//go:embed *.lua
|
||||
var builtinModules embed.FS
|
||||
|
||||
// ModuleRegistry manages built-in modules and Go functions
|
||||
type ModuleRegistry struct {
|
||||
modules map[string]string
|
||||
goFuncs map[string]luajit.GoFunction
|
||||
}
|
||||
|
||||
// NewModuleRegistry creates a new module registry
|
||||
func NewModuleRegistry() *ModuleRegistry {
|
||||
mr := &ModuleRegistry{
|
||||
modules: make(map[string]string),
|
||||
goFuncs: functions.GetAll(),
|
||||
}
|
||||
|
||||
return mr
|
||||
}
|
||||
|
||||
// RegisterModule adds a module by name and source code
|
||||
func (mr *ModuleRegistry) RegisterModule(name, source string) {
|
||||
mr.modules[name] = source
|
||||
}
|
||||
|
||||
// RegisterGoFunction adds a Go function that modules can use
|
||||
func (mr *ModuleRegistry) RegisterGoFunction(name string, fn luajit.GoFunction) {
|
||||
mr.goFuncs[name] = fn
|
||||
}
|
||||
|
||||
// LoadEmbeddedModules loads all modules from the embedded filesystem
|
||||
func (mr *ModuleRegistry) LoadEmbeddedModules() error {
|
||||
entries, err := builtinModules.ReadDir(".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read modules directory: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".lua") {
|
||||
continue
|
||||
}
|
||||
|
||||
moduleName := strings.TrimSuffix(entry.Name(), ".lua")
|
||||
source, err := builtinModules.ReadFile(filepath.Join(".", entry.Name()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read module %s: %w", moduleName, err)
|
||||
}
|
||||
|
||||
mr.RegisterModule(moduleName, string(source))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InstallModules sets up the module system in the Lua state
|
||||
func (mr *ModuleRegistry) InstallModules(state *luajit.State) error {
|
||||
// Create moonshark global table
|
||||
state.NewTable()
|
||||
state.SetGlobal("moonshark")
|
||||
|
||||
// Install Go functions first
|
||||
if err := mr.installGoFunctions(state); err != nil {
|
||||
return fmt.Errorf("failed to install Go functions: %w", err)
|
||||
}
|
||||
|
||||
// Register require function that checks our built-in modules first
|
||||
err := state.RegisterGoFunction("require", func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
return s.PushError("require: %v", err)
|
||||
}
|
||||
|
||||
moduleName, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("require: module name must be a string")
|
||||
}
|
||||
|
||||
// Check if it's a built-in module
|
||||
if source, exists := mr.modules[moduleName]; exists {
|
||||
// Execute the module and return its result
|
||||
if err := s.LoadString(source); err != nil {
|
||||
return s.PushError("require: failed to load module '%s': %v", moduleName, err)
|
||||
}
|
||||
|
||||
if err := s.Call(0, 1); err != nil {
|
||||
return s.PushError("require: failed to execute module '%s': %v", moduleName, err)
|
||||
}
|
||||
|
||||
return 1 // Return the module's result
|
||||
}
|
||||
|
||||
// Fall back to standard Lua require
|
||||
s.GetGlobal("_require_original")
|
||||
if s.IsFunction(-1) {
|
||||
s.PushString(moduleName)
|
||||
if err := s.Call(1, 1); err != nil {
|
||||
return s.PushError("require: %v", err)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
return s.PushError("require: module '%s' not found", moduleName)
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// installGoFunctions installs all registered Go functions into the Lua state
|
||||
func (mr *ModuleRegistry) installGoFunctions(state *luajit.State) error {
|
||||
// Install functions in moonshark namespace
|
||||
state.GetGlobal("moonshark")
|
||||
|
||||
for name, fn := range mr.goFuncs {
|
||||
if err := state.PushGoFunction(fn); err != nil {
|
||||
return fmt.Errorf("failed to register Go function '%s': %w", name, err)
|
||||
}
|
||||
state.SetField(-2, name)
|
||||
}
|
||||
|
||||
state.Pop(1) // Remove moonshark table
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupOriginalRequire saves the original require function
|
||||
func BackupOriginalRequire(state *luajit.State) {
|
||||
state.GetGlobal("require")
|
||||
state.SetGlobal("_require_original")
|
||||
}
|
152
modules/registry.go
Normal file
152
modules/registry.go
Normal file
@ -0,0 +1,152 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"maps"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"Moonshark/functions"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Global registry instance
|
||||
var Global *Registry
|
||||
|
||||
//go:embed modules/*.lua
|
||||
var embeddedModules embed.FS
|
||||
|
||||
// Registry manages all Lua modules and Go functions
|
||||
type Registry struct {
|
||||
modules map[string]string
|
||||
goFuncs map[string]luajit.GoFunction
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new registry with all Go functions loaded
|
||||
func New() *Registry {
|
||||
r := &Registry{
|
||||
modules: make(map[string]string),
|
||||
goFuncs: make(map[string]luajit.GoFunction),
|
||||
}
|
||||
|
||||
// Load all Go functions
|
||||
maps.Copy(r.goFuncs, functions.GetJSONFunctions())
|
||||
maps.Copy(r.goFuncs, functions.GetStringFunctions())
|
||||
maps.Copy(r.goFuncs, functions.GetMathFunctions())
|
||||
maps.Copy(r.goFuncs, functions.GetFSFunctions())
|
||||
maps.Copy(r.goFuncs, functions.GetCryptoFunctions())
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// LoadEmbeddedModules loads all .lua files from embedded filesystem
|
||||
func (r *Registry) LoadEmbeddedModules() error {
|
||||
entries, err := embeddedModules.ReadDir("modules")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read modules directory: %w", err)
|
||||
}
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".lua") {
|
||||
continue
|
||||
}
|
||||
|
||||
moduleName := strings.TrimSuffix(entry.Name(), ".lua")
|
||||
source, err := embeddedModules.ReadFile(filepath.Join("modules", entry.Name()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read module %s: %w", moduleName, err)
|
||||
}
|
||||
|
||||
r.modules[moduleName] = string(source)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InstallInState sets up the complete module system in a Lua state
|
||||
func (r *Registry) InstallInState(state *luajit.State) error {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// Create moonshark global table
|
||||
state.NewTable()
|
||||
state.SetGlobal("moonshark")
|
||||
|
||||
// Install Go functions
|
||||
state.GetGlobal("moonshark")
|
||||
for name, fn := range r.goFuncs {
|
||||
if err := state.PushGoFunction(fn); err != nil {
|
||||
return fmt.Errorf("failed to register Go function '%s': %w", name, err)
|
||||
}
|
||||
state.SetField(-2, name)
|
||||
}
|
||||
state.Pop(1) // Remove moonshark table
|
||||
|
||||
// Backup original require
|
||||
state.GetGlobal("require")
|
||||
state.SetGlobal("_require_original")
|
||||
|
||||
// Install custom require function
|
||||
return state.RegisterGoFunction("require", func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
return s.PushError("require: %v", err)
|
||||
}
|
||||
|
||||
moduleName, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("require: module name must be a string")
|
||||
}
|
||||
|
||||
// Check built-in modules first
|
||||
r.mutex.RLock()
|
||||
source, exists := r.modules[moduleName]
|
||||
r.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
if err := s.LoadString(source); err != nil {
|
||||
return s.PushError("require: failed to load module '%s': %v", moduleName, err)
|
||||
}
|
||||
if err := s.Call(0, 1); err != nil {
|
||||
return s.PushError("require: failed to execute module '%s': %v", moduleName, err)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Fall back to original require
|
||||
s.GetGlobal("_require_original")
|
||||
if s.IsFunction(-1) {
|
||||
s.PushString(moduleName)
|
||||
if err := s.Call(1, 1); err != nil {
|
||||
return s.PushError("require: %v", err)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
return s.PushError("require: module '%s' not found", moduleName)
|
||||
})
|
||||
}
|
||||
|
||||
// Initialize sets up the global registry with all modules loaded
|
||||
func Initialize() error {
|
||||
Global = New()
|
||||
return Global.LoadEmbeddedModules()
|
||||
}
|
||||
|
||||
// GetGoFunctions returns all Go functions
|
||||
func (r *Registry) GetGoFunctions() map[string]luajit.GoFunction {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
result := make(map[string]luajit.GoFunction, len(r.goFuncs))
|
||||
for k, v := range r.goFuncs {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package functions
|
||||
package string
|
||||
|
||||
import (
|
||||
"fmt"
|
24
moonshark.go
24
moonshark.go
@ -1,12 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"Moonshark/functions"
|
||||
"Moonshark/modules"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"Moonshark/registry"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
@ -24,6 +24,12 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Initialize global registry
|
||||
if err := registry.Initialize(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: failed to initialize registry: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create new Lua state with standard libraries
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
@ -32,15 +38,8 @@ func main() {
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
// Set up module system
|
||||
registry := modules.NewModuleRegistry()
|
||||
if err := registry.LoadEmbeddedModules(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to load built-in modules: %v\n", err)
|
||||
}
|
||||
|
||||
// Backup original require and install module system
|
||||
modules.BackupOriginalRequire(state)
|
||||
if err := registry.InstallModules(state); err != nil {
|
||||
// Install module system in main state
|
||||
if err := registry.Global.InstallInState(state); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: failed to install module system: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@ -73,9 +72,6 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Store the bytecode so our Go module implementations can use it elsewhere
|
||||
functions.SetStoredBytecode(bytecode)
|
||||
|
||||
// Execute the compiled bytecode
|
||||
if err := state.LoadAndRunBytecode(bytecode, scriptPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing '%s': %v\n", scriptPath, err)
|
||||
|
Loading…
x
Reference in New Issue
Block a user