package runner
import (
"encoding/base64"
"html"
"strings"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// RegisterUtilFunctions registers utility functions with the Lua state
func RegisterUtilFunctions(state *luajit.State) error {
// HTML special chars
if err := state.RegisterGoFunction("__html_special_chars", htmlSpecialChars); err != nil {
return err
}
// HTML entities
if err := state.RegisterGoFunction("__html_entities", htmlEntities); err != nil {
return err
}
// Base64 encode
if err := state.RegisterGoFunction("__base64_encode", base64Encode); err != nil {
return err
}
// Base64 decode
if err := state.RegisterGoFunction("__base64_decode", base64Decode); err != nil {
return err
}
return nil
}
// htmlSpecialChars converts special characters to HTML entities
func htmlSpecialChars(state *luajit.State) int {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
result := html.EscapeString(input)
state.PushString(result)
return 1
}
// htmlEntities is a more comprehensive version of htmlSpecialChars
func htmlEntities(state *luajit.State) int {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
// First use HTML escape for standard entities
result := html.EscapeString(input)
// Additional entities beyond what html.EscapeString handles
replacements := map[string]string{
"©": "©",
"®": "®",
"™": "™",
"€": "€",
"£": "£",
"¥": "¥",
"—": "—",
"–": "–",
"…": "…",
"•": "•",
"°": "°",
"±": "±",
"¼": "¼",
"½": "½",
"¾": "¾",
}
for char, entity := range replacements {
result = strings.ReplaceAll(result, char, entity)
}
state.PushString(result)
return 1
}
// base64Encode encodes a string to base64
func base64Encode(state *luajit.State) int {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
result := base64.StdEncoding.EncodeToString([]byte(input))
state.PushString(result)
return 1
}
// base64Decode decodes a base64 string
func base64Decode(state *luajit.State) int {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
result, err := base64.StdEncoding.DecodeString(input)
if err != nil {
state.PushNil()
return 1
}
state.PushString(string(result))
return 1
}