From cf203d78998e994d4675a76fb1f62cf570900121 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 24 Jul 2025 09:39:24 -0500 Subject: [PATCH] initial sql database support - sqlite, postgres, mysql --- .gitignore | 1 + go.mod | 23 + go.sum | 46 ++ modules/http/http.go | 1 - modules/mysql/mysql.lua | 1003 +++++++++++++++++++++++++++++++++ modules/postgres/postgres.lua | 847 ++++++++++++++++++++++++++++ modules/registry.go | 4 +- modules/sql/mysql.go | 205 +++++++ modules/sql/postgres.go | 234 ++++++++ modules/sql/sql.go | 367 ++++++++++++ modules/sql/sqlite.go | 384 +++++++++++++ modules/sqlite/sqlite.lua | 655 +++++++++++++++++++++ 12 files changed, 3768 insertions(+), 2 deletions(-) create mode 100644 modules/mysql/mysql.lua create mode 100644 modules/postgres/postgres.lua create mode 100644 modules/sql/mysql.go create mode 100644 modules/sql/postgres.go create mode 100644 modules/sql/sql.go create mode 100644 modules/sql/sqlite.go create mode 100644 modules/sqlite/sqlite.lua diff --git a/.gitignore b/.gitignore index 21b72dd..0802053 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ go.work test_fs_dir public test +test.db diff --git a/go.mod b/go.mod index 89e0c7c..8295656 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,29 @@ require git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 require github.com/google/uuid v1.6.0 +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-sql-driver/mysql v1.9.3 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.5 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/crypto v0.40.0 // indirect + golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.27.0 // indirect + modernc.org/libc v1.65.7 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.37.1 // indirect + zombiezen.com/go/sqlite v1.4.2 // indirect +) + require ( github.com/andybalholm/brotli v1.2.0 // indirect github.com/klauspost/compress v1.18.0 // indirect diff --git a/go.sum b/go.sum index c5b7adf..4973630 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,64 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 h1:XytP9R2fWykv0MXIzxggPx5S/PmTkjyZVvUX2sn4EaU= git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= +github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og= github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= +golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/libc v1.65.7 h1:Ia9Z4yzZtWNtUIuiPuQ7Qf7kxYrxP1/jeHZzG8bFu00= +modernc.org/libc v1.65.7/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs= +modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g= +zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo= +zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc= diff --git a/modules/http/http.go b/modules/http/http.go index 10151cd..65f2b2b 100644 --- a/modules/http/http.go +++ b/modules/http/http.go @@ -263,7 +263,6 @@ func handleRequest(ctx *fasthttp.RequestCtx) { } } - // Lua 404, try static handlers if not already checked if !isLikelyStaticFile(path) && tryStaticHandler(ctx, path) { return } diff --git a/modules/mysql/mysql.lua b/modules/mysql/mysql.lua new file mode 100644 index 0000000..781e2a9 --- /dev/null +++ b/modules/mysql/mysql.lua @@ -0,0 +1,1003 @@ +local str = require("string") +local mysql = {} + +local Connection = {} +Connection.__index = Connection + +function Connection:close() + if self._id then + local ok = moonshark.sql_close(self._id) + self._id = nil + return ok + end + return false +end + +function Connection:ping() + if not self._id then + error("Connection is closed") + end + return moonshark.sql_ping(self._id) +end + +function Connection:query(query_str, ...) + if not self._id then + error("Connection is closed") + end + query_str = str.normalize_whitespace(query_str) + return moonshark.sql_query(self._id, query_str, ...) +end + +function Connection:exec(query_str, ...) + if not self._id then + error("Connection is closed") + end + query_str = str.normalize_whitespace(query_str) + return moonshark.sql_exec(self._id, query_str, ...) +end + +function Connection:query_row(query_str, ...) + local results = self:query(query_str, ...) + if results and #results > 0 then + return results[1] + end + return nil +end + +function Connection:query_value(query_str, ...) + local row = self:query_row(query_str, ...) + if row then + for _, value in pairs(row) do + return value + end + end + return nil +end + +-- Enhanced transaction support with savepoints +function Connection:begin() + local result = self:exec("BEGIN") + if result then + return { + conn = self, + active = true, + + commit = function(tx) + if tx.active then + local result = tx.conn:exec("COMMIT") + tx.active = false + return result + end + return false + end, + + rollback = function(tx) + if tx.active then + local result = tx.conn:exec("ROLLBACK") + tx.active = false + return result + end + return false + end, + + savepoint = function(tx, name) + if not tx.active then + error("Transaction is not active") + end + if str.is_blank(name) then + error("Savepoint name cannot be empty") + end + return tx.conn:exec(str.template("SAVEPOINT ${name}", {name = name})) + end, + + rollback_to = function(tx, name) + if not tx.active then + error("Transaction is not active") + end + if str.is_blank(name) then + error("Savepoint name cannot be empty") + end + return tx.conn:exec(str.template("ROLLBACK TO SAVEPOINT ${name}", {name = name})) + end, + + query = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query(query_str, ...) + end, + + exec = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:exec(query_str, ...) + end, + + query_row = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query_row(query_str, ...) + end, + + query_value = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query_value(query_str, ...) + end + } + end + return nil +end + +-- Enhanced MySQL-specific query builder helpers +function Connection:insert(table_name, data) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local keys = {} + local values = {} + local placeholders = {} + + for key, value in pairs(data) do + table.insert(keys, key) + table.insert(values, value) + table.insert(placeholders, "?") + end + + local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", { + table = table_name, + columns = str.join(keys, ", "), + placeholders = str.join(placeholders, ", ") + }) + + return self:exec(query, unpack(values)) +end + +function Connection:upsert(table_name, data, update_data) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local keys = {} + local values = {} + local placeholders = {} + local updates = {} + + for key, value in pairs(data) do + table.insert(keys, key) + table.insert(values, value) + table.insert(placeholders, "?") + end + + -- Use update_data if provided, otherwise update with same data + local update_source = update_data or data + for key, _ in pairs(update_source) do + table.insert(updates, str.template("${key} = VALUES(${key})", {key = key})) + end + + local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON DUPLICATE KEY UPDATE ${updates}", { + table = table_name, + columns = str.join(keys, ", "), + placeholders = str.join(placeholders, ", "), + updates = str.join(updates, ", ") + }) + + return self:exec(query, unpack(values)) +end + +function Connection:replace(table_name, data) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local keys = {} + local values = {} + local placeholders = {} + + for key, value in pairs(data) do + table.insert(keys, key) + table.insert(values, value) + table.insert(placeholders, "?") + end + + local query = str.template("REPLACE INTO ${table} (${columns}) VALUES (${placeholders})", { + table = table_name, + columns = str.join(keys, ", "), + placeholders = str.join(placeholders, ", ") + }) + + return self:exec(query, unpack(values)) +end + +function Connection:update(table_name, data, where_clause, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + if str.is_blank(where_clause) then + error("WHERE clause cannot be empty for UPDATE") + end + + local sets = {} + local values = {} + + for key, value in pairs(data) do + table.insert(sets, str.template("${key} = ?", {key = key})) + table.insert(values, value) + end + + local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", { + table = table_name, + sets = str.join(sets, ", "), + where = where_clause + }) + + -- Add WHERE clause parameters + local where_args = {...} + for i = 1, #where_args do + table.insert(values, where_args[i]) + end + + return self:exec(query, unpack(values)) +end + +function Connection:delete(table_name, where_clause, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + if str.is_blank(where_clause) then + error("WHERE clause cannot be empty for DELETE") + end + + local query = str.template("DELETE FROM ${table} WHERE ${where}", { + table = table_name, + where = where_clause + }) + return self:exec(query, ...) +end + +function Connection:select(table_name, columns, where_clause, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + columns = columns or "*" + if type(columns) == "table" then + columns = str.join(columns, ", ") + end + + local query + if where_clause and not str.is_blank(where_clause) then + query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", { + columns = columns, + table = table_name, + where = where_clause + }) + return self:query(query, ...) + else + query = str.template("SELECT ${columns} FROM ${table}", { + columns = columns, + table = table_name + }) + return self:query(query) + end +end + +-- Enhanced MySQL schema helpers +function Connection:database_exists(database_name) + if str.is_blank(database_name) then + return false + end + + local result = self:query_value( + "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", + str.trim(database_name) + ) + return result ~= nil +end + +function Connection:table_exists(table_name, database_name) + if str.is_blank(table_name) then + return false + end + + database_name = database_name or self:current_database() + if not database_name then + return false + end + + local result = self:query_value( + "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", + str.trim(database_name), str.trim(table_name) + ) + return result ~= nil +end + +function Connection:column_exists(table_name, column_name, database_name) + if str.is_blank(table_name) or str.is_blank(column_name) then + return false + end + + database_name = database_name or self:current_database() + if not database_name then + return false + end + + local result = self:query_value([[ + SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ? + ]], str.trim(database_name), str.trim(table_name), str.trim(column_name)) + return result ~= nil +end + +function Connection:create_database(database_name, charset, collation) + if str.is_blank(database_name) then + error("Database name cannot be empty") + end + + local charset_clause = charset and str.template(" CHARACTER SET ${charset}", {charset = charset}) or "" + local collation_clause = collation and str.template(" COLLATE ${collation}", {collation = collation}) or "" + + local query = str.template("CREATE DATABASE IF NOT EXISTS ${database}${charset}${collation}", { + database = database_name, + charset = charset_clause, + collation = collation_clause + }) + return self:exec(query) +end + +function Connection:drop_database(database_name) + if str.is_blank(database_name) then + error("Database name cannot be empty") + end + + local query = str.template("DROP DATABASE IF EXISTS ${database}", {database = database_name}) + return self:exec(query) +end + +function Connection:create_table(table_name, schema, engine, charset) + if str.is_blank(table_name) or str.is_blank(schema) then + error("Table name and schema cannot be empty") + end + + local engine_clause = engine and str.template(" ENGINE=${engine}", {engine = str.upper(engine)}) or "" + local charset_clause = charset and str.template(" CHARACTER SET ${charset}", {charset = charset}) or "" + + local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})${engine}${charset}", { + table = table_name, + schema = str.trim(schema), + engine = engine_clause, + charset = charset_clause + }) + return self:exec(query) +end + +function Connection:drop_table(table_name) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local query = str.template("DROP TABLE IF EXISTS ${table}", {table = table_name}) + return self:exec(query) +end + +function Connection:add_column(table_name, column_def, position) + if str.is_blank(table_name) or str.is_blank(column_def) then + error("Table name and column definition cannot be empty") + end + + local position_clause = position and str.template(" ${position}", {position = position}) or "" + local query = str.template("ALTER TABLE ${table} ADD COLUMN ${column}${position}", { + table = table_name, + column = str.trim(column_def), + position = position_clause + }) + return self:exec(query) +end + +function Connection:drop_column(table_name, column_name) + if str.is_blank(table_name) or str.is_blank(column_name) then + error("Table name and column name cannot be empty") + end + + local query = str.template("ALTER TABLE ${table} DROP COLUMN ${column}", { + table = table_name, + column = column_name + }) + return self:exec(query) +end + +function Connection:modify_column(table_name, column_def) + if str.is_blank(table_name) or str.is_blank(column_def) then + error("Table name and column definition cannot be empty") + end + + local query = str.template("ALTER TABLE ${table} MODIFY COLUMN ${column}", { + table = table_name, + column = str.trim(column_def) + }) + return self:exec(query) +end + +function Connection:rename_table(old_name, new_name) + if str.is_blank(old_name) or str.is_blank(new_name) then + error("Old and new table names cannot be empty") + end + + local query = str.template("RENAME TABLE ${old} TO ${new}", { + old = old_name, + new = new_name + }) + return self:exec(query) +end + +function Connection:create_index(index_name, table_name, columns, unique, type) + if str.is_blank(index_name) or str.is_blank(table_name) then + error("Index name and table name cannot be empty") + end + + local unique_clause = unique and "UNIQUE " or "" + local type_clause = type and str.template(" USING ${type}", {type = str.upper(type)}) or "" + local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns) + + local query = str.template("CREATE ${unique}INDEX ${index} ON ${table} (${columns})${type}", { + unique = unique_clause, + index = index_name, + table = table_name, + columns = columns_str, + type = type_clause + }) + return self:exec(query) +end + +function Connection:drop_index(index_name, table_name) + if str.is_blank(index_name) or str.is_blank(table_name) then + error("Index name and table name cannot be empty") + end + + local query = str.template("DROP INDEX ${index} ON ${table}", { + index = index_name, + table = table_name + }) + return self:exec(query) +end + +-- Enhanced MySQL maintenance functions +function Connection:optimize(table_name) + local table_clause = table_name and str.template(" ${table}", {table = table_name}) or "" + return self:query(str.template("OPTIMIZE TABLE${table}", {table = table_clause})) +end + +function Connection:repair(table_name) + if str.is_blank(table_name) then + error("Table name cannot be empty for REPAIR") + end + return self:query(str.template("REPAIR TABLE ${table}", {table = table_name})) +end + +function Connection:check_table(table_name, options) + if str.is_blank(table_name) then + error("Table name cannot be empty for CHECK") + end + + local options_clause = "" + if options then + local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"} + local options_upper = str.upper(options) + + local valid = false + for _, valid_option in ipairs(valid_options) do + if options_upper == valid_option then + valid = true + break + end + end + + if valid then + options_clause = str.template(" ${options}", {options = options_upper}) + end + end + + return self:query(str.template("CHECK TABLE ${table}${options}", { + table = table_name, + options = options_clause + })) +end + +function Connection:analyze_table(table_name) + if str.is_blank(table_name) then + error("Table name cannot be empty for ANALYZE") + end + return self:query(str.template("ANALYZE TABLE ${table}", {table = table_name})) +end + +-- Enhanced MySQL settings and introspection +function Connection:show(what) + if str.is_blank(what) then + error("SHOW parameter cannot be empty") + end + return self:query(str.template("SHOW ${what}", {what = str.upper(what)})) +end + +function Connection:current_database() + return self:query_value("SELECT DATABASE() AS db") +end + +function Connection:version() + return self:query_value("SELECT VERSION() AS version") +end + +function Connection:connection_id() + return self:query_value("SELECT CONNECTION_ID()") +end + +function Connection:list_databases() + return self:query("SHOW DATABASES") +end + +function Connection:list_tables(database_name) + if database_name and not str.is_blank(database_name) then + return self:query(str.template("SHOW TABLES FROM ${database}", {database = database_name})) + else + return self:query("SHOW TABLES") + end +end + +function Connection:describe_table(table_name) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + return self:query(str.template("DESCRIBE ${table}", {table = table_name})) +end + +function Connection:show_create_table(table_name) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + return self:query(str.template("SHOW CREATE TABLE ${table}", {table = table_name})) +end + +function Connection:show_indexes(table_name) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + return self:query(str.template("SHOW INDEXES FROM ${table}", {table = table_name})) +end + +function Connection:show_table_status(table_name) + if table_name and not str.is_blank(table_name) then + return self:query("SHOW TABLE STATUS LIKE ?", table_name) + else + return self:query("SHOW TABLE STATUS") + end +end + +-- Enhanced MySQL user and privilege management +function Connection:create_user(username, password, host) + if str.is_blank(username) or str.is_blank(password) then + error("Username and password cannot be empty") + end + + host = host or "%" + local query = str.template("CREATE USER '${username}'@'${host}' IDENTIFIED BY ?", { + username = username, + host = host + }) + return self:exec(query, password) +end + +function Connection:drop_user(username, host) + if str.is_blank(username) then + error("Username cannot be empty") + end + + host = host or "%" + local query = str.template("DROP USER IF EXISTS '${username}'@'${host}'", { + username = username, + host = host + }) + return self:exec(query) +end + +function Connection:grant(privileges, database, table_name, username, host) + if str.is_blank(privileges) or str.is_blank(database) or str.is_blank(username) then + error("Privileges, database, and username cannot be empty") + end + + host = host or "%" + table_name = table_name or "*" + local object = str.template("${database}.${table}", {database = database, table = table_name}) + + local query = str.template("GRANT ${privileges} ON ${object} TO '${username}'@'${host}'", { + privileges = str.upper(privileges), + object = object, + username = username, + host = host + }) + return self:exec(query) +end + +function Connection:revoke(privileges, database, table_name, username, host) + if str.is_blank(privileges) or str.is_blank(database) or str.is_blank(username) then + error("Privileges, database, and username cannot be empty") + end + + host = host or "%" + table_name = table_name or "*" + local object = str.template("${database}.${table}", {database = database, table = table_name}) + + local query = str.template("REVOKE ${privileges} ON ${object} FROM '${username}'@'${host}'", { + privileges = str.upper(privileges), + object = object, + username = username, + host = host + }) + return self:exec(query) +end + +function Connection:flush_privileges() + return self:exec("FLUSH PRIVILEGES") +end + +-- Enhanced MySQL variables and configuration +function Connection:set_variable(name, value, global) + if str.is_blank(name) then + error("Variable name cannot be empty") + end + + local scope = global and "GLOBAL " or "SESSION " + return self:exec(str.template("SET ${scope}${name} = ?", { + scope = scope, + name = name + }), value) +end + +function Connection:get_variable(name, global) + if str.is_blank(name) then + error("Variable name cannot be empty") + end + + local scope = global and "global." or "session." + return self:query_value(str.template("SELECT @@${scope}${name}", { + scope = scope, + name = name + })) +end + +function Connection:show_variables(pattern) + if pattern and not str.is_blank(pattern) then + return self:query("SHOW VARIABLES LIKE ?", pattern) + else + return self:query("SHOW VARIABLES") + end +end + +function Connection:show_status(pattern) + if pattern and not str.is_blank(pattern) then + return self:query("SHOW STATUS LIKE ?", pattern) + else + return self:query("SHOW STATUS") + end +end + +-- Enhanced connection management +function mysql.connect(dsn) + if str.is_blank(dsn) then + error("DSN cannot be empty") + end + + local conn_id = moonshark.sql_connect("mysql", str.trim(dsn)) + if conn_id then + local conn = {_id = conn_id} + setmetatable(conn, Connection) + return conn + end + return nil +end + +mysql.open = mysql.connect + +-- Enhanced quick execution functions +function mysql.query(dsn, query_str, ...) + local conn = mysql.connect(dsn) + if not conn then + error("Failed to connect to MySQL database") + end + + local results = conn:query(query_str, ...) + conn:close() + return results +end + +function mysql.exec(dsn, query_str, ...) + local conn = mysql.connect(dsn) + if not conn then + error("Failed to connect to MySQL database") + end + + local result = conn:exec(query_str, ...) + conn:close() + return result +end + +function mysql.query_row(dsn, query_str, ...) + local results = mysql.query(dsn, query_str, ...) + if results and #results > 0 then + return results[1] + end + return nil +end + +function mysql.query_value(dsn, query_str, ...) + local row = mysql.query_row(dsn, query_str, ...) + if row then + for _, value in pairs(row) do + return value + end + end + return nil +end + +-- Enhanced migration helpers +function mysql.migrate(dsn, migrations, database_name) + local conn = mysql.connect(dsn) + if not conn then + error("Failed to connect to MySQL database for migration") + end + + -- Use specified database if provided + if database_name and not str.is_blank(database_name) then + conn:exec(str.template("USE ${database}", {database = database_name})) + end + + -- Create migrations table + conn:create_table("_migrations", + "id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) UNIQUE NOT NULL, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP") + + local tx = conn:begin() + if not tx then + conn:close() + error("Failed to begin migration transaction") + end + + local success = true + local error_msg = "" + + for _, migration in ipairs(migrations) do + if not migration.name or str.is_blank(migration.name) then + error_msg = "Migration must have a non-empty name" + success = false + break + end + + -- Check if migration already applied + local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?", + str.trim(migration.name)) + if not existing then + local ok, err = pcall(function() + if type(migration.up) == "string" then + conn:exec(migration.up) + elseif type(migration.up) == "function" then + migration.up(conn) + else + error("Migration 'up' must be string or function") + end + end) + + if ok then + conn:exec("INSERT INTO _migrations (name) VALUES (?)", str.trim(migration.name)) + print(str.template("Applied migration: ${name}", {name = migration.name})) + else + success = false + error_msg = str.template("Migration '${name}' failed: ${error}", { + name = migration.name, + error = err or "unknown error" + }) + break + end + end + end + + if success then + tx:commit() + else + tx:rollback() + conn:close() + error(error_msg) + end + + conn:close() + return true +end + +-- Result processing utilities (same enhanced versions) +function mysql.to_array(results, column_name) + if not results or #results == 0 then + return {} + end + + if str.is_blank(column_name) then + error("Column name cannot be empty") + end + + local array = {} + for i, row in ipairs(results) do + array[i] = row[column_name] + end + return array +end + +function mysql.to_map(results, key_column, value_column) + if not results or #results == 0 then + return {} + end + + if str.is_blank(key_column) then + error("Key column name cannot be empty") + end + + local map = {} + for _, row in ipairs(results) do + local key = row[key_column] + map[key] = value_column and row[value_column] or row + end + return map +end + +function mysql.group_by(results, column_name) + if not results or #results == 0 then + return {} + end + + if str.is_blank(column_name) then + error("Column name cannot be empty") + end + + local groups = {} + for _, row in ipairs(results) do + local key = row[column_name] + if not groups[key] then + groups[key] = {} + end + table.insert(groups[key], row) + end + return groups +end + +-- Enhanced debug helper (same as others) +function mysql.print_results(results) + if not results or #results == 0 then + print("No results") + return + end + + -- Get column names from first row + local columns = {} + for col, _ in pairs(results[1]) do + table.insert(columns, col) + end + table.sort(columns) + + -- Calculate column widths for better formatting + local widths = {} + for _, col in ipairs(columns) do + widths[col] = str.length(col) + end + + for _, row in ipairs(results) do + for _, col in ipairs(columns) do + local value = tostring(row[col] or "") + widths[col] = math.max(widths[col], str.length(value)) + end + end + + -- Print header with proper spacing + local header_parts = {} + local separator_parts = {} + for _, col in ipairs(columns) do + table.insert(header_parts, str.pad_right(col, widths[col])) + table.insert(separator_parts, str.repeat_("-", widths[col])) + end + + print(str.join(header_parts, " | ")) + print(str.join(separator_parts, "-+-")) + + -- Print rows with proper spacing + for _, row in ipairs(results) do + local value_parts = {} + for _, col in ipairs(columns) do + local value = tostring(row[col] or "") + table.insert(value_parts, str.pad_right(value, widths[col])) + end + print(str.join(value_parts, " | ")) + end +end + +-- Enhanced MySQL-specific utilities +function mysql.escape_string(str_val) + if type(str_val) ~= "string" then + return tostring(str_val) + end + -- Basic escaping - in production, use proper escaping functions + return str.replace(str_val, "'", "\\'") +end + +function mysql.escape_identifier(name) + if str.is_blank(name) then + error("Identifier name cannot be empty") + end + return str.template("`${name}`", {name = str.replace(name, "`", "``")}) +end + +-- Enhanced DSN builder helper +function mysql.build_dsn(options) + if type(options) ~= "table" then + error("Options must be a table") + end + + local parts = {} + + if options.username and not str.is_blank(options.username) then + table.insert(parts, options.username) + if options.password and not str.is_blank(options.password) then + parts[#parts] = str.template("${user}:${pass}", { + user = parts[#parts], + pass = options.password + }) + end + parts[#parts] = parts[#parts] .. "@" + end + + if options.protocol and not str.is_blank(options.protocol) then + table.insert(parts, str.template("${protocol}(", {protocol = options.protocol})) + if options.host and not str.is_blank(options.host) then + table.insert(parts, options.host) + if options.port then + parts[#parts] = str.template("${host}:${port}", { + host = parts[#parts], + port = tostring(options.port) + }) + end + end + parts[#parts] = parts[#parts] .. ")" + elseif options.host and not str.is_blank(options.host) then + local host_part = str.template("tcp(${host}", {host = options.host}) + if options.port then + host_part = str.template("${host}:${port}", { + host = host_part, + port = tostring(options.port) + }) + end + table.insert(parts, host_part .. ")") + end + + if options.database and not str.is_blank(options.database) then + table.insert(parts, str.template("/${database}", {database = options.database})) + end + + -- Add parameters + local params = {} + if options.charset and not str.is_blank(options.charset) then + table.insert(params, str.template("charset=${charset}", {charset = options.charset})) + end + if options.parseTime ~= nil then + table.insert(params, str.template("parseTime=${parse}", {parse = tostring(options.parseTime)})) + end + if options.timeout and not str.is_blank(options.timeout) then + table.insert(params, str.template("timeout=${timeout}", {timeout = options.timeout})) + end + if options.tls and not str.is_blank(options.tls) then + table.insert(params, str.template("tls=${tls}", {tls = options.tls})) + end + + if #params > 0 then + table.insert(parts, str.template("?${params}", {params = str.join(params, "&")})) + end + + return str.join(parts, "") +end + +return mysql diff --git a/modules/postgres/postgres.lua b/modules/postgres/postgres.lua new file mode 100644 index 0000000..a5318b7 --- /dev/null +++ b/modules/postgres/postgres.lua @@ -0,0 +1,847 @@ +local str = require("string") +local postgres = {} + +local Connection = {} +Connection.__index = Connection + +function Connection:close() + if self._id then + local ok = moonshark.sql_close(self._id) + self._id = nil + return ok + end + return false +end + +function Connection:ping() + if not self._id then + error("Connection is closed") + end + return moonshark.sql_ping(self._id) +end + +function Connection:query(query_str, ...) + if not self._id then + error("Connection is closed") + end + query_str = str.normalize_whitespace(query_str) + return moonshark.sql_query(self._id, query_str, ...) +end + +function Connection:exec(query_str, ...) + if not self._id then + error("Connection is closed") + end + query_str = str.normalize_whitespace(query_str) + return moonshark.sql_exec(self._id, query_str, ...) +end + +function Connection:query_row(query_str, ...) + local results = self:query(query_str, ...) + if results and #results > 0 then + return results[1] + end + return nil +end + +function Connection:query_value(query_str, ...) + local row = self:query_row(query_str, ...) + if row then + for _, value in pairs(row) do + return value + end + end + return nil +end + +-- Enhanced transaction support with savepoints +function Connection:begin() + local result = self:exec("BEGIN") + if result then + return { + conn = self, + active = true, + + commit = function(tx) + if tx.active then + local result = tx.conn:exec("COMMIT") + tx.active = false + return result + end + return false + end, + + rollback = function(tx) + if tx.active then + local result = tx.conn:exec("ROLLBACK") + tx.active = false + return result + end + return false + end, + + savepoint = function(tx, name) + if not tx.active then + error("Transaction is not active") + end + if str.is_blank(name) then + error("Savepoint name cannot be empty") + end + return tx.conn:exec(str.template("SAVEPOINT ${name}", {name = name})) + end, + + rollback_to = function(tx, name) + if not tx.active then + error("Transaction is not active") + end + if str.is_blank(name) then + error("Savepoint name cannot be empty") + end + return tx.conn:exec(str.template("ROLLBACK TO SAVEPOINT ${name}", {name = name})) + end, + + query = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query(query_str, ...) + end, + + exec = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:exec(query_str, ...) + end, + + query_row = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query_row(query_str, ...) + end, + + query_value = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query_value(query_str, ...) + end + } + end + return nil +end + +-- Enhanced query builder helpers with PostgreSQL parameter numbering +local function build_postgres_params(data) + local keys = {} + local values = {} + local placeholders = {} + local param_count = 0 + + for key, value in pairs(data) do + table.insert(keys, key) + table.insert(values, value) + param_count = param_count + 1 + table.insert(placeholders, str.template("$${num}", {num = tostring(param_count)})) + end + + return keys, values, placeholders, param_count +end + +function Connection:insert(table_name, data, returning) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local keys, values, placeholders = build_postgres_params(data) + + local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", { + table = table_name, + columns = str.join(keys, ", "), + placeholders = str.join(placeholders, ", ") + }) + + if returning and not str.is_blank(returning) then + query = str.template("${query} RETURNING ${returning}", { + query = query, + returning = returning + }) + return self:query(query, unpack(values)) + else + return self:exec(query, unpack(values)) + end +end + +function Connection:upsert(table_name, data, conflict_columns, returning) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local keys, values, placeholders = build_postgres_params(data) + local updates = {} + + for _, key in ipairs(keys) do + table.insert(updates, str.template("${key} = EXCLUDED.${key}", {key = key})) + end + + local conflict_clause = "" + if conflict_columns then + if type(conflict_columns) == "string" then + conflict_clause = str.template("(${columns})", {columns = conflict_columns}) + else + conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")}) + end + end + + local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", { + table = table_name, + columns = str.join(keys, ", "), + placeholders = str.join(placeholders, ", "), + conflict = conflict_clause, + updates = str.join(updates, ", ") + }) + + if returning and not str.is_blank(returning) then + query = str.template("${query} RETURNING ${returning}", { + query = query, + returning = returning + }) + return self:query(query, unpack(values)) + else + return self:exec(query, unpack(values)) + end +end + +function Connection:update(table_name, data, where_clause, returning, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + if str.is_blank(where_clause) then + error("WHERE clause cannot be empty for UPDATE") + end + + local sets = {} + local values = {} + local param_count = 0 + + for key, value in pairs(data) do + param_count = param_count + 1 + table.insert(sets, str.template("${key} = $${num}", { + key = key, + num = tostring(param_count) + })) + table.insert(values, value) + end + + -- Handle WHERE clause parameters + local where_args = {...} + local where_clause_with_params = where_clause + for i = 1, #where_args do + param_count = param_count + 1 + table.insert(values, where_args[i]) + -- Replace ? with numbered parameter if needed + where_clause_with_params = str.replace(where_clause_with_params, "?", + str.template("$${num}", {num = tostring(param_count)}), 1) + end + + local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", { + table = table_name, + sets = str.join(sets, ", "), + where = where_clause_with_params + }) + + if returning and not str.is_blank(returning) then + query = str.template("${query} RETURNING ${returning}", { + query = query, + returning = returning + }) + return self:query(query, unpack(values)) + else + return self:exec(query, unpack(values)) + end +end + +function Connection:delete(table_name, where_clause, returning, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + if str.is_blank(where_clause) then + error("WHERE clause cannot be empty for DELETE") + end + + -- Handle WHERE clause parameters + local where_args = {...} + local values = {} + local where_clause_with_params = where_clause + for i = 1, #where_args do + table.insert(values, where_args[i]) + where_clause_with_params = str.replace(where_clause_with_params, "?", + str.template("$${num}", {num = tostring(i)}), 1) + end + + local query = str.template("DELETE FROM ${table} WHERE ${where}", { + table = table_name, + where = where_clause_with_params + }) + + if returning and not str.is_blank(returning) then + query = str.template("${query} RETURNING ${returning}", { + query = query, + returning = returning + }) + return self:query(query, unpack(values)) + else + return self:exec(query, unpack(values)) + end +end + +function Connection:select(table_name, columns, where_clause, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + columns = columns or "*" + if type(columns) == "table" then + columns = str.join(columns, ", ") + end + + local query + if where_clause and not str.is_blank(where_clause) then + -- Handle WHERE clause parameters + local where_args = {...} + local values = {} + local where_clause_with_params = where_clause + for i = 1, #where_args do + table.insert(values, where_args[i]) + where_clause_with_params = str.replace(where_clause_with_params, "?", + str.template("$${num}", {num = tostring(i)}), 1) + end + + query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", { + columns = columns, + table = table_name, + where = where_clause_with_params + }) + return self:query(query, unpack(values)) + else + query = str.template("SELECT ${columns} FROM ${table}", { + columns = columns, + table = table_name + }) + return self:query(query) + end +end + +-- Enhanced PostgreSQL schema helpers +function Connection:table_exists(table_name, schema_name) + if str.is_blank(table_name) then + return false + end + + schema_name = schema_name or "public" + local result = self:query_value( + "SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2", + str.trim(schema_name), str.trim(table_name) + ) + return result ~= nil +end + +function Connection:column_exists(table_name, column_name, schema_name) + if str.is_blank(table_name) or str.is_blank(column_name) then + return false + end + + schema_name = schema_name or "public" + local result = self:query_value([[ + SELECT column_name FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 AND column_name = $3 + ]], str.trim(schema_name), str.trim(table_name), str.trim(column_name)) + return result ~= nil +end + +function Connection:create_table(table_name, schema) + if str.is_blank(table_name) or str.is_blank(schema) then + error("Table name and schema cannot be empty") + end + + local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", { + table = table_name, + schema = str.trim(schema) + }) + return self:exec(query) +end + +function Connection:drop_table(table_name, cascade) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local cascade_clause = cascade and " CASCADE" or "" + local query = str.template("DROP TABLE IF EXISTS ${table}${cascade}", { + table = table_name, + cascade = cascade_clause + }) + return self:exec(query) +end + +function Connection:add_column(table_name, column_def) + if str.is_blank(table_name) or str.is_blank(column_def) then + error("Table name and column definition cannot be empty") + end + + local query = str.template("ALTER TABLE ${table} ADD COLUMN IF NOT EXISTS ${column}", { + table = table_name, + column = str.trim(column_def) + }) + return self:exec(query) +end + +function Connection:drop_column(table_name, column_name, cascade) + if str.is_blank(table_name) or str.is_blank(column_name) then + error("Table name and column name cannot be empty") + end + + local cascade_clause = cascade and " CASCADE" or "" + local query = str.template("ALTER TABLE ${table} DROP COLUMN IF EXISTS ${column}${cascade}", { + table = table_name, + column = column_name, + cascade = cascade_clause + }) + return self:exec(query) +end + +function Connection:create_index(index_name, table_name, columns, unique, method) + if str.is_blank(index_name) or str.is_blank(table_name) then + error("Index name and table name cannot be empty") + end + + local unique_clause = unique and "UNIQUE " or "" + local method_clause = method and str.template(" USING ${method}", {method = str.upper(method)}) or "" + local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns) + + local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table}${method} (${columns})", { + unique = unique_clause, + index = index_name, + table = table_name, + method = method_clause, + columns = columns_str + }) + return self:exec(query) +end + +function Connection:drop_index(index_name, cascade) + if str.is_blank(index_name) then + error("Index name cannot be empty") + end + + local cascade_clause = cascade and " CASCADE" or "" + local query = str.template("DROP INDEX IF EXISTS ${index}${cascade}", { + index = index_name, + cascade = cascade_clause + }) + return self:exec(query) +end + +-- Enhanced PostgreSQL-specific functions +function Connection:vacuum(table_name, analyze) + local analyze_clause = analyze and " ANALYZE" or "" + local table_clause = table_name and str.template(" ${table}", {table = table_name}) or "" + return self:exec(str.template("VACUUM${analyze}${table}", { + analyze = analyze_clause, + table = table_clause + })) +end + +function Connection:analyze(table_name) + local table_clause = table_name and str.template(" ${table}", {table = table_name}) or "" + return self:exec(str.template("ANALYZE${table}", {table = table_clause})) +end + +function Connection:reindex(name, type) + if str.is_blank(name) then + error("Name cannot be empty for REINDEX") + end + + type = type or "INDEX" + local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"} + local type_upper = str.upper(type) + + local valid = false + for _, valid_type in ipairs(valid_types) do + if type_upper == valid_type then + valid = true + break + end + end + + if not valid then + error(str.template("Invalid REINDEX type: ${type}", {type = type})) + end + + return self:exec(str.template("REINDEX ${type} ${name}", { + type = type_upper, + name = name + })) +end + +-- PostgreSQL settings and introspection +function Connection:show(setting) + if str.is_blank(setting) then + error("Setting name cannot be empty") + end + return self:query_value(str.template("SHOW ${setting}", {setting = setting})) +end + +function Connection:set(setting, value) + if str.is_blank(setting) then + error("Setting name cannot be empty") + end + return self:exec(str.template("SET ${setting} = ${value}", { + setting = setting, + value = tostring(value) + })) +end + +function Connection:current_database() + return self:query_value("SELECT current_database()") +end + +function Connection:current_schema() + return self:query_value("SELECT current_schema()") +end + +function Connection:version() + return self:query_value("SELECT version()") +end + +function Connection:list_schemas() + return self:query("SELECT schema_name FROM information_schema.schemata ORDER BY schema_name") +end + +function Connection:list_tables(schema_name) + schema_name = schema_name or "public" + return self:query("SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename", + str.trim(schema_name)) +end + +function Connection:describe_table(table_name, schema_name) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + schema_name = schema_name or "public" + return self:query([[ + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + ORDER BY ordinal_position + ]], str.trim(schema_name), str.trim(table_name)) +end + +-- Enhanced JSON/JSONB helpers +function Connection:json_extract(column, path) + if str.is_blank(column) or str.is_blank(path) then + error("Column and path cannot be empty") + end + return str.template("${column}->'${path}'", {column = column, path = path}) +end + +function Connection:json_extract_text(column, path) + if str.is_blank(column) or str.is_blank(path) then + error("Column and path cannot be empty") + end + return str.template("${column}->>'${path}'", {column = column, path = path}) +end + +function Connection:jsonb_contains(column, value) + if str.is_blank(column) or str.is_blank(value) then + error("Column and value cannot be empty") + end + return str.template("${column} @> '${value}'", {column = column, value = value}) +end + +function Connection:jsonb_contained_by(column, value) + if str.is_blank(column) or str.is_blank(value) then + error("Column and value cannot be empty") + end + return str.template("${column} <@ '${value}'", {column = column, value = value}) +end + +-- Enhanced Array helpers +function Connection:array_contains(column, value) + if str.is_blank(column) then + error("Column cannot be empty") + end + return str.template("$1 = ANY(${column})", {column = column}) +end + +function Connection:array_length(column) + if str.is_blank(column) then + error("Column cannot be empty") + end + return str.template("array_length(${column}, 1)", {column = column}) +end + +-- Enhanced connection management with DSN parsing +function postgres.parse_dsn(dsn) + if str.is_blank(dsn) then + return nil, "DSN cannot be empty" + end + + local parts = {} + + -- Split by spaces and handle key=value pairs + for pair in str.trim(dsn):gmatch("[^%s]+") do + local key, value = pair:match("([^=]+)=(.+)") + if key and value then + parts[str.trim(key)] = str.trim(value) + end + end + + return parts +end + +function postgres.connect(dsn) + if str.is_blank(dsn) then + error("DSN cannot be empty") + end + + local conn_id = moonshark.sql_connect("postgres", str.trim(dsn)) + if conn_id then + local conn = {_id = conn_id} + setmetatable(conn, Connection) + return conn + end + return nil +end + +postgres.open = postgres.connect + +-- Enhanced quick execution functions +function postgres.query(dsn, query_str, ...) + local conn = postgres.connect(dsn) + if not conn then + error("Failed to connect to PostgreSQL database") + end + + local results = conn:query(query_str, ...) + conn:close() + return results +end + +function postgres.exec(dsn, query_str, ...) + local conn = postgres.connect(dsn) + if not conn then + error("Failed to connect to PostgreSQL database") + end + + local result = conn:exec(query_str, ...) + conn:close() + return result +end + +function postgres.query_row(dsn, query_str, ...) + local results = postgres.query(dsn, query_str, ...) + if results and #results > 0 then + return results[1] + end + return nil +end + +function postgres.query_value(dsn, query_str, ...) + local row = postgres.query_row(dsn, query_str, ...) + if row then + for _, value in pairs(row) do + return value + end + end + return nil +end + +-- Enhanced migration helpers +function postgres.migrate(dsn, migrations, schema) + schema = schema or "public" + local conn = postgres.connect(dsn) + if not conn then + error("Failed to connect to PostgreSQL database for migration") + end + + -- Create migrations table + conn:create_table("_migrations", + "id SERIAL PRIMARY KEY, name TEXT UNIQUE NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW()") + + local tx = conn:begin() + if not tx then + conn:close() + error("Failed to begin migration transaction") + end + + local success = true + local error_msg = "" + + for _, migration in ipairs(migrations) do + if not migration.name or str.is_blank(migration.name) then + error_msg = "Migration must have a non-empty name" + success = false + break + end + + -- Check if migration already applied + local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1", + str.trim(migration.name)) + if not existing then + local ok, err = pcall(function() + if type(migration.up) == "string" then + conn:exec(migration.up) + elseif type(migration.up) == "function" then + migration.up(conn) + else + error("Migration 'up' must be string or function") + end + end) + + if ok then + conn:exec("INSERT INTO _migrations (name) VALUES ($1)", str.trim(migration.name)) + print(str.template("Applied migration: ${name}", {name = migration.name})) + else + success = false + error_msg = str.template("Migration '${name}' failed: ${error}", { + name = migration.name, + error = err or "unknown error" + }) + break + end + end + end + + if success then + tx:commit() + else + tx:rollback() + conn:close() + error(error_msg) + end + + conn:close() + return true +end + +-- Result processing utilities (same enhanced versions as SQLite) +function postgres.to_array(results, column_name) + if not results or #results == 0 then + return {} + end + + if str.is_blank(column_name) then + error("Column name cannot be empty") + end + + local array = {} + for i, row in ipairs(results) do + array[i] = row[column_name] + end + return array +end + +function postgres.to_map(results, key_column, value_column) + if not results or #results == 0 then + return {} + end + + if str.is_blank(key_column) then + error("Key column name cannot be empty") + end + + local map = {} + for _, row in ipairs(results) do + local key = row[key_column] + map[key] = value_column and row[value_column] or row + end + return map +end + +function postgres.group_by(results, column_name) + if not results or #results == 0 then + return {} + end + + if str.is_blank(column_name) then + error("Column name cannot be empty") + end + + local groups = {} + for _, row in ipairs(results) do + local key = row[column_name] + if not groups[key] then + groups[key] = {} + end + table.insert(groups[key], row) + end + return groups +end + +-- Enhanced debug helper (same as SQLite) +function postgres.print_results(results) + if not results or #results == 0 then + print("No results") + return + end + + -- Get column names from first row + local columns = {} + for col, _ in pairs(results[1]) do + table.insert(columns, col) + end + table.sort(columns) + + -- Calculate column widths for better formatting + local widths = {} + for _, col in ipairs(columns) do + widths[col] = str.length(col) + end + + for _, row in ipairs(results) do + for _, col in ipairs(columns) do + local value = tostring(row[col] or "") + widths[col] = math.max(widths[col], str.length(value)) + end + end + + -- Print header with proper spacing + local header_parts = {} + local separator_parts = {} + for _, col in ipairs(columns) do + table.insert(header_parts, str.pad_right(col, widths[col])) + table.insert(separator_parts, str.repeat_("-", widths[col])) + end + + print(str.join(header_parts, " | ")) + print(str.join(separator_parts, "-+-")) + + -- Print rows with proper spacing + for _, row in ipairs(results) do + local value_parts = {} + for _, col in ipairs(columns) do + local value = tostring(row[col] or "") + table.insert(value_parts, str.pad_right(value, widths[col])) + end + print(str.join(value_parts, " | ")) + end +end + +function postgres.escape_identifier(name) + if str.is_blank(name) then + error("Identifier name cannot be empty") + end + return str.template('"${name}"', {name = str.replace(name, '"', '""')}) +end + +function postgres.escape_literal(value) + if type(value) == "string" then + return str.template("'${value}'", {value = str.replace(value, "'", "''")}) + end + return tostring(value) +end + +return postgres diff --git a/modules/registry.go b/modules/registry.go index c54643d..984e896 100644 --- a/modules/registry.go +++ b/modules/registry.go @@ -9,6 +9,7 @@ import ( "Moonshark/modules/fs" "Moonshark/modules/http" "Moonshark/modules/math" + "Moonshark/modules/sql" lua_string "Moonshark/modules/string" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" @@ -17,7 +18,7 @@ import ( // Global registry instance var Global *Registry -//go:embed crypto/*.lua fs/*.lua json/*.lua math/*.lua string/*.lua http/*.lua +//go:embed **/*.lua var embeddedModules embed.FS // Registry manages all Lua modules and Go functions @@ -39,6 +40,7 @@ func New() *Registry { maps.Copy(r.goFuncs, crypto.GetFunctionList()) maps.Copy(r.goFuncs, fs.GetFunctionList()) maps.Copy(r.goFuncs, http.GetFunctionList()) + maps.Copy(r.goFuncs, sql.GetFunctionList()) r.loadEmbeddedModules() return r diff --git a/modules/sql/mysql.go b/modules/sql/mysql.go new file mode 100644 index 0000000..637979d --- /dev/null +++ b/modules/sql/mysql.go @@ -0,0 +1,205 @@ +package sql + +import ( + "context" + "database/sql" + "fmt" + + _ "github.com/go-sql-driver/mysql" +) + +// MySQLDriver implements the Driver interface for MySQL +type MySQLDriver struct{} + +func (d *MySQLDriver) Name() string { + return "mysql" +} + +func (d *MySQLDriver) Open(dsn string) (Connection, error) { + db, err := sql.Open("mysql", dsn) + if err != nil { + return nil, fmt.Errorf("mysql: failed to open database: %w", err) + } + + // Test the connection + if err := db.Ping(); err != nil { + db.Close() + return nil, fmt.Errorf("mysql: failed to ping database: %w", err) + } + + return &MySQLConnection{db: db}, nil +} + +// MySQLConnection implements the Connection interface +type MySQLConnection struct { + db *sql.DB +} + +func (c *MySQLConnection) Close() error { + return c.db.Close() +} + +func (c *MySQLConnection) Ping(ctx context.Context) error { + return c.db.PingContext(ctx) +} + +func (c *MySQLConnection) Begin(ctx context.Context) (Transaction, error) { + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("mysql: failed to begin transaction: %w", err) + } + return &MySQLTransaction{tx: tx}, nil +} + +func (c *MySQLConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) { + rows, err := c.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("mysql: query failed: %w", err) + } + return &MySQLRows{rows: rows}, nil +} + +func (c *MySQLConnection) QueryRow(ctx context.Context, query string, args ...any) Row { + row := c.db.QueryRowContext(ctx, query, args...) + return &MySQLRow{row: row} +} + +func (c *MySQLConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) { + result, err := c.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("mysql: exec failed: %w", err) + } + return &MySQLResult{result: result}, nil +} + +func (c *MySQLConnection) Prepare(ctx context.Context, query string) (Statement, error) { + stmt, err := c.db.PrepareContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("mysql: failed to prepare statement: %w", err) + } + return &MySQLStatement{stmt: stmt}, nil +} + +// MySQLTransaction implements the Transaction interface +type MySQLTransaction struct { + tx *sql.Tx +} + +func (t *MySQLTransaction) Commit() error { + return t.tx.Commit() +} + +func (t *MySQLTransaction) Rollback() error { + return t.tx.Rollback() +} + +func (t *MySQLTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) { + rows, err := t.tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("mysql: transaction query failed: %w", err) + } + return &MySQLRows{rows: rows}, nil +} + +func (t *MySQLTransaction) QueryRow(ctx context.Context, query string, args ...any) Row { + row := t.tx.QueryRowContext(ctx, query, args...) + return &MySQLRow{row: row} +} + +func (t *MySQLTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) { + result, err := t.tx.ExecContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("mysql: transaction exec failed: %w", err) + } + return &MySQLResult{result: result}, nil +} + +func (t *MySQLTransaction) Prepare(ctx context.Context, query string) (Statement, error) { + stmt, err := t.tx.PrepareContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("mysql: failed to prepare transaction statement: %w", err) + } + return &MySQLStatement{stmt: stmt}, nil +} + +// MySQLRows implements the Rows interface +type MySQLRows struct { + rows *sql.Rows +} + +func (r *MySQLRows) Next() bool { + return r.rows.Next() +} + +func (r *MySQLRows) Scan(dest ...any) error { + return r.rows.Scan(dest...) +} + +func (r *MySQLRows) Columns() ([]string, error) { + return r.rows.Columns() +} + +func (r *MySQLRows) Close() error { + return r.rows.Close() +} + +func (r *MySQLRows) Err() error { + return r.rows.Err() +} + +// MySQLRow implements the Row interface +type MySQLRow struct { + row *sql.Row +} + +func (r *MySQLRow) Scan(dest ...any) error { + return r.row.Scan(dest...) +} + +// MySQLResult implements the Result interface +type MySQLResult struct { + result sql.Result +} + +func (r *MySQLResult) LastInsertId() (int64, error) { + return r.result.LastInsertId() +} + +func (r *MySQLResult) RowsAffected() (int64, error) { + return r.result.RowsAffected() +} + +// MySQLStatement implements the Statement interface +type MySQLStatement struct { + stmt *sql.Stmt +} + +func (s *MySQLStatement) Close() error { + return s.stmt.Close() +} + +func (s *MySQLStatement) Query(ctx context.Context, args ...any) (Rows, error) { + rows, err := s.stmt.QueryContext(ctx, args...) + if err != nil { + return nil, fmt.Errorf("mysql: statement query failed: %w", err) + } + return &MySQLRows{rows: rows}, nil +} + +func (s *MySQLStatement) QueryRow(ctx context.Context, args ...any) Row { + row := s.stmt.QueryRowContext(ctx, args...) + return &MySQLRow{row: row} +} + +func (s *MySQLStatement) Exec(ctx context.Context, args ...any) (Result, error) { + result, err := s.stmt.ExecContext(ctx, args...) + if err != nil { + return nil, fmt.Errorf("mysql: statement exec failed: %w", err) + } + return &MySQLResult{result: result}, nil +} + +func init() { + // Register MySQL driver on import + RegisterDriver("mysql", &MySQLDriver{}) +} diff --git a/modules/sql/postgres.go b/modules/sql/postgres.go new file mode 100644 index 0000000..d98cc39 --- /dev/null +++ b/modules/sql/postgres.go @@ -0,0 +1,234 @@ +package sql + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" +) + +// PostgresDriver implements the Driver interface for PostgreSQL +type PostgresDriver struct{} + +func (d *PostgresDriver) Name() string { + return "postgres" +} + +func (d *PostgresDriver) Open(dsn string) (Connection, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("postgres: failed to parse config: %w", err) + } + + pool, err := pgxpool.NewWithConfig(context.Background(), config) + if err != nil { + return nil, fmt.Errorf("postgres: failed to create pool: %w", err) + } + + return &PostgresConnection{pool: pool}, nil +} + +// PostgresConnection implements the Connection interface +type PostgresConnection struct { + pool *pgxpool.Pool +} + +func (c *PostgresConnection) Close() error { + c.pool.Close() + return nil +} + +func (c *PostgresConnection) Ping(ctx context.Context) error { + return c.pool.Ping(ctx) +} + +func (c *PostgresConnection) Begin(ctx context.Context) (Transaction, error) { + tx, err := c.pool.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("postgres: failed to begin transaction: %w", err) + } + return &PostgresTransaction{tx: tx}, nil +} + +func (c *PostgresConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) { + rows, err := c.pool.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("postgres: query failed: %w", err) + } + return &PostgresRows{rows: rows}, nil +} + +func (c *PostgresConnection) QueryRow(ctx context.Context, query string, args ...any) Row { + row := c.pool.QueryRow(ctx, query, args...) + return &PostgresRow{row: row} +} + +func (c *PostgresConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) { + tag, err := c.pool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("postgres: exec failed: %w", err) + } + return &PostgresResult{tag: tag}, nil +} + +func (c *PostgresConnection) Prepare(ctx context.Context, query string) (Statement, error) { + // pgx doesn't have explicit prepared statements like database/sql + // We'll store the query and use it with the pool + return &PostgresStatement{pool: c.pool, query: query}, nil +} + +// PostgresTransaction implements the Transaction interface +type PostgresTransaction struct { + tx pgx.Tx +} + +func (t *PostgresTransaction) Commit() error { + return t.tx.Commit(context.Background()) +} + +func (t *PostgresTransaction) Rollback() error { + return t.tx.Rollback(context.Background()) +} + +func (t *PostgresTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) { + rows, err := t.tx.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("postgres: transaction query failed: %w", err) + } + return &PostgresRows{rows: rows}, nil +} + +func (t *PostgresTransaction) QueryRow(ctx context.Context, query string, args ...any) Row { + row := t.tx.QueryRow(ctx, query, args...) + return &PostgresRow{row: row} +} + +func (t *PostgresTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) { + tag, err := t.tx.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("postgres: transaction exec failed: %w", err) + } + return &PostgresResult{tag: tag}, nil +} + +func (t *PostgresTransaction) Prepare(ctx context.Context, query string) (Statement, error) { + return &PostgresStatement{tx: t.tx, query: query}, nil +} + +// PostgresRows implements the Rows interface +type PostgresRows struct { + rows pgx.Rows +} + +func (r *PostgresRows) Next() bool { + return r.rows.Next() +} + +func (r *PostgresRows) Scan(dest ...any) error { + return r.rows.Scan(dest...) +} + +func (r *PostgresRows) Columns() ([]string, error) { + fields := r.rows.FieldDescriptions() + columns := make([]string, len(fields)) + for i, field := range fields { + columns[i] = field.Name + } + return columns, nil +} + +func (r *PostgresRows) Close() error { + r.rows.Close() + return nil +} + +func (r *PostgresRows) Err() error { + return r.rows.Err() +} + +// PostgresRow implements the Row interface +type PostgresRow struct { + row pgx.Row +} + +func (r *PostgresRow) Scan(dest ...any) error { + return r.row.Scan(dest...) +} + +// PostgresResult implements the Result interface +type PostgresResult struct { + tag pgconn.CommandTag +} + +func (r *PostgresResult) LastInsertId() (int64, error) { + // PostgreSQL doesn't have AUTO_INCREMENT like MySQL + // Users should use RETURNING clause or sequences + return 0, fmt.Errorf("postgres: LastInsertId not supported, use RETURNING clause") +} + +func (r *PostgresResult) RowsAffected() (int64, error) { + return r.tag.RowsAffected(), nil +} + +// PostgresStatement implements the Statement interface +type PostgresStatement struct { + pool *pgxpool.Pool + tx pgx.Tx + query string +} + +func (s *PostgresStatement) Close() error { + // pgx doesn't require explicit statement cleanup + return nil +} + +func (s *PostgresStatement) Query(ctx context.Context, args ...any) (Rows, error) { + var rows pgx.Rows + var err error + + if s.tx != nil { + rows, err = s.tx.Query(ctx, s.query, args...) + } else { + rows, err = s.pool.Query(ctx, s.query, args...) + } + + if err != nil { + return nil, fmt.Errorf("postgres: statement query failed: %w", err) + } + return &PostgresRows{rows: rows}, nil +} + +func (s *PostgresStatement) QueryRow(ctx context.Context, args ...any) Row { + var row pgx.Row + + if s.tx != nil { + row = s.tx.QueryRow(ctx, s.query, args...) + } else { + row = s.pool.QueryRow(ctx, s.query, args...) + } + + return &PostgresRow{row: row} +} + +func (s *PostgresStatement) Exec(ctx context.Context, args ...any) (Result, error) { + var tag pgconn.CommandTag + var err error + + if s.tx != nil { + tag, err = s.tx.Exec(ctx, s.query, args...) + } else { + tag, err = s.pool.Exec(ctx, s.query, args...) + } + + if err != nil { + return nil, fmt.Errorf("postgres: statement exec failed: %w", err) + } + return &PostgresResult{tag: tag}, nil +} + +func init() { + // Register PostgreSQL driver on import + RegisterDriver("postgres", &PostgresDriver{}) +} diff --git a/modules/sql/sql.go b/modules/sql/sql.go new file mode 100644 index 0000000..febcd38 --- /dev/null +++ b/modules/sql/sql.go @@ -0,0 +1,367 @@ +package sql + +import ( + "context" + "fmt" + "sync" + "time" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Driver interface for SQL database implementations +type Driver interface { + Open(dsn string) (Connection, error) + Name() string +} + +// Connection represents a database connection +type Connection interface { + Close() error + Ping(ctx context.Context) error + Begin(ctx context.Context) (Transaction, error) + Query(ctx context.Context, query string, args ...any) (Rows, error) + QueryRow(ctx context.Context, query string, args ...any) Row + Exec(ctx context.Context, query string, args ...any) (Result, error) + Prepare(ctx context.Context, query string) (Statement, error) +} + +// Transaction represents a database transaction +type Transaction interface { + Commit() error + Rollback() error + Query(ctx context.Context, query string, args ...any) (Rows, error) + QueryRow(ctx context.Context, query string, args ...any) Row + Exec(ctx context.Context, query string, args ...any) (Result, error) + Prepare(ctx context.Context, query string) (Statement, error) +} + +// Rows represents query result rows +type Rows interface { + Next() bool + Scan(dest ...any) error + Columns() ([]string, error) + Close() error + Err() error +} + +// Row represents a single query result row +type Row interface { + Scan(dest ...any) error +} + +// Result represents the result of an executed statement +type Result interface { + LastInsertId() (int64, error) + RowsAffected() (int64, error) +} + +// Statement represents a prepared statement +type Statement interface { + Close() error + Query(ctx context.Context, args ...any) (Rows, error) + QueryRow(ctx context.Context, args ...any) Row + Exec(ctx context.Context, args ...any) (Result, error) +} + +// Registry manages database drivers and connections +type Registry struct { + mu sync.RWMutex + drivers map[string]Driver + conns map[string]Connection + nextID int +} + +var global = &Registry{ + drivers: make(map[string]Driver), + conns: make(map[string]Connection), +} + +// RegisterDriver registers a database driver +func RegisterDriver(name string, driver Driver) { + global.mu.Lock() + defer global.mu.Unlock() + global.drivers[name] = driver +} + +// GetDriver returns a registered driver +func GetDriver(name string) (Driver, bool) { + global.mu.RLock() + defer global.mu.RUnlock() + driver, exists := global.drivers[name] + return driver, exists +} + +// Connect opens a new database connection +func Connect(driverName, dsn string) (string, error) { + driver, exists := GetDriver(driverName) + if !exists { + return "", fmt.Errorf("unknown driver: %s", driverName) + } + + conn, err := driver.Open(dsn) + if err != nil { + return "", err + } + + global.mu.Lock() + defer global.mu.Unlock() + + id := fmt.Sprintf("%s_%d", driverName, global.nextID) + global.nextID++ + global.conns[id] = conn + + return id, nil +} + +// GetConnection retrieves a connection by ID +func GetConnection(id string) (Connection, bool) { + global.mu.RLock() + defer global.mu.RUnlock() + conn, exists := global.conns[id] + return conn, exists +} + +// CloseConnection closes and removes a connection +func CloseConnection(id string) error { + global.mu.Lock() + defer global.mu.Unlock() + + conn, exists := global.conns[id] + if !exists { + return fmt.Errorf("connection not found: %s", id) + } + + err := conn.Close() + delete(global.conns, id) + return err +} + +// Lua function implementations + +func luaConnect(s *luajit.State) int { + if err := s.CheckExactArgs(2); err != nil { + return s.PushError("connect: %v", err) + } + + driver, err := s.SafeToString(1) + if err != nil { + return s.PushError("connect: driver must be a string") + } + + dsn, err := s.SafeToString(2) + if err != nil { + return s.PushError("connect: dsn must be a string") + } + + connID, err := Connect(driver, dsn) + if err != nil { + return s.PushError("connect: %v", err) + } + + s.PushString(connID) + return 1 +} + +func luaClose(s *luajit.State) int { + if err := s.CheckExactArgs(1); err != nil { + return s.PushError("close: %v", err) + } + + connID, err := s.SafeToString(1) + if err != nil { + return s.PushError("close: connection id must be a string") + } + + if err := CloseConnection(connID); err != nil { + return s.PushError("close: %v", err) + } + + s.PushBoolean(true) + return 1 +} + +func luaPing(s *luajit.State) int { + if err := s.CheckExactArgs(1); err != nil { + return s.PushError("ping: %v", err) + } + + connID, err := s.SafeToString(1) + if err != nil { + return s.PushError("ping: connection id must be a string") + } + + conn, exists := GetConnection(connID) + if !exists { + return s.PushError("ping: connection not found") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := conn.Ping(ctx); err != nil { + return s.PushError("ping: %v", err) + } + + s.PushBoolean(true) + return 1 +} + +func luaQuery(s *luajit.State) int { + if err := s.CheckMinArgs(2); err != nil { + return s.PushError("query: %v", err) + } + + connID, err := s.SafeToString(1) + if err != nil { + return s.PushError("query: connection id must be a string") + } + + query, err := s.SafeToString(2) + if err != nil { + return s.PushError("query: query must be a string") + } + + conn, exists := GetConnection(connID) + if !exists { + return s.PushError("query: connection not found") + } + + // Collect arguments + args := make([]any, s.GetTop()-2) + for i := 3; i <= s.GetTop(); i++ { + val, err := s.ToValue(i) + if err != nil { + args[i-3] = nil + } else { + args[i-3] = val + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + rows, err := conn.Query(ctx, query, args...) + if err != nil { + return s.PushError("query: %v", err) + } + defer rows.Close() + + // Get column names + columns, err := rows.Columns() + if err != nil { + return s.PushError("query: failed to get columns: %v", err) + } + + // Build result array + s.CreateTable(0, 0) + rowIndex := 1 + + for rows.Next() { + // Create values slice for scanning + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + return s.PushError("query: scan error: %v", err) + } + + // Create row table + s.CreateTable(0, len(columns)) + for i, col := range columns { + s.PushString(col) + if err := s.PushValue(values[i]); err != nil { + s.PushNil() + } + s.SetTable(-3) + } + + // Add to result array + s.PushNumber(float64(rowIndex)) + s.PushCopy(-2) + s.SetTable(-4) + s.Pop(1) // Remove row table copy + + rowIndex++ + } + + if err := rows.Err(); err != nil { + return s.PushError("query: %v", err) + } + + return 1 +} + +func luaExec(s *luajit.State) int { + if err := s.CheckMinArgs(2); err != nil { + return s.PushError("exec: %v", err) + } + + connID, err := s.SafeToString(1) + if err != nil { + return s.PushError("exec: connection id must be a string") + } + + query, err := s.SafeToString(2) + if err != nil { + return s.PushError("exec: query must be a string") + } + + conn, exists := GetConnection(connID) + if !exists { + return s.PushError("exec: connection not found") + } + + // Collect arguments + args := make([]any, s.GetTop()-2) + for i := 3; i <= s.GetTop(); i++ { + val, err := s.ToValue(i) + if err != nil { + args[i-3] = nil + } else { + args[i-3] = val + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := conn.Exec(ctx, query, args...) + if err != nil { + return s.PushError("exec: %v", err) + } + + // Return result info + s.CreateTable(0, 2) + + lastID, _ := result.LastInsertId() + s.PushString("last_insert_id") + s.PushNumber(float64(lastID)) + s.SetTable(-3) + + affected, _ := result.RowsAffected() + s.PushString("rows_affected") + s.PushNumber(float64(affected)) + s.SetTable(-3) + + return 1 +} + +// GetFunctionList returns all Lua-callable functions +func GetFunctionList() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "sql_connect": luaConnect, + "sql_close": luaClose, + "sql_ping": luaPing, + "sql_query": luaQuery, + "sql_exec": luaExec, + } +} + +func init() { + // Register SQLite driver on import + RegisterDriver("sqlite", &SQLiteDriver{}) +} diff --git a/modules/sql/sqlite.go b/modules/sql/sqlite.go new file mode 100644 index 0000000..fcc7f6c --- /dev/null +++ b/modules/sql/sqlite.go @@ -0,0 +1,384 @@ +package sql + +import ( + "context" + "fmt" + + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" +) + +// SQLiteDriver implements the Driver interface for SQLite +type SQLiteDriver struct{} + +func (d *SQLiteDriver) Name() string { + return "sqlite" +} + +func (d *SQLiteDriver) Open(dsn string) (Connection, error) { + conn, err := sqlite.OpenConn(dsn, sqlite.OpenReadWrite|sqlite.OpenCreate) + if err != nil { + return nil, fmt.Errorf("sqlite: failed to open database: %w", err) + } + + return &SQLiteConnection{conn: conn}, nil +} + +// SQLiteConnection implements the Connection interface +type SQLiteConnection struct { + conn *sqlite.Conn +} + +func (c *SQLiteConnection) Close() error { + return c.conn.Close() +} + +func (c *SQLiteConnection) Ping(ctx context.Context) error { + return sqlitex.ExecuteTransient(c.conn, "SELECT 1", nil) +} + +func (c *SQLiteConnection) Begin(ctx context.Context) (Transaction, error) { + if err := sqlitex.ExecuteTransient(c.conn, "BEGIN", nil); err != nil { + return nil, fmt.Errorf("sqlite: failed to begin transaction: %w", err) + } + return &SQLiteTransaction{conn: c.conn}, nil +} + +func (c *SQLiteConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) { + stmt, err := c.conn.Prepare(query) + if err != nil { + return nil, fmt.Errorf("sqlite: failed to prepare query: %w", err) + } + + if err := c.bindArgs(stmt, args...); err != nil { + stmt.Finalize() + return nil, err + } + + return &SQLiteRows{stmt: stmt, hasNext: true}, nil +} + +func (c *SQLiteConnection) QueryRow(ctx context.Context, query string, args ...any) Row { + rows, err := c.Query(ctx, query, args...) + if err != nil { + return &SQLiteRow{err: err} + } + return &SQLiteRow{rows: rows.(*SQLiteRows)} +} + +func (c *SQLiteConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) { + stmt, err := c.conn.Prepare(query) + if err != nil { + return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err) + } + defer stmt.Finalize() + + if err := c.bindArgs(stmt, args...); err != nil { + return nil, err + } + + hasRow, err := stmt.Step() + if err != nil { + return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err) + } + + // Consume all rows if any + for hasRow { + hasRow, err = stmt.Step() + if err != nil { + return nil, fmt.Errorf("sqlite: error stepping through results: %w", err) + } + } + + return &SQLiteResult{ + lastInsertID: c.conn.LastInsertRowID(), + rowsAffected: c.conn.Changes(), + }, nil +} + +func (c *SQLiteConnection) Prepare(ctx context.Context, query string) (Statement, error) { + stmt, err := c.conn.Prepare(query) + if err != nil { + return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err) + } + return &SQLiteStatement{stmt: stmt, conn: c.conn}, nil +} + +func (c *SQLiteConnection) bindArgs(stmt *sqlite.Stmt, args ...any) error { + for i, arg := range args { + paramIndex := i + 1 + + if arg == nil { + stmt.BindNull(paramIndex) + continue + } + + switch v := arg.(type) { + case int: + stmt.BindInt64(paramIndex, int64(v)) + case int64: + stmt.BindInt64(paramIndex, v) + case float64: + stmt.BindFloat(paramIndex, v) + case string: + stmt.BindText(paramIndex, v) + case bool: + if v { + stmt.BindInt64(paramIndex, 1) + } else { + stmt.BindInt64(paramIndex, 0) + } + case []byte: + stmt.BindBytes(paramIndex, v) + default: + return fmt.Errorf("sqlite: unsupported parameter type: %T", arg) + } + } + return nil +} + +// SQLiteTransaction implements the Transaction interface +type SQLiteTransaction struct { + conn *sqlite.Conn +} + +func (t *SQLiteTransaction) Commit() error { + return sqlitex.ExecuteTransient(t.conn, "COMMIT", nil) +} + +func (t *SQLiteTransaction) Rollback() error { + return sqlitex.ExecuteTransient(t.conn, "ROLLBACK", nil) +} + +func (t *SQLiteTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) { + conn := &SQLiteConnection{conn: t.conn} + return conn.Query(ctx, query, args...) +} + +func (t *SQLiteTransaction) QueryRow(ctx context.Context, query string, args ...any) Row { + conn := &SQLiteConnection{conn: t.conn} + return conn.QueryRow(ctx, query, args...) +} + +func (t *SQLiteTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) { + conn := &SQLiteConnection{conn: t.conn} + return conn.Exec(ctx, query, args...) +} + +func (t *SQLiteTransaction) Prepare(ctx context.Context, query string) (Statement, error) { + conn := &SQLiteConnection{conn: t.conn} + return conn.Prepare(ctx, query) +} + +// SQLiteRows implements the Rows interface +type SQLiteRows struct { + stmt *sqlite.Stmt + hasNext bool + err error +} + +func (r *SQLiteRows) Next() bool { + if r.err != nil { + return false + } + + if !r.hasNext { + return false + } + + var err error + r.hasNext, err = r.stmt.Step() + if err != nil { + r.err = err + return false + } + + return r.hasNext +} + +func (r *SQLiteRows) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + + for i, d := range dest { + if i >= r.stmt.ColumnCount() { + break + } + + switch ptr := d.(type) { + case *any: + *ptr = r.getValue(i) + case *string: + *ptr = r.stmt.ColumnText(i) + case *int: + *ptr = int(r.stmt.ColumnInt64(i)) + case *int64: + *ptr = r.stmt.ColumnInt64(i) + case *float64: + *ptr = r.stmt.ColumnFloat(i) + case *bool: + *ptr = r.stmt.ColumnInt64(i) != 0 + case *[]byte: + if r.stmt.ColumnType(i) == sqlite.TypeBlob { + // Get blob size first + size := r.stmt.ColumnBytes(i, nil) + if size == 0 { + *ptr = []byte{} + } else { + buf := make([]byte, size) + r.stmt.ColumnBytes(i, buf) + *ptr = buf + } + } else { + // Convert text to bytes + *ptr = []byte(r.stmt.ColumnText(i)) + } + default: + return fmt.Errorf("sqlite: unsupported scan destination type: %T", d) + } + } + + return nil +} + +func (r *SQLiteRows) getValue(index int) any { + switch r.stmt.ColumnType(index) { + case sqlite.TypeInteger: + return r.stmt.ColumnInt64(index) + case sqlite.TypeFloat: + return r.stmt.ColumnFloat(index) + case sqlite.TypeText: + return r.stmt.ColumnText(index) + case sqlite.TypeBlob: + // For blob columns, we need to handle this differently + // First, get the size by calling with nil buffer + size := r.stmt.ColumnBytes(index, nil) + if size == 0 { + return []byte{} + } + // Now allocate buffer and get the actual data + buf := make([]byte, size) + r.stmt.ColumnBytes(index, buf) + return buf + case sqlite.TypeNull: + return nil + default: + return r.stmt.ColumnText(index) + } +} + +func (r *SQLiteRows) Columns() ([]string, error) { + if r.err != nil { + return nil, r.err + } + + columns := make([]string, r.stmt.ColumnCount()) + for i := range columns { + columns[i] = r.stmt.ColumnName(i) + } + + return columns, nil +} + +func (r *SQLiteRows) Close() error { + if r.stmt != nil { + return r.stmt.Finalize() + } + return nil +} + +func (r *SQLiteRows) Err() error { + return r.err +} + +// SQLiteRow implements the Row interface +type SQLiteRow struct { + rows *SQLiteRows + err error +} + +func (r *SQLiteRow) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + + if r.rows == nil { + return fmt.Errorf("sqlite: no rows available") + } + + if !r.rows.Next() { + if r.rows.Err() != nil { + return r.rows.Err() + } + return fmt.Errorf("sqlite: no rows in result set") + } + + return r.rows.Scan(dest...) +} + +// SQLiteResult implements the Result interface +type SQLiteResult struct { + lastInsertID int64 + rowsAffected int +} + +func (r *SQLiteResult) LastInsertId() (int64, error) { + return r.lastInsertID, nil +} + +func (r *SQLiteResult) RowsAffected() (int64, error) { + return int64(r.rowsAffected), nil +} + +// SQLiteStatement implements the Statement interface +type SQLiteStatement struct { + stmt *sqlite.Stmt + conn *sqlite.Conn +} + +func (s *SQLiteStatement) Close() error { + return s.stmt.Finalize() +} + +func (s *SQLiteStatement) Query(ctx context.Context, args ...any) (Rows, error) { + conn := &SQLiteConnection{conn: s.conn} + if err := conn.bindArgs(s.stmt, args...); err != nil { + return nil, err + } + + return &SQLiteRows{stmt: s.stmt, hasNext: true}, nil +} + +func (s *SQLiteStatement) QueryRow(ctx context.Context, args ...any) Row { + rows, err := s.Query(ctx, args...) + if err != nil { + return &SQLiteRow{err: err} + } + return &SQLiteRow{rows: rows.(*SQLiteRows)} +} + +func (s *SQLiteStatement) Exec(ctx context.Context, args ...any) (Result, error) { + conn := &SQLiteConnection{conn: s.conn} + if err := conn.bindArgs(s.stmt, args...); err != nil { + return nil, err + } + + hasRow, err := s.stmt.Step() + if err != nil { + return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err) + } + + // Consume all rows if any + for hasRow { + hasRow, err = s.stmt.Step() + if err != nil { + return nil, fmt.Errorf("sqlite: error stepping through results: %w", err) + } + } + + return &SQLiteResult{ + lastInsertID: s.conn.LastInsertRowID(), + rowsAffected: s.conn.Changes(), + }, nil +} diff --git a/modules/sqlite/sqlite.lua b/modules/sqlite/sqlite.lua new file mode 100644 index 0000000..0ab2bbb --- /dev/null +++ b/modules/sqlite/sqlite.lua @@ -0,0 +1,655 @@ +local str = require("string") +local sqlite = {} + +local Connection = {} +Connection.__index = Connection + +function Connection:close() + if self._id then + local ok = moonshark.sql_close(self._id) + self._id = nil + return ok + end + return false +end + +function Connection:ping() + if not self._id then + error("Connection is closed") + end + return moonshark.sql_ping(self._id) +end + +function Connection:query(query_str, ...) + if not self._id then + error("Connection is closed") + end + query_str = str.normalize_whitespace(query_str) + return moonshark.sql_query(self._id, query_str, ...) +end + +function Connection:exec(query_str, ...) + if not self._id then + error("Connection is closed") + end + query_str = str.normalize_whitespace(query_str) + return moonshark.sql_exec(self._id, query_str, ...) +end + +function Connection:query_row(query_str, ...) + local results = self:query(query_str, ...) + if results and #results > 0 then + return results[1] + end + return nil +end + +function Connection:query_value(query_str, ...) + local row = self:query_row(query_str, ...) + if row then + for _, value in pairs(row) do + return value + end + end + return nil +end + +-- Enhanced transaction support +function Connection:begin() + local result = self:exec("BEGIN") + if result then + return { + conn = self, + active = true, + + commit = function(tx) + if tx.active then + local result = tx.conn:exec("COMMIT") + tx.active = false + return result + end + return false + end, + + rollback = function(tx) + if tx.active then + local result = tx.conn:exec("ROLLBACK") + tx.active = false + return result + end + return false + end, + + query = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query(query_str, ...) + end, + + exec = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:exec(query_str, ...) + end, + + query_row = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query_row(query_str, ...) + end, + + query_value = function(tx, query_str, ...) + if not tx.active then + error("Transaction is not active") + end + return tx.conn:query_value(query_str, ...) + end + } + end + return nil +end + +-- Enhanced query builder helpers with string utilities +function Connection:insert(table_name, data) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local keys = {} + local values = {} + local placeholders = {} + + for key, value in pairs(data) do + table.insert(keys, key) + table.insert(values, value) + table.insert(placeholders, "?") + end + + local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", { + table = table_name, + columns = str.join(keys, ", "), + placeholders = str.join(placeholders, ", ") + }) + + return self:exec(query, unpack(values)) +end + +function Connection:upsert(table_name, data, conflict_columns) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local keys = {} + local values = {} + local placeholders = {} + local updates = {} + + for key, value in pairs(data) do + table.insert(keys, key) + table.insert(values, value) + table.insert(placeholders, "?") + table.insert(updates, str.template("${key} = excluded.${key}", {key = key})) + end + + local conflict_clause = "" + if conflict_columns then + if type(conflict_columns) == "string" then + conflict_clause = str.template("(${columns})", {columns = conflict_columns}) + else + conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")}) + end + end + + local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", { + table = table_name, + columns = str.join(keys, ", "), + placeholders = str.join(placeholders, ", "), + conflict = conflict_clause, + updates = str.join(updates, ", ") + }) + + return self:exec(query, unpack(values)) +end + +function Connection:update(table_name, data, where_clause, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + if str.is_blank(where_clause) then + error("WHERE clause cannot be empty for UPDATE") + end + + local sets = {} + local values = {} + + for key, value in pairs(data) do + table.insert(sets, str.template("${key} = ?", {key = key})) + table.insert(values, value) + end + + local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", { + table = table_name, + sets = str.join(sets, ", "), + where = where_clause + }) + + -- Add WHERE clause parameters + local where_args = {...} + for i = 1, #where_args do + table.insert(values, where_args[i]) + end + + return self:exec(query, unpack(values)) +end + +function Connection:delete(table_name, where_clause, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + if str.is_blank(where_clause) then + error("WHERE clause cannot be empty for DELETE") + end + + local query = str.template("DELETE FROM ${table} WHERE ${where}", { + table = table_name, + where = where_clause + }) + return self:exec(query, ...) +end + +function Connection:select(table_name, columns, where_clause, ...) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + columns = columns or "*" + if type(columns) == "table" then + columns = str.join(columns, ", ") + end + + local query + if where_clause and not str.is_blank(where_clause) then + query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", { + columns = columns, + table = table_name, + where = where_clause + }) + return self:query(query, ...) + else + query = str.template("SELECT ${columns} FROM ${table}", { + columns = columns, + table = table_name + }) + return self:query(query) + end +end + +-- Enhanced schema helpers with validation +function Connection:table_exists(table_name) + if str.is_blank(table_name) then + return false + end + + local result = self:query_value( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + str.trim(table_name) + ) + return result ~= nil +end + +function Connection:column_exists(table_name, column_name) + if str.is_blank(table_name) or str.is_blank(column_name) then + return false + end + + local result = self:query(str.template("PRAGMA table_info(${table})", {table = table_name})) + if result then + for _, row in ipairs(result) do + if str.iequals(row.name, str.trim(column_name)) then + return true + end + end + end + return false +end + +function Connection:create_table(table_name, schema) + if str.is_blank(table_name) or str.is_blank(schema) then + error("Table name and schema cannot be empty") + end + + local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", { + table = table_name, + schema = str.trim(schema) + }) + return self:exec(query) +end + +function Connection:drop_table(table_name) + if str.is_blank(table_name) then + error("Table name cannot be empty") + end + + local query = str.template("DROP TABLE IF EXISTS ${table}", {table = table_name}) + return self:exec(query) +end + +function Connection:add_column(table_name, column_def) + if str.is_blank(table_name) or str.is_blank(column_def) then + error("Table name and column definition cannot be empty") + end + + local query = str.template("ALTER TABLE ${table} ADD COLUMN ${column}", { + table = table_name, + column = str.trim(column_def) + }) + return self:exec(query) +end + +function Connection:create_index(index_name, table_name, columns, unique) + if str.is_blank(index_name) or str.is_blank(table_name) then + error("Index name and table name cannot be empty") + end + + local unique_clause = unique and "UNIQUE " or "" + local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns) + + local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table} (${columns})", { + unique = unique_clause, + index = index_name, + table = table_name, + columns = columns_str + }) + return self:exec(query) +end + +function Connection:drop_index(index_name) + if str.is_blank(index_name) then + error("Index name cannot be empty") + end + + local query = str.template("DROP INDEX IF EXISTS ${index}", {index = index_name}) + return self:exec(query) +end + +-- Enhanced SQLite-specific functions +function Connection:vacuum() + return self:exec("VACUUM") +end + +function Connection:analyze() + return self:exec("ANALYZE") +end + +function Connection:integrity_check() + return self:query("PRAGMA integrity_check") +end + +function Connection:foreign_keys(enabled) + local value = enabled and "ON" or "OFF" + return self:exec(str.template("PRAGMA foreign_keys = ${value}", {value = value})) +end + +function Connection:journal_mode(mode) + mode = mode or "WAL" + if not str.contains(str.upper(mode), "DELETE") and + not str.contains(str.upper(mode), "TRUNCATE") and + not str.contains(str.upper(mode), "PERSIST") and + not str.contains(str.upper(mode), "MEMORY") and + not str.contains(str.upper(mode), "WAL") and + not str.contains(str.upper(mode), "OFF") then + error("Invalid journal mode: " .. mode) + end + return self:query(str.template("PRAGMA journal_mode = ${mode}", {mode = str.upper(mode)})) +end + +function Connection:synchronous(level) + level = level or "NORMAL" + local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"} + local level_upper = str.upper(level) + + local valid = false + for _, valid_level in ipairs(valid_levels) do + if level_upper == valid_level then + valid = true + break + end + end + + if not valid then + error("Invalid synchronous level: " .. level) + end + + return self:exec(str.template("PRAGMA synchronous = ${level}", {level = level_upper})) +end + +function Connection:cache_size(size) + size = size or -64000 + if type(size) ~= "number" then + error("Cache size must be a number") + end + return self:exec(str.template("PRAGMA cache_size = ${size}", {size = tostring(size)})) +end + +function Connection:temp_store(mode) + mode = mode or "MEMORY" + local valid_modes = {"DEFAULT", "FILE", "MEMORY"} + local mode_upper = str.upper(mode) + + local valid = false + for _, valid_mode in ipairs(valid_modes) do + if mode_upper == valid_mode then + valid = true + break + end + end + + if not valid then + error("Invalid temp_store mode: " .. mode) + end + + return self:exec(str.template("PRAGMA temp_store = ${mode}", {mode = mode_upper})) +end + +-- Connection management with enhanced path handling +function sqlite.open(database_path) + database_path = database_path or ":memory:" + + -- Clean up path + if database_path ~= ":memory:" then + database_path = str.trim(database_path) + if str.is_blank(database_path) then + database_path = ":memory:" + end + end + + local conn_id = moonshark.sql_connect("sqlite", database_path) + if conn_id then + local conn = {_id = conn_id} + setmetatable(conn, Connection) + return conn + end + return nil +end + +sqlite.connect = sqlite.open + +-- Enhanced quick execution functions +function sqlite.query(database_path, query_str, ...) + local conn = sqlite.open(database_path) + if not conn then + error(str.template("Failed to open SQLite database: ${path}", { + path = database_path or ":memory:" + })) + end + + local results = conn:query(query_str, ...) + conn:close() + return results +end + +function sqlite.exec(database_path, query_str, ...) + local conn = sqlite.open(database_path) + if not conn then + error(str.template("Failed to open SQLite database: ${path}", { + path = database_path or ":memory:" + })) + end + + local result = conn:exec(query_str, ...) + conn:close() + return result +end + +function sqlite.query_row(database_path, query_str, ...) + local results = sqlite.query(database_path, query_str, ...) + if results and #results > 0 then + return results[1] + end + return nil +end + +function sqlite.query_value(database_path, query_str, ...) + local row = sqlite.query_row(database_path, query_str, ...) + if row then + for _, value in pairs(row) do + return value + end + end + return nil +end + +-- Enhanced migration helpers +function sqlite.migrate(database_path, migrations) + local conn = sqlite.open(database_path) + if not conn then + error("Failed to open SQLite database for migration") + end + + -- Create migrations table + conn:create_table("_migrations", + "id INTEGER PRIMARY KEY, name TEXT UNIQUE, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP") + + local tx = conn:begin() + if not tx then + conn:close() + error("Failed to begin migration transaction") + end + + local success = true + local error_msg = "" + + for _, migration in ipairs(migrations) do + if not migration.name or str.is_blank(migration.name) then + error_msg = "Migration must have a non-empty name" + success = false + break + end + + -- Check if migration already applied + local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?", + str.trim(migration.name)) + if not existing then + local ok, err = pcall(function() + if type(migration.up) == "string" then + conn:exec(migration.up) + elseif type(migration.up) == "function" then + migration.up(conn) + else + error("Migration 'up' must be string or function") + end + end) + + if ok then + conn:exec("INSERT INTO _migrations (name) VALUES (?)", str.trim(migration.name)) + print(str.template("Applied migration: ${name}", {name = migration.name})) + else + success = false + error_msg = str.template("Migration '${name}' failed: ${error}", { + name = migration.name, + error = err or "unknown error" + }) + break + end + end + end + + if success then + tx:commit() + else + tx:rollback() + conn:close() + error(error_msg) + end + + conn:close() + return true +end + +-- Enhanced result processing utilities +function sqlite.to_array(results, column_name) + if not results or #results == 0 then + return {} + end + + if str.is_blank(column_name) then + error("Column name cannot be empty") + end + + local array = {} + for i, row in ipairs(results) do + array[i] = row[column_name] + end + return array +end + +function sqlite.to_map(results, key_column, value_column) + if not results or #results == 0 then + return {} + end + + if str.is_blank(key_column) then + error("Key column name cannot be empty") + end + + local map = {} + for _, row in ipairs(results) do + local key = row[key_column] + map[key] = value_column and row[value_column] or row + end + return map +end + +function sqlite.group_by(results, column_name) + if not results or #results == 0 then + return {} + end + + if str.is_blank(column_name) then + error("Column name cannot be empty") + end + + local groups = {} + for _, row in ipairs(results) do + local key = row[column_name] + if not groups[key] then + groups[key] = {} + end + table.insert(groups[key], row) + end + return groups +end + +-- Enhanced debug helper +function sqlite.print_results(results) + if not results or #results == 0 then + print("No results") + return + end + + -- Get column names from first row + local columns = {} + for col, _ in pairs(results[1]) do + table.insert(columns, col) + end + table.sort(columns) + + -- Calculate column widths for better formatting + local widths = {} + for _, col in ipairs(columns) do + widths[col] = str.length(col) + end + + for _, row in ipairs(results) do + for _, col in ipairs(columns) do + local value = tostring(row[col] or "") + widths[col] = math.max(widths[col], str.length(value)) + end + end + + -- Print header with proper spacing + local header_parts = {} + local separator_parts = {} + for _, col in ipairs(columns) do + table.insert(header_parts, str.pad_right(col, widths[col])) + table.insert(separator_parts, str.repeat_("-", widths[col])) + end + + print(str.join(header_parts, " | ")) + print(str.join(separator_parts, "-+-")) + + -- Print rows with proper spacing + for _, row in ipairs(results) do + local value_parts = {} + for _, col in ipairs(columns) do + local value = tostring(row[col] or "") + table.insert(value_parts, str.pad_right(value, widths[col])) + end + print(str.join(value_parts, " | ")) + end +end + +return sqlite