Moonshark/runner/crypto.go
2025-05-10 13:02:09 -05:00

405 lines
8.8 KiB
Go

package runner
import (
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/binary"
"encoding/hex"
"fmt"
"hash"
"math"
mrand "math/rand/v2"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
var (
// Map to store state-specific RNGs
stateRngs = make(map[*luajit.State]*mrand.PCG)
stateRngsMu sync.Mutex
)
// RegisterCryptoFunctions registers all crypto functions with the Lua state
func RegisterCryptoFunctions(state *luajit.State) error {
// Create a state-specific RNG
stateRngsMu.Lock()
stateRngs[state] = mrand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()>>32))
stateRngsMu.Unlock()
// Register hash functions
if err := state.RegisterGoFunction("__crypto_hash", cryptoHash); err != nil {
return err
}
// Register HMAC functions
if err := state.RegisterGoFunction("__crypto_hmac", cryptoHmac); err != nil {
return err
}
// Register UUID generation
if err := state.RegisterGoFunction("__crypto_uuid", cryptoUuid); err != nil {
return err
}
// Register random functions
if err := state.RegisterGoFunction("__crypto_random", cryptoRandom); err != nil {
return err
}
if err := state.RegisterGoFunction("__crypto_random_bytes", cryptoRandomBytes); err != nil {
return err
}
if err := state.RegisterGoFunction("__crypto_random_int", cryptoRandomInt); err != nil {
return err
}
if err := state.RegisterGoFunction("__crypto_random_seed", cryptoRandomSeed); err != nil {
return err
}
// Override Lua's math.random
if err := OverrideLuaRandom(state); err != nil {
return err
}
return nil
}
// CleanupCrypto cleans up resources when a state is closed
func CleanupCrypto(state *luajit.State) {
stateRngsMu.Lock()
delete(stateRngs, state)
stateRngsMu.Unlock()
}
// cryptoHash generates hash digests using various algorithms
func cryptoHash(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) {
state.PushString("hash: expected (string data, string algorithm)")
return 1
}
data := state.ToString(1)
algorithm := state.ToString(2)
var h hash.Hash
switch algorithm {
case "md5":
h = md5.New()
case "sha1":
h = sha1.New()
case "sha256":
h = sha256.New()
case "sha512":
h = sha512.New()
default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
return 1
}
h.Write([]byte(data))
hashBytes := h.Sum(nil)
// Output format
outputFormat := "hex"
if state.GetTop() >= 3 && state.IsString(3) {
outputFormat = state.ToString(3)
}
switch outputFormat {
case "hex":
state.PushString(hex.EncodeToString(hashBytes))
case "binary":
state.PushString(string(hashBytes))
default:
state.PushString(hex.EncodeToString(hashBytes))
}
return 1
}
// cryptoHmac generates HMAC using various hash algorithms
func cryptoHmac(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) || !state.IsString(3) {
state.PushString("hmac: expected (string data, string key, string algorithm)")
return 1
}
data := state.ToString(1)
key := state.ToString(2)
algorithm := state.ToString(3)
var h func() hash.Hash
switch algorithm {
case "md5":
h = md5.New
case "sha1":
h = sha1.New
case "sha256":
h = sha256.New
case "sha512":
h = sha512.New
default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
return 1
}
mac := hmac.New(h, []byte(key))
mac.Write([]byte(data))
macBytes := mac.Sum(nil)
// Output format
outputFormat := "hex"
if state.GetTop() >= 4 && state.IsString(4) {
outputFormat = state.ToString(4)
}
switch outputFormat {
case "hex":
state.PushString(hex.EncodeToString(macBytes))
case "binary":
state.PushString(string(macBytes))
default:
state.PushString(hex.EncodeToString(macBytes))
}
return 1
}
// cryptoUuid generates a random UUID v4
func cryptoUuid(state *luajit.State) int {
uuid := make([]byte, 16)
_, err := rand.Read(uuid)
if err != nil {
state.PushString(fmt.Sprintf("uuid: generation error: %v", err))
return 1
}
// Set version (4) and variant (RFC 4122)
uuid[6] = (uuid[6] & 0x0F) | 0x40
uuid[8] = (uuid[8] & 0x3F) | 0x80
uuidStr := fmt.Sprintf("%x-%x-%x-%x-%x",
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])
state.PushString(uuidStr)
return 1
}
// cryptoRandomBytes generates random bytes
func cryptoRandomBytes(state *luajit.State) int {
if !state.IsNumber(1) {
state.PushString("random_bytes: expected (number length)")
return 1
}
length := int(state.ToNumber(1))
if length <= 0 {
state.PushString("random_bytes: length must be positive")
return 1
}
// Check if secure
secure := true
if state.GetTop() >= 2 && state.IsBoolean(2) {
secure = state.ToBoolean(2)
}
bytes := make([]byte, length)
if secure {
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random_bytes: error: %v", err))
return 1
}
} else {
stateRngsMu.Lock()
stateRng, ok := stateRngs[state]
stateRngsMu.Unlock()
if !ok {
state.PushString("random_bytes: RNG not initialized")
return 1
}
for i := range bytes {
bytes[i] = byte(stateRng.Uint64() & 0xFF)
}
}
// Output format
outputFormat := "binary"
if state.GetTop() >= 3 && state.IsString(3) {
outputFormat = state.ToString(3)
}
switch outputFormat {
case "binary":
state.PushString(string(bytes))
case "hex":
state.PushString(hex.EncodeToString(bytes))
default:
state.PushString(string(bytes))
}
return 1
}
// cryptoRandomInt generates a random integer in range [min, max]
func cryptoRandomInt(state *luajit.State) int {
if !state.IsNumber(1) || !state.IsNumber(2) {
state.PushString("random_int: expected (number min, number max)")
return 1
}
min := int64(state.ToNumber(1))
max := int64(state.ToNumber(2))
if max <= min {
state.PushString("random_int: max must be greater than min")
return 1
}
// Check if secure
secure := true
if state.GetTop() >= 3 && state.IsBoolean(3) {
secure = state.ToBoolean(3)
}
range_size := max - min + 1
var result int64
if secure {
bytes := make([]byte, 8)
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random_int: error: %v", err))
return 1
}
val := binary.BigEndian.Uint64(bytes)
result = min + int64(val%uint64(range_size))
} else {
stateRngsMu.Lock()
stateRng, ok := stateRngs[state]
stateRngsMu.Unlock()
if !ok {
state.PushString("random_int: RNG not initialized")
return 1
}
result = min + int64(stateRng.Uint64()%uint64(range_size))
}
state.PushNumber(float64(result))
return 1
}
// cryptoRandom implements math.random functionality
func cryptoRandom(state *luajit.State) int {
numArgs := state.GetTop()
// Check if secure
secure := false
// math.random() - return [0,1)
if numArgs == 0 {
if secure {
bytes := make([]byte, 8)
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random: error: %v", err))
return 1
}
val := binary.BigEndian.Uint64(bytes)
state.PushNumber(float64(val) / float64(math.MaxUint64))
} else {
stateRngsMu.Lock()
stateRng, ok := stateRngs[state]
stateRngsMu.Unlock()
if !ok {
state.PushString("random: RNG not initialized")
return 1
}
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
}
return 1
}
// math.random(n) - return integer [1,n]
if numArgs == 1 && state.IsNumber(1) {
n := int64(state.ToNumber(1))
if n < 1 {
state.PushString("random: upper bound must be >= 1")
return 1
}
state.PushNumber(1) // min
state.PushNumber(float64(n)) // max
state.PushBoolean(secure) // secure flag
return cryptoRandomInt(state)
}
// math.random(m, n) - return integer [m,n]
if numArgs >= 2 && state.IsNumber(1) && state.IsNumber(2) {
state.PushBoolean(secure) // secure flag
return cryptoRandomInt(state)
}
state.PushString("random: invalid arguments")
return 1
}
// cryptoRandomSeed sets seed for non-secure RNG
func cryptoRandomSeed(state *luajit.State) int {
if !state.IsNumber(1) {
state.PushString("randomseed: expected (number seed)")
return 1
}
seed := uint64(state.ToNumber(1))
stateRngsMu.Lock()
stateRngs[state] = mrand.NewPCG(seed, seed>>32)
stateRngsMu.Unlock()
return 0
}
// OverrideLuaRandom replaces Lua's math.random with Go implementation
func OverrideLuaRandom(state *luajit.State) error {
if err := state.RegisterGoFunction("go_math_random", cryptoRandom); err != nil {
return err
}
if err := state.RegisterGoFunction("go_math_randomseed", cryptoRandomSeed); err != nil {
return err
}
// Replace original functions
return state.DoString(`
-- Save original functions
_G._original_math_random = math.random
_G._original_math_randomseed = math.randomseed
-- Replace with Go implementations
math.random = go_math_random
math.randomseed = go_math_randomseed
-- Clean up global namespace
go_math_random = nil
go_math_randomseed = nil
`)
}