464 lines
10 KiB
Go
464 lines
10 KiB
Go
package lualibs
|
|
|
|
import (
|
|
"Moonshark/logger"
|
|
"crypto/hmac"
|
|
"crypto/md5"
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"crypto/sha256"
|
|
"crypto/sha512"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"hash"
|
|
"math"
|
|
mrand "math/rand/v2"
|
|
"sync"
|
|
"time"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
var (
|
|
// Map to store state-specific RNGs
|
|
stateRngs = make(map[*luajit.State]*mrand.PCG)
|
|
stateRngsMu sync.Mutex
|
|
)
|
|
|
|
// RegisterCryptoFunctions registers all crypto functions with the Lua state
|
|
func RegisterCryptoFunctions(state *luajit.State) error {
|
|
// Create a state-specific RNG
|
|
stateRngsMu.Lock()
|
|
stateRngs[state] = mrand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()>>32))
|
|
stateRngsMu.Unlock()
|
|
|
|
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Register hash functions
|
|
if err := state.RegisterGoFunction("__crypto_hash", cryptoHash); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Register HMAC functions
|
|
if err := state.RegisterGoFunction("__crypto_hmac", cryptoHmac); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Register UUID generation
|
|
if err := state.RegisterGoFunction("__crypto_uuid", cryptoUuid); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Register random functions
|
|
if err := state.RegisterGoFunction("__crypto_random", cryptoRandom); err != nil {
|
|
return err
|
|
}
|
|
if err := state.RegisterGoFunction("__crypto_random_bytes", cryptoRandomBytes); err != nil {
|
|
return err
|
|
}
|
|
if err := state.RegisterGoFunction("__crypto_random_int", cryptoRandomInt); err != nil {
|
|
return err
|
|
}
|
|
if err := state.RegisterGoFunction("__crypto_random_seed", cryptoRandomSeed); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Override Lua's math.random
|
|
if err := OverrideLuaRandom(state); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CleanupCrypto cleans up resources when a state is closed
|
|
func CleanupCrypto(state *luajit.State) {
|
|
stateRngsMu.Lock()
|
|
delete(stateRngs, state)
|
|
stateRngsMu.Unlock()
|
|
}
|
|
|
|
// cryptoHash generates hash digests using various algorithms
|
|
func cryptoHash(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(2); err != nil {
|
|
return state.PushError("hash: %v", err)
|
|
}
|
|
|
|
data, err := state.SafeToString(1)
|
|
if err != nil {
|
|
return state.PushError("hash: data must be string")
|
|
}
|
|
|
|
algorithm, err := state.SafeToString(2)
|
|
if err != nil {
|
|
return state.PushError("hash: algorithm must be string")
|
|
}
|
|
|
|
var h hash.Hash
|
|
switch algorithm {
|
|
case "md5":
|
|
h = md5.New()
|
|
case "sha1":
|
|
h = sha1.New()
|
|
case "sha256":
|
|
h = sha256.New()
|
|
case "sha512":
|
|
h = sha512.New()
|
|
default:
|
|
return state.PushError("unsupported algorithm: %s", algorithm)
|
|
}
|
|
|
|
h.Write([]byte(data))
|
|
hashBytes := h.Sum(nil)
|
|
|
|
// Output format
|
|
outputFormat := "hex"
|
|
if state.GetTop() >= 3 {
|
|
if format, err := state.SafeToString(3); err == nil {
|
|
outputFormat = format
|
|
}
|
|
}
|
|
|
|
switch outputFormat {
|
|
case "hex":
|
|
state.PushString(hex.EncodeToString(hashBytes))
|
|
case "binary":
|
|
state.PushString(string(hashBytes))
|
|
default:
|
|
state.PushString(hex.EncodeToString(hashBytes))
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
// cryptoHmac generates HMAC using various hash algorithms
|
|
func cryptoHmac(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(3); err != nil {
|
|
return state.PushError("hmac: %v", err)
|
|
}
|
|
|
|
data, err := state.SafeToString(1)
|
|
if err != nil {
|
|
return state.PushError("hmac: data must be string")
|
|
}
|
|
|
|
key, err := state.SafeToString(2)
|
|
if err != nil {
|
|
return state.PushError("hmac: key must be string")
|
|
}
|
|
|
|
algorithm, err := state.SafeToString(3)
|
|
if err != nil {
|
|
return state.PushError("hmac: algorithm must be string")
|
|
}
|
|
|
|
var h func() hash.Hash
|
|
switch algorithm {
|
|
case "md5":
|
|
h = md5.New
|
|
case "sha1":
|
|
h = sha1.New
|
|
case "sha256":
|
|
h = sha256.New
|
|
case "sha512":
|
|
h = sha512.New
|
|
default:
|
|
return state.PushError("unsupported algorithm: %s", algorithm)
|
|
}
|
|
|
|
mac := hmac.New(h, []byte(key))
|
|
mac.Write([]byte(data))
|
|
macBytes := mac.Sum(nil)
|
|
|
|
// Output format
|
|
outputFormat := "hex"
|
|
if state.GetTop() >= 4 {
|
|
if format, err := state.SafeToString(4); err == nil {
|
|
outputFormat = format
|
|
}
|
|
}
|
|
|
|
switch outputFormat {
|
|
case "hex":
|
|
state.PushString(hex.EncodeToString(macBytes))
|
|
case "binary":
|
|
state.PushString(string(macBytes))
|
|
default:
|
|
state.PushString(hex.EncodeToString(macBytes))
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
// cryptoUuid generates a random UUID v4
|
|
func cryptoUuid(state *luajit.State) int {
|
|
uuid := make([]byte, 16)
|
|
if _, err := rand.Read(uuid); err != nil {
|
|
return state.PushError("uuid: generation error: %v", err)
|
|
}
|
|
|
|
// Set version (4) and variant (RFC 4122)
|
|
uuid[6] = (uuid[6] & 0x0F) | 0x40
|
|
uuid[8] = (uuid[8] & 0x3F) | 0x80
|
|
|
|
uuidStr := fmt.Sprintf("%x-%x-%x-%x-%x",
|
|
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])
|
|
|
|
state.PushString(uuidStr)
|
|
return 1
|
|
}
|
|
|
|
// cryptoRandomBytes generates random bytes
|
|
func cryptoRandomBytes(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(1); err != nil {
|
|
return state.PushError("random_bytes: %v", err)
|
|
}
|
|
|
|
length, err := state.SafeToNumber(1)
|
|
if err != nil {
|
|
return state.PushError("random_bytes: length must be number")
|
|
}
|
|
|
|
if length <= 0 {
|
|
return state.PushError("random_bytes: length must be positive")
|
|
}
|
|
|
|
// Check if secure
|
|
secure := true
|
|
if state.GetTop() >= 2 && state.IsBoolean(2) {
|
|
secure = state.ToBoolean(2)
|
|
}
|
|
|
|
bytes := make([]byte, int(length))
|
|
|
|
if secure {
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return state.PushError("random_bytes: error: %v", err)
|
|
}
|
|
} else {
|
|
stateRngsMu.Lock()
|
|
stateRng, ok := stateRngs[state]
|
|
stateRngsMu.Unlock()
|
|
|
|
if !ok {
|
|
return state.PushError("random_bytes: RNG not initialized")
|
|
}
|
|
|
|
for i := range bytes {
|
|
bytes[i] = byte(stateRng.Uint64() & 0xFF)
|
|
}
|
|
}
|
|
|
|
// Output format
|
|
outputFormat := "binary"
|
|
if state.GetTop() >= 3 {
|
|
if format, err := state.SafeToString(3); err == nil {
|
|
outputFormat = format
|
|
}
|
|
}
|
|
|
|
switch outputFormat {
|
|
case "binary":
|
|
state.PushString(string(bytes))
|
|
case "hex":
|
|
state.PushString(hex.EncodeToString(bytes))
|
|
default:
|
|
state.PushString(string(bytes))
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
// cryptoRandomInt generates a random integer in range [min, max]
|
|
func cryptoRandomInt(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(2); err != nil {
|
|
return state.PushError("random_int: %v", err)
|
|
}
|
|
|
|
minVal, err := state.SafeToNumber(1)
|
|
if err != nil {
|
|
return state.PushError("random_int: min must be number")
|
|
}
|
|
|
|
maxVal, err := state.SafeToNumber(2)
|
|
if err != nil {
|
|
return state.PushError("random_int: max must be number")
|
|
}
|
|
|
|
min := int64(minVal)
|
|
max := int64(maxVal)
|
|
|
|
if max <= min {
|
|
return state.PushError("random_int: max must be greater than min")
|
|
}
|
|
|
|
// Check if secure
|
|
secure := true
|
|
if state.GetTop() >= 3 && state.IsBoolean(3) {
|
|
secure = state.ToBoolean(3)
|
|
}
|
|
|
|
range_size := max - min + 1
|
|
var result int64
|
|
|
|
if secure {
|
|
bytes := make([]byte, 8)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return state.PushError("random_int: error: %v", err)
|
|
}
|
|
|
|
val := binary.BigEndian.Uint64(bytes)
|
|
result = min + int64(val%uint64(range_size))
|
|
} else {
|
|
stateRngsMu.Lock()
|
|
stateRng, ok := stateRngs[state]
|
|
stateRngsMu.Unlock()
|
|
|
|
if !ok {
|
|
return state.PushError("random_int: RNG not initialized")
|
|
}
|
|
|
|
result = min + int64(stateRng.Uint64()%uint64(range_size))
|
|
}
|
|
|
|
state.PushNumber(float64(result))
|
|
return 1
|
|
}
|
|
|
|
// cryptoRandom implements math.random functionality
|
|
func cryptoRandom(state *luajit.State) int {
|
|
numArgs := state.GetTop()
|
|
|
|
// Check if secure
|
|
secure := false
|
|
|
|
// math.random() - return [0,1)
|
|
if numArgs == 0 {
|
|
if secure {
|
|
bytes := make([]byte, 8)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return state.PushError("random: error: %v", err)
|
|
}
|
|
val := binary.BigEndian.Uint64(bytes)
|
|
state.PushNumber(float64(val) / float64(math.MaxUint64))
|
|
} else {
|
|
stateRngsMu.Lock()
|
|
stateRng, ok := stateRngs[state]
|
|
stateRngsMu.Unlock()
|
|
|
|
if !ok {
|
|
return state.PushError("random: RNG not initialized")
|
|
}
|
|
|
|
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
|
|
}
|
|
return 1
|
|
}
|
|
|
|
// math.random(n) - return integer [1,n]
|
|
if numArgs == 1 && state.IsNumber(1) {
|
|
n := int64(state.ToNumber(1))
|
|
if n < 1 {
|
|
return state.PushError("random: upper bound must be >= 1")
|
|
}
|
|
|
|
state.PushNumber(1) // min
|
|
state.PushNumber(float64(n)) // max
|
|
state.PushBoolean(secure) // secure flag
|
|
return cryptoRandomInt(state)
|
|
}
|
|
|
|
// math.random(m, n) - return integer [m,n]
|
|
if numArgs >= 2 && state.IsNumber(1) && state.IsNumber(2) {
|
|
state.PushBoolean(secure) // secure flag
|
|
return cryptoRandomInt(state)
|
|
}
|
|
|
|
return state.PushError("random: invalid arguments")
|
|
}
|
|
|
|
// cryptoRandomSeed sets seed for non-secure RNG
|
|
func cryptoRandomSeed(state *luajit.State) int {
|
|
if err := state.CheckExactArgs(1); err != nil {
|
|
return state.PushError("randomseed: %v", err)
|
|
}
|
|
|
|
seed, err := state.SafeToNumber(1)
|
|
if err != nil {
|
|
return state.PushError("randomseed: seed must be number")
|
|
}
|
|
|
|
seedVal := uint64(seed)
|
|
stateRngsMu.Lock()
|
|
stateRngs[state] = mrand.NewPCG(seedVal, seedVal>>32)
|
|
stateRngsMu.Unlock()
|
|
|
|
return 0
|
|
}
|
|
|
|
// OverrideLuaRandom replaces Lua's math.random with Go implementation
|
|
func OverrideLuaRandom(state *luajit.State) error {
|
|
if err := state.RegisterGoFunction("go_math_random", cryptoRandom); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := state.RegisterGoFunction("go_math_randomseed", cryptoRandomSeed); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Replace original functions
|
|
return state.DoString(`
|
|
-- Save original functions
|
|
_G._original_math_random = math.random
|
|
_G._original_math_randomseed = math.randomseed
|
|
|
|
-- Replace with Go implementations
|
|
math.random = go_math_random
|
|
math.randomseed = go_math_randomseed
|
|
|
|
-- Clean up global namespace
|
|
go_math_random = nil
|
|
go_math_randomseed = nil
|
|
`)
|
|
}
|
|
|
|
// generateToken creates a cryptographically secure random token
|
|
func generateToken(state *luajit.State) int {
|
|
// Get the length from the Lua arguments (default to 32)
|
|
length := 32
|
|
if state.GetTop() >= 1 {
|
|
if lengthVal, err := state.SafeToNumber(1); err == nil {
|
|
length = int(lengthVal)
|
|
}
|
|
}
|
|
|
|
// Enforce minimum length for security
|
|
if length < 16 {
|
|
length = 16
|
|
}
|
|
|
|
// Generate secure random bytes
|
|
tokenBytes := make([]byte, length)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
logger.Errorf("Failed to generate secure token: %v", err)
|
|
state.PushString("")
|
|
return 1 // Return empty string on error
|
|
}
|
|
|
|
// Encode as base64
|
|
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
|
|
|
// Trim to requested length (base64 might be longer)
|
|
if len(token) > length {
|
|
token = token[:length]
|
|
}
|
|
|
|
// Push the token to the Lua stack
|
|
state.PushString(token)
|
|
return 1 // One return value
|
|
}
|