go functions first pass

This commit is contained in:
Sky Johnson 2025-07-14 21:34:02 -05:00
parent e5388c4c23
commit da602278c5
12 changed files with 921 additions and 7 deletions

80
functions/crypto.go Normal file
View File

@ -0,0 +1,80 @@
// functions/crypto.go
package functions
import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// GetCryptoFunctions returns all cryptographic Go functions
func GetCryptoFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"base64_encode": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("base64_encode: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("base64_encode: argument must be a string")
}
encoded := base64.StdEncoding.EncodeToString([]byte(str))
s.PushString(encoded)
return 1
},
"base64_decode": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("base64_decode: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("base64_decode: argument must be a string")
}
decoded, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return s.PushError("base64_decode: %v", err)
}
s.PushString(string(decoded))
return 1
},
"md5_hash": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("md5_hash: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("md5_hash: argument must be a string")
}
hash := md5.Sum([]byte(str))
s.PushString(hex.EncodeToString(hash[:]))
return 1
},
"sha256_hash": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("sha256_hash: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("sha256_hash: argument must be a string")
}
hash := sha256.Sum256([]byte(str))
s.PushString(hex.EncodeToString(hash[:]))
return 1
},
}
}

67
functions/fs.go Normal file
View File

@ -0,0 +1,67 @@
package functions
import (
"os"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// GetFSFunctions returns all file system Go functions
func GetFSFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"file_exists": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("file_exists: %v", err)
}
path, err := s.SafeToString(1)
if err != nil {
return s.PushError("file_exists: argument must be a string")
}
_, err = os.Stat(path)
s.PushBoolean(err == nil)
return 1
},
"file_size": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("file_size: %v", err)
}
path, err := s.SafeToString(1)
if err != nil {
return s.PushError("file_size: argument must be a string")
}
info, err := os.Stat(path)
if err != nil {
s.PushNumber(-1)
return 1
}
s.PushNumber(float64(info.Size()))
return 1
},
"file_is_dir": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("file_is_dir: %v", err)
}
path, err := s.SafeToString(1)
if err != nil {
return s.PushError("file_is_dir: argument must be a string")
}
info, err := os.Stat(path)
if err != nil {
s.PushBoolean(false)
return 1
}
s.PushBoolean(info.IsDir())
return 1
},
}
}

55
functions/json.go Normal file
View File

@ -0,0 +1,55 @@
package functions
import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/goccy/go-json"
)
// GetJSONFunctions returns all JSON-related Go functions
func GetJSONFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"json_encode": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("json_encode: %v", err)
}
value, err := s.ToValue(1)
if err != nil {
return s.PushError("json_encode: failed to read value: %v", err)
}
data, err := json.Marshal(value)
if err != nil {
return s.PushError("json_encode: %v", err)
}
s.PushString(string(data))
return 1
},
"json_decode": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("json_decode: %v", err)
}
jsonStr, err := s.SafeToString(1)
if err != nil {
return s.PushError("json_decode: input must be a string")
}
var result interface{}
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
// Return nil and error string instead of PushError for JSON parsing errors
s.PushNil()
s.PushString(err.Error())
return 2
}
if err := s.PushValue(result); err != nil {
return s.PushError("json_decode: failed to push result: %v", err)
}
return 1
},
}
}

85
functions/math.go Normal file
View File

@ -0,0 +1,85 @@
package functions
import luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
// GetMathFunctions returns all math-related Go functions
func GetMathFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"math_factorial": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("math_factorial: %v", err)
}
n, err := s.SafeToNumber(1)
if err != nil || n < 0 || n != float64(int(n)) {
return s.PushError("math_factorial: argument must be a non-negative integer")
}
if n > 170 {
return s.PushError("math_factorial: argument too large (max 170)")
}
result := 1.0
for i := 2; i <= int(n); i++ {
result *= float64(i)
}
s.PushNumber(result)
return 1
},
"math_gcd": func(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("math_gcd: %v", err)
}
a, err := s.SafeToNumber(1)
if err != nil || a != float64(int(a)) {
return s.PushError("math_gcd: first argument must be an integer")
}
b, err := s.SafeToNumber(2)
if err != nil || b != float64(int(b)) {
return s.PushError("math_gcd: second argument must be an integer")
}
ia, ib := int(a), int(b)
for ib != 0 {
ia, ib = ib, ia%ib
}
s.PushNumber(float64(ia))
return 1
},
"math_lcm": func(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("math_lcm: %v", err)
}
a, err := s.SafeToNumber(1)
if err != nil || a != float64(int(a)) {
return s.PushError("math_lcm: first argument must be an integer")
}
b, err := s.SafeToNumber(2)
if err != nil || b != float64(int(b)) {
return s.PushError("math_lcm: second argument must be an integer")
}
ia, ib := int(a), int(b)
// Calculate GCD
gcd := func(x, y int) int {
for y != 0 {
x, y = y, x%y
}
return x
}
result := ia * ib / gcd(ia, ib)
s.PushNumber(float64(result))
return 1
},
}
}

34
functions/registry.go Normal file
View File

@ -0,0 +1,34 @@
package functions
import luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
// Registry holds all available Go functions for Lua modules
type Registry map[string]luajit.GoFunction
// GetAll returns all registered functions
func GetAll() Registry {
registry := make(Registry)
// Register function groups
for name, fn := range GetJSONFunctions() {
registry[name] = fn
}
for name, fn := range GetStringFunctions() {
registry[name] = fn
}
for name, fn := range GetMathFunctions() {
registry[name] = fn
}
for name, fn := range GetFSFunctions() {
registry[name] = fn
}
for name, fn := range GetCryptoFunctions() {
registry[name] = fn
}
return registry
}

155
functions/string.go Normal file
View File

@ -0,0 +1,155 @@
// functions/string.go
package functions
import (
"fmt"
"strings"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// GetStringFunctions returns all string manipulation Go functions
func GetStringFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"string_split": func(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("string_split: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("string_split: first argument must be a string")
}
sep, err := s.SafeToString(2)
if err != nil {
return s.PushError("string_split: second argument must be a string")
}
parts := strings.Split(str, sep)
if err := s.PushValue(parts); err != nil {
return s.PushError("string_split: failed to push result: %v", err)
}
return 1
},
"string_join": func(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("string_join: %v", err)
}
arr, err := s.SafeToTable(1)
if err != nil {
return s.PushError("string_join: first argument must be a table")
}
sep, err := s.SafeToString(2)
if err != nil {
return s.PushError("string_join: second argument must be a string")
}
var parts []string
if slice, ok := arr.([]string); ok {
parts = slice
} else if anySlice, ok := arr.([]interface{}); ok {
parts = make([]string, len(anySlice))
for i, v := range anySlice {
parts[i] = fmt.Sprintf("%v", v)
}
} else {
return s.PushError("string_join: first argument must be an array")
}
result := strings.Join(parts, sep)
s.PushString(result)
return 1
},
"string_trim": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("string_trim: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("string_trim: argument must be a string")
}
s.PushString(strings.TrimSpace(str))
return 1
},
"string_upper": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("string_upper: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("string_upper: argument must be a string")
}
s.PushString(strings.ToUpper(str))
return 1
},
"string_lower": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("string_lower: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("string_lower: argument must be a string")
}
s.PushString(strings.ToLower(str))
return 1
},
"string_contains": func(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("string_contains: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("string_contains: first argument must be a string")
}
substr, err := s.SafeToString(2)
if err != nil {
return s.PushError("string_contains: second argument must be a string")
}
s.PushBoolean(strings.Contains(str, substr))
return 1
},
"string_replace": func(s *luajit.State) int {
if err := s.CheckExactArgs(3); err != nil {
return s.PushError("string_replace: %v", err)
}
str, err := s.SafeToString(1)
if err != nil {
return s.PushError("string_replace: first argument must be a string")
}
old, err := s.SafeToString(2)
if err != nil {
return s.PushError("string_replace: second argument must be a string")
}
new, err := s.SafeToString(3)
if err != nil {
return s.PushError("string_replace: third argument must be a string")
}
result := strings.ReplaceAll(str, old, new)
s.PushString(result)
return 1
},
}
}

1
go.mod
View File

@ -3,3 +3,4 @@ module Moonshark
go 1.24.1
require git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6
require github.com/goccy/go-json v0.10.5

2
go.sum
View File

@ -1,2 +1,4 @@
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 h1:XytP9R2fWykv0MXIzxggPx5S/PmTkjyZVvUX2sn4EaU=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=

View File

@ -6,22 +6,28 @@ import (
"path/filepath"
"strings"
"Moonshark/functions"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
//go:embed modules/*.lua
var builtinModules embed.FS
// ModuleRegistry manages built-in modules
// ModuleRegistry manages built-in modules and Go functions
type ModuleRegistry struct {
modules map[string]string
goFuncs map[string]luajit.GoFunction
}
// NewModuleRegistry creates a new module registry
func NewModuleRegistry() *ModuleRegistry {
return &ModuleRegistry{
mr := &ModuleRegistry{
modules: make(map[string]string),
goFuncs: functions.GetAll(),
}
return mr
}
// RegisterModule adds a module by name and source code
@ -29,12 +35,16 @@ func (mr *ModuleRegistry) RegisterModule(name, source string) {
mr.modules[name] = source
}
// RegisterGoFunction adds a Go function that modules can use
func (mr *ModuleRegistry) RegisterGoFunction(name string, fn luajit.GoFunction) {
mr.goFuncs[name] = fn
}
// LoadEmbeddedModules loads all modules from the embedded filesystem
func (mr *ModuleRegistry) LoadEmbeddedModules() error {
entries, err := builtinModules.ReadDir("modules")
if err != nil {
fmt.Printf("Failed to read modules directory: %v\n", err)
return err
return fmt.Errorf("failed to read modules directory: %w", err)
}
for _, entry := range entries {
@ -60,6 +70,11 @@ func (mr *ModuleRegistry) InstallModules(state *luajit.State) error {
state.NewTable()
state.SetGlobal("moonshark")
// Install Go functions first
if err := mr.installGoFunctions(state); err != nil {
return fmt.Errorf("failed to install Go functions: %w", err)
}
// Register require function that checks our built-in modules first
err := state.RegisterGoFunction("require", func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
@ -101,6 +116,22 @@ func (mr *ModuleRegistry) InstallModules(state *luajit.State) error {
return err
}
// installGoFunctions installs all registered Go functions into the Lua state
func (mr *ModuleRegistry) installGoFunctions(state *luajit.State) error {
// Install functions in moonshark namespace
state.GetGlobal("moonshark")
for name, fn := range mr.goFuncs {
if err := state.PushGoFunction(fn); err != nil {
return fmt.Errorf("failed to register Go function '%s': %w", name, err)
}
state.SetField(-2, name)
}
state.Pop(1) // Remove moonshark table
return nil
}
// BackupOriginalRequire saves the original require function
func BackupOriginalRequire(state *luajit.State) {
state.GetGlobal("require")

171
modules/json.lua Normal file
View File

@ -0,0 +1,171 @@
-- modules/json.lua - High-performance JSON module using Go functions
local json = {}
-- Use the fast Go JSON encoder/decoder
function json.encode(value)
return moonshark.json_encode(value)
end
function json.decode(str)
local result, err = moonshark.json_decode(str)
if result == nil and err then
error("json_decode: " .. err)
end
return result
end
-- Pretty print JSON with indentation
function json.pretty(value, indent)
indent = indent or 2
local encoded = json.encode(value)
local result = {}
local depth = 0
local in_string = false
local escape_next = false
for i = 1, #encoded do
local char = encoded:sub(i, i)
if escape_next then
table.insert(result, char)
escape_next = false
elseif char == "\\" and in_string then
table.insert(result, char)
escape_next = true
elseif char == '"' then
table.insert(result, char)
in_string = not in_string
elseif not in_string then
if char == "{" or char == "[" then
table.insert(result, char)
depth = depth + 1
table.insert(result, "\n" .. string.rep(" ", depth * indent))
elseif char == "}" or char == "]" then
depth = depth - 1
table.insert(result, "\n" .. string.rep(" ", depth * indent))
table.insert(result, char)
elseif char == "," then
table.insert(result, char)
table.insert(result, "\n" .. string.rep(" ", depth * indent))
elseif char == ":" then
table.insert(result, char .. " ")
else
table.insert(result, char)
end
else
table.insert(result, char)
end
end
return table.concat(result)
end
-- Load JSON from file
function json.load_file(filename)
if not moonshark.file_exists(filename) then
error("File not found: " .. filename)
end
local file = io.open(filename, "r")
if not file then
error("Cannot open file: " .. filename)
end
local content = file:read("*all")
file:close()
return json.decode(content)
end
-- Save data to JSON file
function json.save_file(filename, data, pretty)
local content
if pretty then
content = json.pretty(data)
else
content = json.encode(data)
end
local file = io.open(filename, "w")
if not file then
error("Cannot write to file: " .. filename)
end
file:write(content)
file:close()
end
-- Merge JSON objects
function json.merge(...)
local result = {}
for i = 1, select("#", ...) do
local obj = select(i, ...)
if type(obj) == "table" then
for k, v in pairs(obj) do
result[k] = v
end
end
end
return result
end
-- Extract values by JSONPath-like syntax (simplified)
function json.extract(data, path)
local parts = moonshark.string_split(path, ".")
local current = data
for _, part in ipairs(parts) do
if type(current) ~= "table" then
return nil
end
-- Handle array indices [0], [1], etc.
local array_match = part:match("^%[(%d+)%]$")
if array_match then
local index = tonumber(array_match) + 1 -- Lua is 1-indexed
current = current[index]
else
current = current[part]
end
if current == nil then
return nil
end
end
return current
end
-- Validate JSON structure against schema (basic)
function json.validate(data, schema)
local function validate_value(value, schema_value)
local value_type = type(value)
local schema_type = schema_value.type
if schema_type and value_type ~= schema_type then
return false, "Expected " .. schema_type .. ", got " .. value_type
end
if schema_type == "table" and schema_value.properties then
for prop, prop_schema in pairs(schema_value.properties) do
if schema_value.required and schema_value.required[prop] and value[prop] == nil then
return false, "Missing required property: " .. prop
end
if value[prop] ~= nil then
local valid, err = validate_value(value[prop], prop_schema)
if not valid then
return false, "Property " .. prop .. ": " .. err
end
end
end
end
return true
end
return validate_value(data, schema)
end
return json

220
tests/json.lua Normal file
View File

@ -0,0 +1,220 @@
#!/usr/bin/env moonshark
-- Test script for JSON module functionality
local json = require("json")
local passed = 0
local total = 0
local function test(name, fn)
print("Testing " .. name .. "...")
total = total + 1
local ok, err = pcall(fn)
if ok then
passed = passed + 1
print(" ✓ PASS")
return true
else
print(" ✗ FAIL: " .. err)
return false
end
end
-- Test data
local test_data = {
name = "John Doe",
age = 30,
active = true,
scores = {85, 92, 78, 90},
address = {
street = "123 Main St",
city = "Springfield",
zip = "12345"
},
tags = {"developer", "golang", "lua"}
}
-- Test 1: Basic encoding
test("Basic JSON Encoding", function()
local encoded = json.encode(test_data)
assert(type(encoded) == "string", "encode should return string")
assert(string.find(encoded, "John Doe"), "should contain name")
assert(string.find(encoded, "30"), "should contain age")
end)
-- Test 2: Basic decoding
test("Basic JSON Decoding", function()
local encoded = json.encode(test_data)
local decoded = json.decode(encoded)
assert(decoded.name == "John Doe", "name should match")
assert(decoded.age == 30, "age should match")
assert(decoded.active == true, "active should be true")
assert(#decoded.scores == 4, "scores array length should be 4")
end)
-- Test 3: Round-trip encoding/decoding
test("Round-trip Encoding/Decoding", function()
local encoded = json.encode(test_data)
local decoded = json.decode(encoded)
local re_encoded = json.encode(decoded)
local re_decoded = json.decode(re_encoded)
assert(re_decoded.name == test_data.name, "name should survive round-trip")
assert(re_decoded.address.city == test_data.address.city, "nested data should survive")
end)
-- Test 4: Pretty printing
test("Pretty Printing", function()
local pretty = json.pretty(test_data)
assert(type(pretty) == "string", "pretty should return string")
assert(string.find(pretty, "\n"), "pretty should contain newlines")
assert(string.find(pretty, " "), "pretty should contain indentation")
-- Should still be valid JSON
local decoded = json.decode(pretty)
assert(decoded.name == test_data.name, "pretty JSON should still decode correctly")
end)
-- Test 5: Object merging
test("Object Merging", function()
local obj1 = {a = 1, b = 2}
local obj2 = {b = 3, c = 4}
local obj3 = {d = 5}
local merged = json.merge(obj1, obj2, obj3)
assert(merged.a == 1, "should preserve a from obj1")
assert(merged.b == 3, "should use b from obj2 (later wins)")
assert(merged.c == 4, "should include c from obj2")
assert(merged.d == 5, "should include d from obj3")
end)
-- Test 6: Data extraction
test("Data Extraction", function()
local name = json.extract(test_data, "name")
assert(name == "John Doe", "should extract name")
local city = json.extract(test_data, "address.city")
assert(city == "Springfield", "should extract nested city")
local first_score = json.extract(test_data, "scores.[0]")
assert(first_score == 85, "should extract array element")
local missing = json.extract(test_data, "nonexistent.field")
assert(missing == nil, "should return nil for missing path")
end)
-- Test 7: Schema validation
test("Schema Validation", function()
local schema = {
type = "table",
properties = {
name = {type = "string"},
age = {type = "number"},
active = {type = "boolean"}
},
required = {name = true, age = true}
}
local valid, err = json.validate(test_data, schema)
assert(valid == true, "test_data should be valid")
local invalid_data = {name = "John", age = "not_a_number"}
local invalid, err2 = json.validate(invalid_data, schema)
assert(invalid == false, "invalid_data should fail validation")
assert(type(err2) == "string", "should return error message")
end)
-- Test 8: File operations
test("File Save/Load", function()
local filename = "test_output.json"
-- Save to file
json.save_file(filename, test_data, true) -- pretty format
-- Check file exists
assert(moonshark.file_exists(filename), "file should exist after save")
-- Load from file
local loaded = json.load_file(filename)
assert(loaded.name == test_data.name, "loaded data should match original")
assert(loaded.address.zip == test_data.address.zip, "nested data should match")
-- Clean up
os.remove(filename)
end)
-- Test 9: Error handling
test("Error Handling", function()
-- Invalid JSON should throw error
local success, err = pcall(json.decode, '{"invalid": json}')
assert(success == false, "invalid JSON should cause error")
assert(type(err) == "string", "should return error message")
-- Missing file should throw error
local success2, err2 = pcall(json.load_file, "nonexistent_file.json")
assert(success2 == false, "missing file should cause error")
assert(type(err2) == "string", "should return error message")
end)
-- Test 10: Edge cases
test("Edge Cases", function()
-- Empty objects
local empty_obj = {}
local encoded_empty = json.encode(empty_obj)
local decoded_empty = json.decode(encoded_empty)
assert(type(decoded_empty) == "table", "empty object should decode to table")
-- Null values
local with_nil = {a = 1, b = nil, c = 3}
local encoded_nil = json.encode(with_nil)
local decoded_nil = json.decode(encoded_nil)
-- Note: nil values are typically omitted in JSON
-- Special numbers
local special = {
zero = 0,
negative = -42,
decimal = 3.14159
}
local encoded_special = json.encode(special)
local decoded_special = json.decode(encoded_special)
assert(decoded_special.zero == 0, "zero should encode/decode correctly")
assert(decoded_special.negative == -42, "negative should encode/decode correctly")
assert(math.abs(decoded_special.decimal - 3.14159) < 0.00001, "decimal should encode/decode correctly")
end)
-- Performance test
test("Performance Test", function()
local large_data = {}
for i = 1, 1000 do
large_data[i] = {
id = i,
name = "User " .. i,
data = {x = i * 2, y = i * 3, z = i * 4}
}
end
local start = os.clock()
local encoded = json.encode(large_data)
local encode_time = os.clock() - start
start = os.clock()
local decoded = json.decode(encoded)
local decode_time = os.clock() - start
print(string.format(" Encoded 1000 objects in %.3f seconds", encode_time))
print(string.format(" Decoded 1000 objects in %.3f seconds", decode_time))
assert(#decoded == 1000, "should have 1000 objects after decode")
assert(decoded[500].name == "User 500", "data should be intact")
end)
print("=" .. string.rep("=", 50))
print(string.format("Test Results: %d/%d passed", passed, total))
if passed == total then
print("🎉 All tests passed!")
os.exit(0)
else
print("❌ Some tests failed!")
os.exit(1)
end

View File

@ -13,10 +13,15 @@ local function assert_equal(a, b)
end
end
local passed = 0
local total = 0
local function test(name, fn)
print("Testing " .. name .. "...")
total = total + 1
local ok, err = pcall(fn)
if ok then
passed = passed + 1
print(" ✓ PASS")
else
print(" ✗ FAIL: " .. err)
@ -70,8 +75,8 @@ test("Statistics", function()
assert_equal(math.sum(data), 15)
assert_equal(math.mean(data), 3)
assert_equal(math.median(data), 3)
assert_close(math.variance(data), 2.5)
assert_close(math.stdev(data), math.sqrt(2.5))
assert_close(math.variance(data), 2)
assert_close(math.stdev(data), math.sqrt(2))
local min, max = math.minmax(data)
assert_equal(min, 1)
assert_equal(max, 5)
@ -189,4 +194,12 @@ test("Interpolation", function()
assert_close(math.interpolation.catmull_rom(0.5, 0, 1, 2, 3), 1.5)
end)
print("\nAll tests completed!")
print("=" .. string.rep("=", 50))
print(string.format("Test Results: %d/%d passed", passed, total))
if passed == total then
print("🎉 All tests passed!")
os.exit(0)
else
print("❌ Some tests failed!")
os.exit(1)
end