193 lines
4.1 KiB
Go
193 lines
4.1 KiB
Go
package lualibs
|
||
|
||
import (
|
||
"encoding/base64"
|
||
"html"
|
||
"strings"
|
||
|
||
"github.com/goccy/go-json"
|
||
|
||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||
)
|
||
|
||
// RegisterUtilFunctions registers utility functions with the Lua state
|
||
func RegisterUtilFunctions(state *luajit.State) error {
|
||
if err := state.RegisterGoFunction("__json_marshal", jsonMarshal); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := state.RegisterGoFunction("__json_unmarshal", jsonUnmarshal); err != nil {
|
||
return err
|
||
}
|
||
|
||
// HTML special chars
|
||
if err := state.RegisterGoFunction("__html_special_chars", htmlSpecialChars); err != nil {
|
||
return err
|
||
}
|
||
|
||
// HTML entities
|
||
if err := state.RegisterGoFunction("__html_entities", htmlEntities); err != nil {
|
||
return err
|
||
}
|
||
|
||
// Base64 encode
|
||
if err := state.RegisterGoFunction("__base64_encode", base64Encode); err != nil {
|
||
return err
|
||
}
|
||
|
||
// Base64 decode
|
||
if err := state.RegisterGoFunction("__base64_decode", base64Decode); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// htmlSpecialChars converts special characters to HTML entities
|
||
func htmlSpecialChars(state *luajit.State) int {
|
||
if err := state.CheckExactArgs(1); err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
input, err := state.SafeToString(1)
|
||
if err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
result := html.EscapeString(input)
|
||
state.PushString(result)
|
||
return 1
|
||
}
|
||
|
||
// htmlEntities is a more comprehensive version of htmlSpecialChars
|
||
func htmlEntities(state *luajit.State) int {
|
||
if err := state.CheckExactArgs(1); err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
input, err := state.SafeToString(1)
|
||
if err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
// First use HTML escape for standard entities
|
||
result := html.EscapeString(input)
|
||
|
||
// Additional entities beyond what html.EscapeString handles
|
||
replacements := map[string]string{
|
||
"©": "©",
|
||
"®": "®",
|
||
"™": "™",
|
||
"€": "€",
|
||
"£": "£",
|
||
"¥": "¥",
|
||
"—": "—",
|
||
"–": "–",
|
||
"…": "…",
|
||
"•": "•",
|
||
"°": "°",
|
||
"±": "±",
|
||
"¼": "¼",
|
||
"½": "½",
|
||
"¾": "¾",
|
||
}
|
||
|
||
for char, entity := range replacements {
|
||
result = strings.ReplaceAll(result, char, entity)
|
||
}
|
||
|
||
state.PushString(result)
|
||
return 1
|
||
}
|
||
|
||
// base64Encode encodes a string to base64
|
||
func base64Encode(state *luajit.State) int {
|
||
if err := state.CheckExactArgs(1); err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
input, err := state.SafeToString(1)
|
||
if err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
result := base64.StdEncoding.EncodeToString([]byte(input))
|
||
state.PushString(result)
|
||
return 1
|
||
}
|
||
|
||
// base64Decode decodes a base64 string
|
||
func base64Decode(state *luajit.State) int {
|
||
if err := state.CheckExactArgs(1); err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
input, err := state.SafeToString(1)
|
||
if err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
result, err := base64.StdEncoding.DecodeString(input)
|
||
if err != nil {
|
||
state.PushNil()
|
||
return 1
|
||
}
|
||
|
||
state.PushString(string(result))
|
||
return 1
|
||
}
|
||
|
||
// jsonMarshal converts a Lua value to a JSON string with validation
|
||
func jsonMarshal(state *luajit.State) int {
|
||
if err := state.CheckExactArgs(1); err != nil {
|
||
return state.PushError("json marshal: %v", err)
|
||
}
|
||
|
||
value, err := state.ToTable(1)
|
||
if err != nil {
|
||
// Try as generic value if not a table
|
||
value, err = state.ToValue(1)
|
||
if err != nil {
|
||
return state.PushError("json marshal error: %v", err)
|
||
}
|
||
}
|
||
|
||
bytes, err := json.Marshal(value)
|
||
if err != nil {
|
||
return state.PushError("json marshal error: %v", err)
|
||
}
|
||
|
||
state.PushString(string(bytes))
|
||
return 1
|
||
}
|
||
|
||
// jsonUnmarshal converts a JSON string to a Lua value with validation
|
||
func jsonUnmarshal(state *luajit.State) int {
|
||
if err := state.CheckExactArgs(1); err != nil {
|
||
return state.PushError("json unmarshal: %v", err)
|
||
}
|
||
|
||
jsonStr, err := state.SafeToString(1)
|
||
if err != nil {
|
||
return state.PushError("json unmarshal: expected string, got %s", state.GetType(1))
|
||
}
|
||
|
||
var value any
|
||
if err := json.Unmarshal([]byte(jsonStr), &value); err != nil {
|
||
return state.PushError("json unmarshal error: %v", err)
|
||
}
|
||
|
||
if err := state.PushValue(value); err != nil {
|
||
return state.PushError("json unmarshal error: %v", err)
|
||
}
|
||
return 1
|
||
}
|