From acb86701772297222d244b91c87b85dfe088746d Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Mon, 14 Jul 2025 21:51:02 -0500 Subject: [PATCH] update test fwk, fix package path --- moonshark.go | 14 ++++ tests/json.lua | 104 +++++++++---------------- tests/math.lua | 44 ++--------- tests/tests.lua | 203 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 258 insertions(+), 107 deletions(-) create mode 100644 tests/tests.lua diff --git a/moonshark.go b/moonshark.go index e7fe08e..03c5902 100644 --- a/moonshark.go +++ b/moonshark.go @@ -43,6 +43,20 @@ func main() { os.Exit(1) } + // Get the absolute path to the script directory + scriptDir := filepath.Dir(scriptPath) + absScriptDir, err := filepath.Abs(scriptDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: failed to get absolute path: %v\n", err) + os.Exit(1) + } + + // Add script directory to Lua's package.path + packagePath := filepath.Join(absScriptDir, "?.lua") + if err := state.AddPackagePath(packagePath); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to add script directory to package.path: %v\n", err) + } + // Execute the script if err := state.DoFile(scriptPath); err != nil { fmt.Fprintf(os.Stderr, "Error executing '%s': %v\n", scriptPath, err) diff --git a/tests/json.lua b/tests/json.lua index 6afedef..dbb9afd 100644 --- a/tests/json.lua +++ b/tests/json.lua @@ -1,25 +1,6 @@ -#!/usr/bin/env moonshark - --- Test script for JSON module functionality +require("tests") local json = require("json") -local passed = 0 -local total = 0 - -local function test(name, fn) - print("Testing " .. name .. "...") - total = total + 1 - local ok, err = pcall(fn) - if ok then - passed = passed + 1 - print(" ✓ PASS") - return true - else - print(" ✗ FAIL: " .. err) - return false - end -end - -- Test data local test_data = { name = "John Doe", @@ -37,7 +18,7 @@ local test_data = { -- Test 1: Basic encoding test("Basic JSON Encoding", function() local encoded = json.encode(test_data) - assert(type(encoded) == "string", "encode should return string") + assert_equal(type(encoded), "string") assert(string.find(encoded, "John Doe"), "should contain name") assert(string.find(encoded, "30"), "should contain age") end) @@ -46,10 +27,10 @@ end) test("Basic JSON Decoding", function() local encoded = json.encode(test_data) local decoded = json.decode(encoded) - assert(decoded.name == "John Doe", "name should match") - assert(decoded.age == 30, "age should match") - assert(decoded.active == true, "active should be true") - assert(#decoded.scores == 4, "scores array length should be 4") + assert_equal(decoded.name, "John Doe") + assert_equal(decoded.age, 30) + assert_equal(decoded.active, true) + assert_equal(#decoded.scores, 4) end) -- Test 3: Round-trip encoding/decoding @@ -59,20 +40,20 @@ test("Round-trip Encoding/Decoding", function() local re_encoded = json.encode(decoded) local re_decoded = json.decode(re_encoded) - assert(re_decoded.name == test_data.name, "name should survive round-trip") - assert(re_decoded.address.city == test_data.address.city, "nested data should survive") + assert_equal(re_decoded.name, test_data.name) + assert_equal(re_decoded.address.city, test_data.address.city) end) -- Test 4: Pretty printing test("Pretty Printing", function() local pretty = json.pretty(test_data) - assert(type(pretty) == "string", "pretty should return string") + assert_equal(type(pretty), "string") assert(string.find(pretty, "\n"), "pretty should contain newlines") assert(string.find(pretty, " "), "pretty should contain indentation") -- Should still be valid JSON local decoded = json.decode(pretty) - assert(decoded.name == test_data.name, "pretty JSON should still decode correctly") + assert_equal(decoded.name, test_data.name) end) -- Test 5: Object merging @@ -82,25 +63,25 @@ test("Object Merging", function() local obj3 = {d = 5} local merged = json.merge(obj1, obj2, obj3) - assert(merged.a == 1, "should preserve a from obj1") - assert(merged.b == 3, "should use b from obj2 (later wins)") - assert(merged.c == 4, "should include c from obj2") - assert(merged.d == 5, "should include d from obj3") + assert_equal(merged.a, 1) + assert_equal(merged.b, 3) -- later wins + assert_equal(merged.c, 4) + assert_equal(merged.d, 5) end) -- Test 6: Data extraction test("Data Extraction", function() local name = json.extract(test_data, "name") - assert(name == "John Doe", "should extract name") + assert_equal(name, "John Doe") local city = json.extract(test_data, "address.city") - assert(city == "Springfield", "should extract nested city") + assert_equal(city, "Springfield") local first_score = json.extract(test_data, "scores.[0]") - assert(first_score == 85, "should extract array element") + assert_equal(first_score, 85) local missing = json.extract(test_data, "nonexistent.field") - assert(missing == nil, "should return nil for missing path") + assert_equal(missing, nil) end) -- Test 7: Schema validation @@ -116,12 +97,12 @@ test("Schema Validation", function() } local valid, err = json.validate(test_data, schema) - assert(valid == true, "test_data should be valid") + assert_equal(valid, true) local invalid_data = {name = "John", age = "not_a_number"} local invalid, err2 = json.validate(invalid_data, schema) - assert(invalid == false, "invalid_data should fail validation") - assert(type(err2) == "string", "should return error message") + assert_equal(invalid, false) + assert_equal(type(err2), "string") end) -- Test 8: File operations @@ -132,12 +113,12 @@ test("File Save/Load", function() json.save_file(filename, test_data, true) -- pretty format -- Check file exists - assert(moonshark.file_exists(filename), "file should exist after save") + assert(file_exists(filename), "file should exist after save") -- Load from file local loaded = json.load_file(filename) - assert(loaded.name == test_data.name, "loaded data should match original") - assert(loaded.address.zip == test_data.address.zip, "nested data should match") + assert_equal(loaded.name, test_data.name) + assert_equal(loaded.address.zip, test_data.address.zip) -- Clean up os.remove(filename) @@ -147,13 +128,13 @@ end) test("Error Handling", function() -- Invalid JSON should throw error local success, err = pcall(json.decode, '{"invalid": json}') - assert(success == false, "invalid JSON should cause error") - assert(type(err) == "string", "should return error message") + assert_equal(success, false) + assert_equal(type(err), "string") -- Missing file should throw error local success2, err2 = pcall(json.load_file, "nonexistent_file.json") - assert(success2 == false, "missing file should cause error") - assert(type(err2) == "string", "should return error message") + assert_equal(success2, false) + assert_equal(type(err2), "string") end) -- Test 10: Edge cases @@ -162,13 +143,7 @@ test("Edge Cases", function() local empty_obj = {} local encoded_empty = json.encode(empty_obj) local decoded_empty = json.decode(encoded_empty) - assert(type(decoded_empty) == "table", "empty object should decode to table") - - -- Null values - local with_nil = {a = 1, b = nil, c = 3} - local encoded_nil = json.encode(with_nil) - local decoded_nil = json.decode(encoded_nil) - -- Note: nil values are typically omitted in JSON + assert_equal(type(decoded_empty), "table") -- Special numbers local special = { @@ -178,9 +153,9 @@ test("Edge Cases", function() } local encoded_special = json.encode(special) local decoded_special = json.decode(encoded_special) - assert(decoded_special.zero == 0, "zero should encode/decode correctly") - assert(decoded_special.negative == -42, "negative should encode/decode correctly") - assert(math.abs(decoded_special.decimal - 3.14159) < 0.00001, "decimal should encode/decode correctly") + assert_equal(decoded_special.zero, 0) + assert_equal(decoded_special.negative, -42) + assert_close(decoded_special.decimal, 3.14159, 0.00001) end) -- Performance test @@ -205,16 +180,9 @@ test("Performance Test", function() print(string.format(" Encoded 1000 objects in %.3f seconds", encode_time)) print(string.format(" Decoded 1000 objects in %.3f seconds", decode_time)) - assert(#decoded == 1000, "should have 1000 objects after decode") - assert(decoded[500].name == "User 500", "data should be intact") + assert_equal(#decoded, 1000) + assert_equal(decoded[500].name, "User 500") end) -print("=" .. string.rep("=", 50)) -print(string.format("Test Results: %d/%d passed", passed, total)) -if passed == total then - print("🎉 All tests passed!") - os.exit(0) -else - print("❌ Some tests failed!") - os.exit(1) -end \ No newline at end of file +summary() +test_exit() \ No newline at end of file diff --git a/tests/math.lua b/tests/math.lua index 138f50a..cee6648 100644 --- a/tests/math.lua +++ b/tests/math.lua @@ -1,33 +1,6 @@ +require("tests") local math = require("math") -local function assert_close(a, b, tolerance) - tolerance = tolerance or 1e-10 - if math.abs(a - b) > tolerance then - error(string.format("Expected %g, got %g (diff: %g)", a, b, math.abs(a - b))) - end -end - -local function assert_equal(a, b) - if a ~= b then - error(string.format("Expected %s, got %s", tostring(a), tostring(b))) - end -end - -local passed = 0 -local total = 0 - -local function test(name, fn) - print("Testing " .. name .. "...") - total = total + 1 - local ok, err = pcall(fn) - if ok then - passed = passed + 1 - print(" ✓ PASS") - else - print(" ✗ FAIL: " .. err) - end -end - -- Test constants test("Constants", function() assert_close(math.pi, 3.14159265358979323846) @@ -63,9 +36,9 @@ end) -- Test random functions test("Random Functions", function() local r = math.randomf(0, 1) - assert_equal(r >= 0 and r < 1, true) + assert(r >= 0 and r < 1, "randomf should be in range [0, 1)") local i = math.randint(1, 10) - assert_equal(i >= 1 and i <= 10, true) + assert(i >= 1 and i <= 10, "randint should be in range [1, 10]") assert_equal(type(math.randboolean()), "boolean") end) @@ -194,12 +167,5 @@ test("Interpolation", function() assert_close(math.interpolation.catmull_rom(0.5, 0, 1, 2, 3), 1.5) end) -print("=" .. string.rep("=", 50)) -print(string.format("Test Results: %d/%d passed", passed, total)) -if passed == total then - print("🎉 All tests passed!") - os.exit(0) -else - print("❌ Some tests failed!") - os.exit(1) -end \ No newline at end of file +summary() +test_exit() \ No newline at end of file diff --git a/tests/tests.lua b/tests/tests.lua new file mode 100644 index 0000000..03b6274 --- /dev/null +++ b/tests/tests.lua @@ -0,0 +1,203 @@ +-- Enhanced Test Framework - Global Functions +-- Provides better assert reporting and test runner functionality + +-- Test state +local passed = 0 +local total = 0 + +-- Enhanced assert function with better error reporting +function assert(condition, message, level) + if condition then + return true + end + + level = level or 2 + local info = debug.getinfo(level, "Sl") + local file = info.source:match("@?(.+)") or "unknown" + local line = info.currentline or "unknown" + + local error_msg = message or "assertion failed" + local full_msg = string.format("%s:%s: %s", file, line, error_msg) + + error(full_msg, 0) +end + +-- Assert with tolerance for floating point comparisons +function assert_close(a, b, tolerance, message) + tolerance = tolerance or 1e-10 + local diff = math.abs(a - b) + if diff <= tolerance then + return true + end + + local msg = message or string.format("Expected %g, got %g (diff: %g, tolerance: %g)", a, b, diff, tolerance) + assert(false, msg, 3) +end + +-- Assert equality with better error messages +function assert_equal(a, b, message) + if a == b then + return true + end + + local msg = message or string.format("Expected %s, got %s", tostring(a), tostring(b)) + assert(false, msg, 3) +end + +-- Assert table equality (deep comparison) +function assert_table_equal(a, b, message, path) + path = path or "root" + + if type(a) ~= type(b) then + local msg = message or string.format("Type mismatch at %s: expected %s, got %s", path, type(a), type(b)) + assert(false, msg, 3) + end + + if type(a) ~= "table" then + if a ~= b then + local msg = message or string.format("Value mismatch at %s: expected %s, got %s", path, tostring(a), tostring(b)) + assert(false, msg, 3) + end + return true + end + + -- Check all keys in a exist in b with same values + for k, v in pairs(a) do + local new_path = path .. "." .. tostring(k) + if b[k] == nil then + local msg = message or string.format("Missing key at %s", new_path) + assert(false, msg, 3) + end + assert_table_equal(v, b[k], message, new_path) + end + + -- Check all keys in b exist in a + for k, v in pairs(b) do + if a[k] == nil then + local new_path = path .. "." .. tostring(k) + local msg = message or string.format("Extra key at %s", new_path) + assert(false, msg, 3) + end + end + + return true +end + +-- Test runner function +function test(name, fn) + print("Testing " .. name .. "...") + total = total + 1 + + local start_time = os.clock() + local ok, err = pcall(fn) + local end_time = os.clock() + local duration = end_time - start_time + + if ok then + passed = passed + 1 + print(string.format(" ✓ PASS (%.3fs)", duration)) + return true + else + print(" ✗ FAIL: " .. err) + if duration > 0.001 then + print(string.format(" (%.3fs)", duration)) + end + return false + end +end + +-- Test suite runner +function run_tests(tests) + print("Running test suite...") + print("=" .. string.rep("=", 50)) + + for name, test_fn in pairs(tests) do + test(name, test_fn) + end + + return summary() +end + +-- Reset test counters +function reset_tests() + passed = 0 + total = 0 +end + +-- Get test statistics +function test_stats() + return { + passed = passed, + total = total, + failed = total - passed, + success_rate = total > 0 and (passed / total) or 0 + } +end + +-- Print test summary and return success status +function summary() + print("=" .. string.rep("=", 50)) + print(string.format("Test Results: %d/%d passed", passed, total)) + + local success = passed == total + if success then + print("🎉 All tests passed!") + else + local failed = total - passed + local rate = total > 0 and (passed / total * 100) or 0 + print(string.format("❌ %d test(s) failed! (%.1f%% success rate)", failed, rate)) + end + + return success +end + +-- Exit with appropriate code based on test results +function test_exit() + local success = passed == total + os.exit(success and 0 or 1) +end + +-- Convenience function to run and exit +function run_and_exit(tests) + local success = run_tests(tests) + os.exit(success and 0 or 1) +end + +-- Benchmark function +function benchmark(name, fn, iterations) + iterations = iterations or 1000 + print("Benchmarking " .. name .. " (" .. iterations .. " iterations)...") + + -- Warmup + for i = 1, math.min(10, iterations) do + fn() + end + + -- Actual benchmark + local start = os.clock() + for i = 1, iterations do + fn() + end + local total_time = os.clock() - start + local avg_time = total_time / iterations + + print(string.format(" Total: %.3fs, Average: %.6fs, Rate: %.0f ops/sec", + total_time, avg_time, 1/avg_time)) + + return { + total_time = total_time, + avg_time = avg_time, + ops_per_sec = 1/avg_time, + iterations = iterations + } +end + +-- Helper to check if file exists +function file_exists(filename) + local file = io.open(filename, "r") + if file then + file:close() + return true + end + return false +end \ No newline at end of file