From da602278c59b4deb613634c11b8892cc304a14a4 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Mon, 14 Jul 2025 21:34:02 -0500 Subject: [PATCH] go functions first pass --- functions/crypto.go | 80 +++++++++++++++ functions/fs.go | 67 +++++++++++++ functions/json.go | 55 +++++++++++ functions/math.go | 85 ++++++++++++++++ functions/registry.go | 34 +++++++ functions/string.go | 155 +++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + modules.go | 39 +++++++- modules/json.lua | 171 ++++++++++++++++++++++++++++++++ tests/json.lua | 220 ++++++++++++++++++++++++++++++++++++++++++ tests/math.lua | 19 +++- 12 files changed, 921 insertions(+), 7 deletions(-) create mode 100644 functions/crypto.go create mode 100644 functions/fs.go create mode 100644 functions/json.go create mode 100644 functions/math.go create mode 100644 functions/registry.go create mode 100644 functions/string.go create mode 100644 modules/json.lua create mode 100644 tests/json.lua diff --git a/functions/crypto.go b/functions/crypto.go new file mode 100644 index 0000000..8ae6698 --- /dev/null +++ b/functions/crypto.go @@ -0,0 +1,80 @@ +// functions/crypto.go +package functions + +import ( + "crypto/md5" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// GetCryptoFunctions returns all cryptographic Go functions +func GetCryptoFunctions() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "base64_encode": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("base64_encode: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("base64_encode: argument must be a string") + } + + encoded := base64.StdEncoding.EncodeToString([]byte(str)) + s.PushString(encoded) + return 1 + }, + + "base64_decode": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("base64_decode: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("base64_decode: argument must be a string") + } + + decoded, err := base64.StdEncoding.DecodeString(str) + if err != nil { + return s.PushError("base64_decode: %v", err) + } + + s.PushString(string(decoded)) + return 1 + }, + + "md5_hash": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("md5_hash: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("md5_hash: argument must be a string") + } + + hash := md5.Sum([]byte(str)) + s.PushString(hex.EncodeToString(hash[:])) + return 1 + }, + + "sha256_hash": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("sha256_hash: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("sha256_hash: argument must be a string") + } + + hash := sha256.Sum256([]byte(str)) + s.PushString(hex.EncodeToString(hash[:])) + return 1 + }, + } +} diff --git a/functions/fs.go b/functions/fs.go new file mode 100644 index 0000000..8227a4d --- /dev/null +++ b/functions/fs.go @@ -0,0 +1,67 @@ +package functions + +import ( + "os" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// GetFSFunctions returns all file system Go functions +func GetFSFunctions() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "file_exists": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("file_exists: %v", err) + } + + path, err := s.SafeToString(1) + if err != nil { + return s.PushError("file_exists: argument must be a string") + } + + _, err = os.Stat(path) + s.PushBoolean(err == nil) + return 1 + }, + + "file_size": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("file_size: %v", err) + } + + path, err := s.SafeToString(1) + if err != nil { + return s.PushError("file_size: argument must be a string") + } + + info, err := os.Stat(path) + if err != nil { + s.PushNumber(-1) + return 1 + } + + s.PushNumber(float64(info.Size())) + return 1 + }, + + "file_is_dir": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("file_is_dir: %v", err) + } + + path, err := s.SafeToString(1) + if err != nil { + return s.PushError("file_is_dir: argument must be a string") + } + + info, err := os.Stat(path) + if err != nil { + s.PushBoolean(false) + return 1 + } + + s.PushBoolean(info.IsDir()) + return 1 + }, + } +} diff --git a/functions/json.go b/functions/json.go new file mode 100644 index 0000000..c8aa163 --- /dev/null +++ b/functions/json.go @@ -0,0 +1,55 @@ +package functions + +import ( + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "github.com/goccy/go-json" +) + +// GetJSONFunctions returns all JSON-related Go functions +func GetJSONFunctions() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "json_encode": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("json_encode: %v", err) + } + + value, err := s.ToValue(1) + if err != nil { + return s.PushError("json_encode: failed to read value: %v", err) + } + + data, err := json.Marshal(value) + if err != nil { + return s.PushError("json_encode: %v", err) + } + + s.PushString(string(data)) + return 1 + }, + + "json_decode": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("json_decode: %v", err) + } + + jsonStr, err := s.SafeToString(1) + if err != nil { + return s.PushError("json_decode: input must be a string") + } + + var result interface{} + if err := json.Unmarshal([]byte(jsonStr), &result); err != nil { + // Return nil and error string instead of PushError for JSON parsing errors + s.PushNil() + s.PushString(err.Error()) + return 2 + } + + if err := s.PushValue(result); err != nil { + return s.PushError("json_decode: failed to push result: %v", err) + } + + return 1 + }, + } +} diff --git a/functions/math.go b/functions/math.go new file mode 100644 index 0000000..fd9ffb3 --- /dev/null +++ b/functions/math.go @@ -0,0 +1,85 @@ +package functions + +import luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + +// GetMathFunctions returns all math-related Go functions +func GetMathFunctions() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "math_factorial": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("math_factorial: %v", err) + } + + n, err := s.SafeToNumber(1) + if err != nil || n < 0 || n != float64(int(n)) { + return s.PushError("math_factorial: argument must be a non-negative integer") + } + + if n > 170 { + return s.PushError("math_factorial: argument too large (max 170)") + } + + result := 1.0 + for i := 2; i <= int(n); i++ { + result *= float64(i) + } + + s.PushNumber(result) + return 1 + }, + + "math_gcd": func(s *luajit.State) int { + if err := s.CheckExactArgs(2); err != nil { + return s.PushError("math_gcd: %v", err) + } + + a, err := s.SafeToNumber(1) + if err != nil || a != float64(int(a)) { + return s.PushError("math_gcd: first argument must be an integer") + } + + b, err := s.SafeToNumber(2) + if err != nil || b != float64(int(b)) { + return s.PushError("math_gcd: second argument must be an integer") + } + + ia, ib := int(a), int(b) + for ib != 0 { + ia, ib = ib, ia%ib + } + + s.PushNumber(float64(ia)) + return 1 + }, + + "math_lcm": func(s *luajit.State) int { + if err := s.CheckExactArgs(2); err != nil { + return s.PushError("math_lcm: %v", err) + } + + a, err := s.SafeToNumber(1) + if err != nil || a != float64(int(a)) { + return s.PushError("math_lcm: first argument must be an integer") + } + + b, err := s.SafeToNumber(2) + if err != nil || b != float64(int(b)) { + return s.PushError("math_lcm: second argument must be an integer") + } + + ia, ib := int(a), int(b) + + // Calculate GCD + gcd := func(x, y int) int { + for y != 0 { + x, y = y, x%y + } + return x + } + + result := ia * ib / gcd(ia, ib) + s.PushNumber(float64(result)) + return 1 + }, + } +} diff --git a/functions/registry.go b/functions/registry.go new file mode 100644 index 0000000..c8d8ac2 --- /dev/null +++ b/functions/registry.go @@ -0,0 +1,34 @@ +package functions + +import luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + +// Registry holds all available Go functions for Lua modules +type Registry map[string]luajit.GoFunction + +// GetAll returns all registered functions +func GetAll() Registry { + registry := make(Registry) + + // Register function groups + for name, fn := range GetJSONFunctions() { + registry[name] = fn + } + + for name, fn := range GetStringFunctions() { + registry[name] = fn + } + + for name, fn := range GetMathFunctions() { + registry[name] = fn + } + + for name, fn := range GetFSFunctions() { + registry[name] = fn + } + + for name, fn := range GetCryptoFunctions() { + registry[name] = fn + } + + return registry +} diff --git a/functions/string.go b/functions/string.go new file mode 100644 index 0000000..44deb05 --- /dev/null +++ b/functions/string.go @@ -0,0 +1,155 @@ +// functions/string.go +package functions + +import ( + "fmt" + "strings" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// GetStringFunctions returns all string manipulation Go functions +func GetStringFunctions() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "string_split": func(s *luajit.State) int { + if err := s.CheckExactArgs(2); err != nil { + return s.PushError("string_split: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("string_split: first argument must be a string") + } + + sep, err := s.SafeToString(2) + if err != nil { + return s.PushError("string_split: second argument must be a string") + } + + parts := strings.Split(str, sep) + if err := s.PushValue(parts); err != nil { + return s.PushError("string_split: failed to push result: %v", err) + } + + return 1 + }, + + "string_join": func(s *luajit.State) int { + if err := s.CheckExactArgs(2); err != nil { + return s.PushError("string_join: %v", err) + } + + arr, err := s.SafeToTable(1) + if err != nil { + return s.PushError("string_join: first argument must be a table") + } + + sep, err := s.SafeToString(2) + if err != nil { + return s.PushError("string_join: second argument must be a string") + } + + var parts []string + if slice, ok := arr.([]string); ok { + parts = slice + } else if anySlice, ok := arr.([]interface{}); ok { + parts = make([]string, len(anySlice)) + for i, v := range anySlice { + parts[i] = fmt.Sprintf("%v", v) + } + } else { + return s.PushError("string_join: first argument must be an array") + } + + result := strings.Join(parts, sep) + s.PushString(result) + return 1 + }, + + "string_trim": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("string_trim: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("string_trim: argument must be a string") + } + + s.PushString(strings.TrimSpace(str)) + return 1 + }, + + "string_upper": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("string_upper: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("string_upper: argument must be a string") + } + + s.PushString(strings.ToUpper(str)) + return 1 + }, + + "string_lower": func(s *luajit.State) int { + if err := s.CheckMinArgs(1); err != nil { + return s.PushError("string_lower: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("string_lower: argument must be a string") + } + + s.PushString(strings.ToLower(str)) + return 1 + }, + + "string_contains": func(s *luajit.State) int { + if err := s.CheckExactArgs(2); err != nil { + return s.PushError("string_contains: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("string_contains: first argument must be a string") + } + + substr, err := s.SafeToString(2) + if err != nil { + return s.PushError("string_contains: second argument must be a string") + } + + s.PushBoolean(strings.Contains(str, substr)) + return 1 + }, + + "string_replace": func(s *luajit.State) int { + if err := s.CheckExactArgs(3); err != nil { + return s.PushError("string_replace: %v", err) + } + + str, err := s.SafeToString(1) + if err != nil { + return s.PushError("string_replace: first argument must be a string") + } + + old, err := s.SafeToString(2) + if err != nil { + return s.PushError("string_replace: second argument must be a string") + } + + new, err := s.SafeToString(3) + if err != nil { + return s.PushError("string_replace: third argument must be a string") + } + + result := strings.ReplaceAll(str, old, new) + s.PushString(result) + return 1 + }, + } +} diff --git a/go.mod b/go.mod index 00c2bd1..b9869e1 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,4 @@ module Moonshark go 1.24.1 require git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 +require github.com/goccy/go-json v0.10.5 diff --git a/go.sum b/go.sum index 30d332f..0ea3bab 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 h1:XytP9R2fWykv0MXIzxggPx5S/PmTkjyZVvUX2sn4EaU= git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= diff --git a/modules.go b/modules.go index 2d39411..3667cb2 100644 --- a/modules.go +++ b/modules.go @@ -6,22 +6,28 @@ import ( "path/filepath" "strings" + "Moonshark/functions" + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) //go:embed modules/*.lua var builtinModules embed.FS -// ModuleRegistry manages built-in modules +// ModuleRegistry manages built-in modules and Go functions type ModuleRegistry struct { modules map[string]string + goFuncs map[string]luajit.GoFunction } // NewModuleRegistry creates a new module registry func NewModuleRegistry() *ModuleRegistry { - return &ModuleRegistry{ + mr := &ModuleRegistry{ modules: make(map[string]string), + goFuncs: functions.GetAll(), } + + return mr } // RegisterModule adds a module by name and source code @@ -29,12 +35,16 @@ func (mr *ModuleRegistry) RegisterModule(name, source string) { mr.modules[name] = source } +// RegisterGoFunction adds a Go function that modules can use +func (mr *ModuleRegistry) RegisterGoFunction(name string, fn luajit.GoFunction) { + mr.goFuncs[name] = fn +} + // LoadEmbeddedModules loads all modules from the embedded filesystem func (mr *ModuleRegistry) LoadEmbeddedModules() error { entries, err := builtinModules.ReadDir("modules") if err != nil { - fmt.Printf("Failed to read modules directory: %v\n", err) - return err + return fmt.Errorf("failed to read modules directory: %w", err) } for _, entry := range entries { @@ -60,6 +70,11 @@ func (mr *ModuleRegistry) InstallModules(state *luajit.State) error { state.NewTable() state.SetGlobal("moonshark") + // Install Go functions first + if err := mr.installGoFunctions(state); err != nil { + return fmt.Errorf("failed to install Go functions: %w", err) + } + // Register require function that checks our built-in modules first err := state.RegisterGoFunction("require", func(s *luajit.State) int { if err := s.CheckMinArgs(1); err != nil { @@ -101,6 +116,22 @@ func (mr *ModuleRegistry) InstallModules(state *luajit.State) error { return err } +// installGoFunctions installs all registered Go functions into the Lua state +func (mr *ModuleRegistry) installGoFunctions(state *luajit.State) error { + // Install functions in moonshark namespace + state.GetGlobal("moonshark") + + for name, fn := range mr.goFuncs { + if err := state.PushGoFunction(fn); err != nil { + return fmt.Errorf("failed to register Go function '%s': %w", name, err) + } + state.SetField(-2, name) + } + + state.Pop(1) // Remove moonshark table + return nil +} + // BackupOriginalRequire saves the original require function func BackupOriginalRequire(state *luajit.State) { state.GetGlobal("require") diff --git a/modules/json.lua b/modules/json.lua new file mode 100644 index 0000000..46437f1 --- /dev/null +++ b/modules/json.lua @@ -0,0 +1,171 @@ +-- modules/json.lua - High-performance JSON module using Go functions + +local json = {} + +-- Use the fast Go JSON encoder/decoder +function json.encode(value) + return moonshark.json_encode(value) +end + +function json.decode(str) + local result, err = moonshark.json_decode(str) + if result == nil and err then + error("json_decode: " .. err) + end + return result +end + +-- Pretty print JSON with indentation +function json.pretty(value, indent) + indent = indent or 2 + local encoded = json.encode(value) + local result = {} + local depth = 0 + local in_string = false + local escape_next = false + + for i = 1, #encoded do + local char = encoded:sub(i, i) + + if escape_next then + table.insert(result, char) + escape_next = false + elseif char == "\\" and in_string then + table.insert(result, char) + escape_next = true + elseif char == '"' then + table.insert(result, char) + in_string = not in_string + elseif not in_string then + if char == "{" or char == "[" then + table.insert(result, char) + depth = depth + 1 + table.insert(result, "\n" .. string.rep(" ", depth * indent)) + elseif char == "}" or char == "]" then + depth = depth - 1 + table.insert(result, "\n" .. string.rep(" ", depth * indent)) + table.insert(result, char) + elseif char == "," then + table.insert(result, char) + table.insert(result, "\n" .. string.rep(" ", depth * indent)) + elseif char == ":" then + table.insert(result, char .. " ") + else + table.insert(result, char) + end + else + table.insert(result, char) + end + end + + return table.concat(result) +end + +-- Load JSON from file +function json.load_file(filename) + if not moonshark.file_exists(filename) then + error("File not found: " .. filename) + end + + local file = io.open(filename, "r") + if not file then + error("Cannot open file: " .. filename) + end + + local content = file:read("*all") + file:close() + + return json.decode(content) +end + +-- Save data to JSON file +function json.save_file(filename, data, pretty) + local content + if pretty then + content = json.pretty(data) + else + content = json.encode(data) + end + + local file = io.open(filename, "w") + if not file then + error("Cannot write to file: " .. filename) + end + + file:write(content) + file:close() +end + +-- Merge JSON objects +function json.merge(...) + local result = {} + for i = 1, select("#", ...) do + local obj = select(i, ...) + if type(obj) == "table" then + for k, v in pairs(obj) do + result[k] = v + end + end + end + return result +end + +-- Extract values by JSONPath-like syntax (simplified) +function json.extract(data, path) + local parts = moonshark.string_split(path, ".") + local current = data + + for _, part in ipairs(parts) do + if type(current) ~= "table" then + return nil + end + + -- Handle array indices [0], [1], etc. + local array_match = part:match("^%[(%d+)%]$") + if array_match then + local index = tonumber(array_match) + 1 -- Lua is 1-indexed + current = current[index] + else + current = current[part] + end + + if current == nil then + return nil + end + end + + return current +end + +-- Validate JSON structure against schema (basic) +function json.validate(data, schema) + local function validate_value(value, schema_value) + local value_type = type(value) + local schema_type = schema_value.type + + if schema_type and value_type ~= schema_type then + return false, "Expected " .. schema_type .. ", got " .. value_type + end + + if schema_type == "table" and schema_value.properties then + for prop, prop_schema in pairs(schema_value.properties) do + if schema_value.required and schema_value.required[prop] and value[prop] == nil then + return false, "Missing required property: " .. prop + end + + if value[prop] ~= nil then + local valid, err = validate_value(value[prop], prop_schema) + if not valid then + return false, "Property " .. prop .. ": " .. err + end + end + end + end + + return true + end + + return validate_value(data, schema) +end + +return json \ No newline at end of file diff --git a/tests/json.lua b/tests/json.lua new file mode 100644 index 0000000..6afedef --- /dev/null +++ b/tests/json.lua @@ -0,0 +1,220 @@ +#!/usr/bin/env moonshark + +-- Test script for JSON module functionality +local json = require("json") + +local passed = 0 +local total = 0 + +local function test(name, fn) + print("Testing " .. name .. "...") + total = total + 1 + local ok, err = pcall(fn) + if ok then + passed = passed + 1 + print(" ✓ PASS") + return true + else + print(" ✗ FAIL: " .. err) + return false + end +end + +-- Test data +local test_data = { + name = "John Doe", + age = 30, + active = true, + scores = {85, 92, 78, 90}, + address = { + street = "123 Main St", + city = "Springfield", + zip = "12345" + }, + tags = {"developer", "golang", "lua"} +} + +-- Test 1: Basic encoding +test("Basic JSON Encoding", function() + local encoded = json.encode(test_data) + assert(type(encoded) == "string", "encode should return string") + assert(string.find(encoded, "John Doe"), "should contain name") + assert(string.find(encoded, "30"), "should contain age") +end) + +-- Test 2: Basic decoding +test("Basic JSON Decoding", function() + local encoded = json.encode(test_data) + local decoded = json.decode(encoded) + assert(decoded.name == "John Doe", "name should match") + assert(decoded.age == 30, "age should match") + assert(decoded.active == true, "active should be true") + assert(#decoded.scores == 4, "scores array length should be 4") +end) + +-- Test 3: Round-trip encoding/decoding +test("Round-trip Encoding/Decoding", function() + local encoded = json.encode(test_data) + local decoded = json.decode(encoded) + local re_encoded = json.encode(decoded) + local re_decoded = json.decode(re_encoded) + + assert(re_decoded.name == test_data.name, "name should survive round-trip") + assert(re_decoded.address.city == test_data.address.city, "nested data should survive") +end) + +-- Test 4: Pretty printing +test("Pretty Printing", function() + local pretty = json.pretty(test_data) + assert(type(pretty) == "string", "pretty should return string") + assert(string.find(pretty, "\n"), "pretty should contain newlines") + assert(string.find(pretty, " "), "pretty should contain indentation") + + -- Should still be valid JSON + local decoded = json.decode(pretty) + assert(decoded.name == test_data.name, "pretty JSON should still decode correctly") +end) + +-- Test 5: Object merging +test("Object Merging", function() + local obj1 = {a = 1, b = 2} + local obj2 = {b = 3, c = 4} + local obj3 = {d = 5} + + local merged = json.merge(obj1, obj2, obj3) + assert(merged.a == 1, "should preserve a from obj1") + assert(merged.b == 3, "should use b from obj2 (later wins)") + assert(merged.c == 4, "should include c from obj2") + assert(merged.d == 5, "should include d from obj3") +end) + +-- Test 6: Data extraction +test("Data Extraction", function() + local name = json.extract(test_data, "name") + assert(name == "John Doe", "should extract name") + + local city = json.extract(test_data, "address.city") + assert(city == "Springfield", "should extract nested city") + + local first_score = json.extract(test_data, "scores.[0]") + assert(first_score == 85, "should extract array element") + + local missing = json.extract(test_data, "nonexistent.field") + assert(missing == nil, "should return nil for missing path") +end) + +-- Test 7: Schema validation +test("Schema Validation", function() + local schema = { + type = "table", + properties = { + name = {type = "string"}, + age = {type = "number"}, + active = {type = "boolean"} + }, + required = {name = true, age = true} + } + + local valid, err = json.validate(test_data, schema) + assert(valid == true, "test_data should be valid") + + local invalid_data = {name = "John", age = "not_a_number"} + local invalid, err2 = json.validate(invalid_data, schema) + assert(invalid == false, "invalid_data should fail validation") + assert(type(err2) == "string", "should return error message") +end) + +-- Test 8: File operations +test("File Save/Load", function() + local filename = "test_output.json" + + -- Save to file + json.save_file(filename, test_data, true) -- pretty format + + -- Check file exists + assert(moonshark.file_exists(filename), "file should exist after save") + + -- Load from file + local loaded = json.load_file(filename) + assert(loaded.name == test_data.name, "loaded data should match original") + assert(loaded.address.zip == test_data.address.zip, "nested data should match") + + -- Clean up + os.remove(filename) +end) + +-- Test 9: Error handling +test("Error Handling", function() + -- Invalid JSON should throw error + local success, err = pcall(json.decode, '{"invalid": json}') + assert(success == false, "invalid JSON should cause error") + assert(type(err) == "string", "should return error message") + + -- Missing file should throw error + local success2, err2 = pcall(json.load_file, "nonexistent_file.json") + assert(success2 == false, "missing file should cause error") + assert(type(err2) == "string", "should return error message") +end) + +-- Test 10: Edge cases +test("Edge Cases", function() + -- Empty objects + local empty_obj = {} + local encoded_empty = json.encode(empty_obj) + local decoded_empty = json.decode(encoded_empty) + assert(type(decoded_empty) == "table", "empty object should decode to table") + + -- Null values + local with_nil = {a = 1, b = nil, c = 3} + local encoded_nil = json.encode(with_nil) + local decoded_nil = json.decode(encoded_nil) + -- Note: nil values are typically omitted in JSON + + -- Special numbers + local special = { + zero = 0, + negative = -42, + decimal = 3.14159 + } + local encoded_special = json.encode(special) + local decoded_special = json.decode(encoded_special) + assert(decoded_special.zero == 0, "zero should encode/decode correctly") + assert(decoded_special.negative == -42, "negative should encode/decode correctly") + assert(math.abs(decoded_special.decimal - 3.14159) < 0.00001, "decimal should encode/decode correctly") +end) + +-- Performance test +test("Performance Test", function() + local large_data = {} + for i = 1, 1000 do + large_data[i] = { + id = i, + name = "User " .. i, + data = {x = i * 2, y = i * 3, z = i * 4} + } + end + + local start = os.clock() + local encoded = json.encode(large_data) + local encode_time = os.clock() - start + + start = os.clock() + local decoded = json.decode(encoded) + local decode_time = os.clock() - start + + print(string.format(" Encoded 1000 objects in %.3f seconds", encode_time)) + print(string.format(" Decoded 1000 objects in %.3f seconds", decode_time)) + + assert(#decoded == 1000, "should have 1000 objects after decode") + assert(decoded[500].name == "User 500", "data should be intact") +end) + +print("=" .. string.rep("=", 50)) +print(string.format("Test Results: %d/%d passed", passed, total)) +if passed == total then + print("🎉 All tests passed!") + os.exit(0) +else + print("❌ Some tests failed!") + os.exit(1) +end \ No newline at end of file diff --git a/tests/math.lua b/tests/math.lua index c0a0acb..138f50a 100644 --- a/tests/math.lua +++ b/tests/math.lua @@ -13,10 +13,15 @@ local function assert_equal(a, b) end end +local passed = 0 +local total = 0 + local function test(name, fn) print("Testing " .. name .. "...") + total = total + 1 local ok, err = pcall(fn) if ok then + passed = passed + 1 print(" ✓ PASS") else print(" ✗ FAIL: " .. err) @@ -70,8 +75,8 @@ test("Statistics", function() assert_equal(math.sum(data), 15) assert_equal(math.mean(data), 3) assert_equal(math.median(data), 3) - assert_close(math.variance(data), 2.5) - assert_close(math.stdev(data), math.sqrt(2.5)) + assert_close(math.variance(data), 2) + assert_close(math.stdev(data), math.sqrt(2)) local min, max = math.minmax(data) assert_equal(min, 1) assert_equal(max, 5) @@ -189,4 +194,12 @@ test("Interpolation", function() assert_close(math.interpolation.catmull_rom(0.5, 0, 1, 2, 3), 1.5) end) -print("\nAll tests completed!") \ No newline at end of file +print("=" .. string.rep("=", 50)) +print(string.format("Test Results: %d/%d passed", passed, total)) +if passed == total then + print("🎉 All tests passed!") + os.exit(0) +else + print("❌ Some tests failed!") + os.exit(1) +end \ No newline at end of file