initial sql database support - sqlite, postgres, mysql
This commit is contained in:
parent
09646394a5
commit
cf203d7899
1
.gitignore
vendored
1
.gitignore
vendored
@ -27,3 +27,4 @@ go.work
|
||||
test_fs_dir
|
||||
public
|
||||
test
|
||||
test.db
|
||||
|
23
go.mod
23
go.mod
@ -6,6 +6,29 @@ require git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6
|
||||
|
||||
require github.com/google/uuid v1.6.0
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.7.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/crypto v0.40.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
modernc.org/libc v1.65.7 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.37.1 // indirect
|
||||
zombiezen.com/go/sqlite v1.4.2 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
|
46
go.sum
46
go.sum
@ -1,18 +1,64 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 h1:XytP9R2fWykv0MXIzxggPx5S/PmTkjyZVvUX2sn4EaU=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
|
||||
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og=
|
||||
github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/libc v1.65.7 h1:Ia9Z4yzZtWNtUIuiPuQ7Qf7kxYrxP1/jeHZzG8bFu00=
|
||||
modernc.org/libc v1.65.7/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
|
||||
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
|
||||
zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo=
|
||||
zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc=
|
||||
|
@ -263,7 +263,6 @@ func handleRequest(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Lua 404, try static handlers if not already checked
|
||||
if !isLikelyStaticFile(path) && tryStaticHandler(ctx, path) {
|
||||
return
|
||||
}
|
||||
|
1003
modules/mysql/mysql.lua
Normal file
1003
modules/mysql/mysql.lua
Normal file
File diff suppressed because it is too large
Load Diff
847
modules/postgres/postgres.lua
Normal file
847
modules/postgres/postgres.lua
Normal file
@ -0,0 +1,847 @@
|
||||
local str = require("string")
|
||||
local postgres = {}
|
||||
|
||||
local Connection = {}
|
||||
Connection.__index = Connection
|
||||
|
||||
function Connection:close()
|
||||
if self._id then
|
||||
local ok = moonshark.sql_close(self._id)
|
||||
self._id = nil
|
||||
return ok
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function Connection:ping()
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_ping(self._id)
|
||||
end
|
||||
|
||||
function Connection:query(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
query_str = str.normalize_whitespace(query_str)
|
||||
return moonshark.sql_query(self._id, query_str, ...)
|
||||
end
|
||||
|
||||
function Connection:exec(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
query_str = str.normalize_whitespace(query_str)
|
||||
return moonshark.sql_exec(self._id, query_str, ...)
|
||||
end
|
||||
|
||||
function Connection:query_row(query_str, ...)
|
||||
local results = self:query(query_str, ...)
|
||||
if results and #results > 0 then
|
||||
return results[1]
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function Connection:query_value(query_str, ...)
|
||||
local row = self:query_row(query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Enhanced transaction support with savepoints
|
||||
function Connection:begin()
|
||||
local result = self:exec("BEGIN")
|
||||
if result then
|
||||
return {
|
||||
conn = self,
|
||||
active = true,
|
||||
|
||||
commit = function(tx)
|
||||
if tx.active then
|
||||
local result = tx.conn:exec("COMMIT")
|
||||
tx.active = false
|
||||
return result
|
||||
end
|
||||
return false
|
||||
end,
|
||||
|
||||
rollback = function(tx)
|
||||
if tx.active then
|
||||
local result = tx.conn:exec("ROLLBACK")
|
||||
tx.active = false
|
||||
return result
|
||||
end
|
||||
return false
|
||||
end,
|
||||
|
||||
savepoint = function(tx, name)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
if str.is_blank(name) then
|
||||
error("Savepoint name cannot be empty")
|
||||
end
|
||||
return tx.conn:exec(str.template("SAVEPOINT ${name}", {name = name}))
|
||||
end,
|
||||
|
||||
rollback_to = function(tx, name)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
if str.is_blank(name) then
|
||||
error("Savepoint name cannot be empty")
|
||||
end
|
||||
return tx.conn:exec(str.template("ROLLBACK TO SAVEPOINT ${name}", {name = name}))
|
||||
end,
|
||||
|
||||
query = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:query(query_str, ...)
|
||||
end,
|
||||
|
||||
exec = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:exec(query_str, ...)
|
||||
end,
|
||||
|
||||
query_row = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:query_row(query_str, ...)
|
||||
end,
|
||||
|
||||
query_value = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:query_value(query_str, ...)
|
||||
end
|
||||
}
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Enhanced query builder helpers with PostgreSQL parameter numbering
|
||||
local function build_postgres_params(data)
|
||||
local keys = {}
|
||||
local values = {}
|
||||
local placeholders = {}
|
||||
local param_count = 0
|
||||
|
||||
for key, value in pairs(data) do
|
||||
table.insert(keys, key)
|
||||
table.insert(values, value)
|
||||
param_count = param_count + 1
|
||||
table.insert(placeholders, str.template("$${num}", {num = tostring(param_count)}))
|
||||
end
|
||||
|
||||
return keys, values, placeholders, param_count
|
||||
end
|
||||
|
||||
function Connection:insert(table_name, data, returning)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys, values, placeholders = build_postgres_params(data)
|
||||
|
||||
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
||||
table = table_name,
|
||||
columns = str.join(keys, ", "),
|
||||
placeholders = str.join(placeholders, ", ")
|
||||
})
|
||||
|
||||
if returning and not str.is_blank(returning) then
|
||||
query = str.template("${query} RETURNING ${returning}", {
|
||||
query = query,
|
||||
returning = returning
|
||||
})
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:upsert(table_name, data, conflict_columns, returning)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys, values, placeholders = build_postgres_params(data)
|
||||
local updates = {}
|
||||
|
||||
for _, key in ipairs(keys) do
|
||||
table.insert(updates, str.template("${key} = EXCLUDED.${key}", {key = key}))
|
||||
end
|
||||
|
||||
local conflict_clause = ""
|
||||
if conflict_columns then
|
||||
if type(conflict_columns) == "string" then
|
||||
conflict_clause = str.template("(${columns})", {columns = conflict_columns})
|
||||
else
|
||||
conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")})
|
||||
end
|
||||
end
|
||||
|
||||
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
|
||||
table = table_name,
|
||||
columns = str.join(keys, ", "),
|
||||
placeholders = str.join(placeholders, ", "),
|
||||
conflict = conflict_clause,
|
||||
updates = str.join(updates, ", ")
|
||||
})
|
||||
|
||||
if returning and not str.is_blank(returning) then
|
||||
query = str.template("${query} RETURNING ${returning}", {
|
||||
query = query,
|
||||
returning = returning
|
||||
})
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:update(table_name, data, where_clause, returning, ...)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if str.is_blank(where_clause) then
|
||||
error("WHERE clause cannot be empty for UPDATE")
|
||||
end
|
||||
|
||||
local sets = {}
|
||||
local values = {}
|
||||
local param_count = 0
|
||||
|
||||
for key, value in pairs(data) do
|
||||
param_count = param_count + 1
|
||||
table.insert(sets, str.template("${key} = $${num}", {
|
||||
key = key,
|
||||
num = tostring(param_count)
|
||||
}))
|
||||
table.insert(values, value)
|
||||
end
|
||||
|
||||
-- Handle WHERE clause parameters
|
||||
local where_args = {...}
|
||||
local where_clause_with_params = where_clause
|
||||
for i = 1, #where_args do
|
||||
param_count = param_count + 1
|
||||
table.insert(values, where_args[i])
|
||||
-- Replace ? with numbered parameter if needed
|
||||
where_clause_with_params = str.replace(where_clause_with_params, "?",
|
||||
str.template("$${num}", {num = tostring(param_count)}), 1)
|
||||
end
|
||||
|
||||
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
||||
table = table_name,
|
||||
sets = str.join(sets, ", "),
|
||||
where = where_clause_with_params
|
||||
})
|
||||
|
||||
if returning and not str.is_blank(returning) then
|
||||
query = str.template("${query} RETURNING ${returning}", {
|
||||
query = query,
|
||||
returning = returning
|
||||
})
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:delete(table_name, where_clause, returning, ...)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if str.is_blank(where_clause) then
|
||||
error("WHERE clause cannot be empty for DELETE")
|
||||
end
|
||||
|
||||
-- Handle WHERE clause parameters
|
||||
local where_args = {...}
|
||||
local values = {}
|
||||
local where_clause_with_params = where_clause
|
||||
for i = 1, #where_args do
|
||||
table.insert(values, where_args[i])
|
||||
where_clause_with_params = str.replace(where_clause_with_params, "?",
|
||||
str.template("$${num}", {num = tostring(i)}), 1)
|
||||
end
|
||||
|
||||
local query = str.template("DELETE FROM ${table} WHERE ${where}", {
|
||||
table = table_name,
|
||||
where = where_clause_with_params
|
||||
})
|
||||
|
||||
if returning and not str.is_blank(returning) then
|
||||
query = str.template("${query} RETURNING ${returning}", {
|
||||
query = query,
|
||||
returning = returning
|
||||
})
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:select(table_name, columns, where_clause, ...)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
columns = columns or "*"
|
||||
if type(columns) == "table" then
|
||||
columns = str.join(columns, ", ")
|
||||
end
|
||||
|
||||
local query
|
||||
if where_clause and not str.is_blank(where_clause) then
|
||||
-- Handle WHERE clause parameters
|
||||
local where_args = {...}
|
||||
local values = {}
|
||||
local where_clause_with_params = where_clause
|
||||
for i = 1, #where_args do
|
||||
table.insert(values, where_args[i])
|
||||
where_clause_with_params = str.replace(where_clause_with_params, "?",
|
||||
str.template("$${num}", {num = tostring(i)}), 1)
|
||||
end
|
||||
|
||||
query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
||||
columns = columns,
|
||||
table = table_name,
|
||||
where = where_clause_with_params
|
||||
})
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
query = str.template("SELECT ${columns} FROM ${table}", {
|
||||
columns = columns,
|
||||
table = table_name
|
||||
})
|
||||
return self:query(query)
|
||||
end
|
||||
end
|
||||
|
||||
-- Enhanced PostgreSQL schema helpers
|
||||
function Connection:table_exists(table_name, schema_name)
|
||||
if str.is_blank(table_name) then
|
||||
return false
|
||||
end
|
||||
|
||||
schema_name = schema_name or "public"
|
||||
local result = self:query_value(
|
||||
"SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2",
|
||||
str.trim(schema_name), str.trim(table_name)
|
||||
)
|
||||
return result ~= nil
|
||||
end
|
||||
|
||||
function Connection:column_exists(table_name, column_name, schema_name)
|
||||
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||
return false
|
||||
end
|
||||
|
||||
schema_name = schema_name or "public"
|
||||
local result = self:query_value([[
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
|
||||
]], str.trim(schema_name), str.trim(table_name), str.trim(column_name))
|
||||
return result ~= nil
|
||||
end
|
||||
|
||||
function Connection:create_table(table_name, schema)
|
||||
if str.is_blank(table_name) or str.is_blank(schema) then
|
||||
error("Table name and schema cannot be empty")
|
||||
end
|
||||
|
||||
local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", {
|
||||
table = table_name,
|
||||
schema = str.trim(schema)
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:drop_table(table_name, cascade)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local cascade_clause = cascade and " CASCADE" or ""
|
||||
local query = str.template("DROP TABLE IF EXISTS ${table}${cascade}", {
|
||||
table = table_name,
|
||||
cascade = cascade_clause
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:add_column(table_name, column_def)
|
||||
if str.is_blank(table_name) or str.is_blank(column_def) then
|
||||
error("Table name and column definition cannot be empty")
|
||||
end
|
||||
|
||||
local query = str.template("ALTER TABLE ${table} ADD COLUMN IF NOT EXISTS ${column}", {
|
||||
table = table_name,
|
||||
column = str.trim(column_def)
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:drop_column(table_name, column_name, cascade)
|
||||
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||
error("Table name and column name cannot be empty")
|
||||
end
|
||||
|
||||
local cascade_clause = cascade and " CASCADE" or ""
|
||||
local query = str.template("ALTER TABLE ${table} DROP COLUMN IF EXISTS ${column}${cascade}", {
|
||||
table = table_name,
|
||||
column = column_name,
|
||||
cascade = cascade_clause
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:create_index(index_name, table_name, columns, unique, method)
|
||||
if str.is_blank(index_name) or str.is_blank(table_name) then
|
||||
error("Index name and table name cannot be empty")
|
||||
end
|
||||
|
||||
local unique_clause = unique and "UNIQUE " or ""
|
||||
local method_clause = method and str.template(" USING ${method}", {method = str.upper(method)}) or ""
|
||||
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns)
|
||||
|
||||
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table}${method} (${columns})", {
|
||||
unique = unique_clause,
|
||||
index = index_name,
|
||||
table = table_name,
|
||||
method = method_clause,
|
||||
columns = columns_str
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:drop_index(index_name, cascade)
|
||||
if str.is_blank(index_name) then
|
||||
error("Index name cannot be empty")
|
||||
end
|
||||
|
||||
local cascade_clause = cascade and " CASCADE" or ""
|
||||
local query = str.template("DROP INDEX IF EXISTS ${index}${cascade}", {
|
||||
index = index_name,
|
||||
cascade = cascade_clause
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
-- Enhanced PostgreSQL-specific functions
|
||||
function Connection:vacuum(table_name, analyze)
|
||||
local analyze_clause = analyze and " ANALYZE" or ""
|
||||
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
|
||||
return self:exec(str.template("VACUUM${analyze}${table}", {
|
||||
analyze = analyze_clause,
|
||||
table = table_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:analyze(table_name)
|
||||
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
|
||||
return self:exec(str.template("ANALYZE${table}", {table = table_clause}))
|
||||
end
|
||||
|
||||
function Connection:reindex(name, type)
|
||||
if str.is_blank(name) then
|
||||
error("Name cannot be empty for REINDEX")
|
||||
end
|
||||
|
||||
type = type or "INDEX"
|
||||
local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"}
|
||||
local type_upper = str.upper(type)
|
||||
|
||||
local valid = false
|
||||
for _, valid_type in ipairs(valid_types) do
|
||||
if type_upper == valid_type then
|
||||
valid = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not valid then
|
||||
error(str.template("Invalid REINDEX type: ${type}", {type = type}))
|
||||
end
|
||||
|
||||
return self:exec(str.template("REINDEX ${type} ${name}", {
|
||||
type = type_upper,
|
||||
name = name
|
||||
}))
|
||||
end
|
||||
|
||||
-- PostgreSQL settings and introspection
|
||||
function Connection:show(setting)
|
||||
if str.is_blank(setting) then
|
||||
error("Setting name cannot be empty")
|
||||
end
|
||||
return self:query_value(str.template("SHOW ${setting}", {setting = setting}))
|
||||
end
|
||||
|
||||
function Connection:set(setting, value)
|
||||
if str.is_blank(setting) then
|
||||
error("Setting name cannot be empty")
|
||||
end
|
||||
return self:exec(str.template("SET ${setting} = ${value}", {
|
||||
setting = setting,
|
||||
value = tostring(value)
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:current_database()
|
||||
return self:query_value("SELECT current_database()")
|
||||
end
|
||||
|
||||
function Connection:current_schema()
|
||||
return self:query_value("SELECT current_schema()")
|
||||
end
|
||||
|
||||
function Connection:version()
|
||||
return self:query_value("SELECT version()")
|
||||
end
|
||||
|
||||
function Connection:list_schemas()
|
||||
return self:query("SELECT schema_name FROM information_schema.schemata ORDER BY schema_name")
|
||||
end
|
||||
|
||||
function Connection:list_tables(schema_name)
|
||||
schema_name = schema_name or "public"
|
||||
return self:query("SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
|
||||
str.trim(schema_name))
|
||||
end
|
||||
|
||||
function Connection:describe_table(table_name, schema_name)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
schema_name = schema_name or "public"
|
||||
return self:query([[
|
||||
SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
ORDER BY ordinal_position
|
||||
]], str.trim(schema_name), str.trim(table_name))
|
||||
end
|
||||
|
||||
-- Enhanced JSON/JSONB helpers
|
||||
function Connection:json_extract(column, path)
|
||||
if str.is_blank(column) or str.is_blank(path) then
|
||||
error("Column and path cannot be empty")
|
||||
end
|
||||
return str.template("${column}->'${path}'", {column = column, path = path})
|
||||
end
|
||||
|
||||
function Connection:json_extract_text(column, path)
|
||||
if str.is_blank(column) or str.is_blank(path) then
|
||||
error("Column and path cannot be empty")
|
||||
end
|
||||
return str.template("${column}->>'${path}'", {column = column, path = path})
|
||||
end
|
||||
|
||||
function Connection:jsonb_contains(column, value)
|
||||
if str.is_blank(column) or str.is_blank(value) then
|
||||
error("Column and value cannot be empty")
|
||||
end
|
||||
return str.template("${column} @> '${value}'", {column = column, value = value})
|
||||
end
|
||||
|
||||
function Connection:jsonb_contained_by(column, value)
|
||||
if str.is_blank(column) or str.is_blank(value) then
|
||||
error("Column and value cannot be empty")
|
||||
end
|
||||
return str.template("${column} <@ '${value}'", {column = column, value = value})
|
||||
end
|
||||
|
||||
-- Enhanced Array helpers
|
||||
function Connection:array_contains(column, value)
|
||||
if str.is_blank(column) then
|
||||
error("Column cannot be empty")
|
||||
end
|
||||
return str.template("$1 = ANY(${column})", {column = column})
|
||||
end
|
||||
|
||||
function Connection:array_length(column)
|
||||
if str.is_blank(column) then
|
||||
error("Column cannot be empty")
|
||||
end
|
||||
return str.template("array_length(${column}, 1)", {column = column})
|
||||
end
|
||||
|
||||
-- Enhanced connection management with DSN parsing
|
||||
function postgres.parse_dsn(dsn)
|
||||
if str.is_blank(dsn) then
|
||||
return nil, "DSN cannot be empty"
|
||||
end
|
||||
|
||||
local parts = {}
|
||||
|
||||
-- Split by spaces and handle key=value pairs
|
||||
for pair in str.trim(dsn):gmatch("[^%s]+") do
|
||||
local key, value = pair:match("([^=]+)=(.+)")
|
||||
if key and value then
|
||||
parts[str.trim(key)] = str.trim(value)
|
||||
end
|
||||
end
|
||||
|
||||
return parts
|
||||
end
|
||||
|
||||
function postgres.connect(dsn)
|
||||
if str.is_blank(dsn) then
|
||||
error("DSN cannot be empty")
|
||||
end
|
||||
|
||||
local conn_id = moonshark.sql_connect("postgres", str.trim(dsn))
|
||||
if conn_id then
|
||||
local conn = {_id = conn_id}
|
||||
setmetatable(conn, Connection)
|
||||
return conn
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
postgres.open = postgres.connect
|
||||
|
||||
-- Enhanced quick execution functions
|
||||
function postgres.query(dsn, query_str, ...)
|
||||
local conn = postgres.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to PostgreSQL database")
|
||||
end
|
||||
|
||||
local results = conn:query(query_str, ...)
|
||||
conn:close()
|
||||
return results
|
||||
end
|
||||
|
||||
function postgres.exec(dsn, query_str, ...)
|
||||
local conn = postgres.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to PostgreSQL database")
|
||||
end
|
||||
|
||||
local result = conn:exec(query_str, ...)
|
||||
conn:close()
|
||||
return result
|
||||
end
|
||||
|
||||
function postgres.query_row(dsn, query_str, ...)
|
||||
local results = postgres.query(dsn, query_str, ...)
|
||||
if results and #results > 0 then
|
||||
return results[1]
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function postgres.query_value(dsn, query_str, ...)
|
||||
local row = postgres.query_row(dsn, query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Enhanced migration helpers
|
||||
function postgres.migrate(dsn, migrations, schema)
|
||||
schema = schema or "public"
|
||||
local conn = postgres.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to PostgreSQL database for migration")
|
||||
end
|
||||
|
||||
-- Create migrations table
|
||||
conn:create_table("_migrations",
|
||||
"id SERIAL PRIMARY KEY, name TEXT UNIQUE NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW()")
|
||||
|
||||
local tx = conn:begin()
|
||||
if not tx then
|
||||
conn:close()
|
||||
error("Failed to begin migration transaction")
|
||||
end
|
||||
|
||||
local success = true
|
||||
local error_msg = ""
|
||||
|
||||
for _, migration in ipairs(migrations) do
|
||||
if not migration.name or str.is_blank(migration.name) then
|
||||
error_msg = "Migration must have a non-empty name"
|
||||
success = false
|
||||
break
|
||||
end
|
||||
|
||||
-- Check if migration already applied
|
||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1",
|
||||
str.trim(migration.name))
|
||||
if not existing then
|
||||
local ok, err = pcall(function()
|
||||
if type(migration.up) == "string" then
|
||||
conn:exec(migration.up)
|
||||
elseif type(migration.up) == "function" then
|
||||
migration.up(conn)
|
||||
else
|
||||
error("Migration 'up' must be string or function")
|
||||
end
|
||||
end)
|
||||
|
||||
if ok then
|
||||
conn:exec("INSERT INTO _migrations (name) VALUES ($1)", str.trim(migration.name))
|
||||
print(str.template("Applied migration: ${name}", {name = migration.name}))
|
||||
else
|
||||
success = false
|
||||
error_msg = str.template("Migration '${name}' failed: ${error}", {
|
||||
name = migration.name,
|
||||
error = err or "unknown error"
|
||||
})
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if success then
|
||||
tx:commit()
|
||||
else
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error(error_msg)
|
||||
end
|
||||
|
||||
conn:close()
|
||||
return true
|
||||
end
|
||||
|
||||
-- Result processing utilities (same enhanced versions as SQLite)
|
||||
function postgres.to_array(results, column_name)
|
||||
if not results or #results == 0 then
|
||||
return {}
|
||||
end
|
||||
|
||||
if str.is_blank(column_name) then
|
||||
error("Column name cannot be empty")
|
||||
end
|
||||
|
||||
local array = {}
|
||||
for i, row in ipairs(results) do
|
||||
array[i] = row[column_name]
|
||||
end
|
||||
return array
|
||||
end
|
||||
|
||||
function postgres.to_map(results, key_column, value_column)
|
||||
if not results or #results == 0 then
|
||||
return {}
|
||||
end
|
||||
|
||||
if str.is_blank(key_column) then
|
||||
error("Key column name cannot be empty")
|
||||
end
|
||||
|
||||
local map = {}
|
||||
for _, row in ipairs(results) do
|
||||
local key = row[key_column]
|
||||
map[key] = value_column and row[value_column] or row
|
||||
end
|
||||
return map
|
||||
end
|
||||
|
||||
function postgres.group_by(results, column_name)
|
||||
if not results or #results == 0 then
|
||||
return {}
|
||||
end
|
||||
|
||||
if str.is_blank(column_name) then
|
||||
error("Column name cannot be empty")
|
||||
end
|
||||
|
||||
local groups = {}
|
||||
for _, row in ipairs(results) do
|
||||
local key = row[column_name]
|
||||
if not groups[key] then
|
||||
groups[key] = {}
|
||||
end
|
||||
table.insert(groups[key], row)
|
||||
end
|
||||
return groups
|
||||
end
|
||||
|
||||
-- Enhanced debug helper (same as SQLite)
|
||||
function postgres.print_results(results)
|
||||
if not results or #results == 0 then
|
||||
print("No results")
|
||||
return
|
||||
end
|
||||
|
||||
-- Get column names from first row
|
||||
local columns = {}
|
||||
for col, _ in pairs(results[1]) do
|
||||
table.insert(columns, col)
|
||||
end
|
||||
table.sort(columns)
|
||||
|
||||
-- Calculate column widths for better formatting
|
||||
local widths = {}
|
||||
for _, col in ipairs(columns) do
|
||||
widths[col] = str.length(col)
|
||||
end
|
||||
|
||||
for _, row in ipairs(results) do
|
||||
for _, col in ipairs(columns) do
|
||||
local value = tostring(row[col] or "")
|
||||
widths[col] = math.max(widths[col], str.length(value))
|
||||
end
|
||||
end
|
||||
|
||||
-- Print header with proper spacing
|
||||
local header_parts = {}
|
||||
local separator_parts = {}
|
||||
for _, col in ipairs(columns) do
|
||||
table.insert(header_parts, str.pad_right(col, widths[col]))
|
||||
table.insert(separator_parts, str.repeat_("-", widths[col]))
|
||||
end
|
||||
|
||||
print(str.join(header_parts, " | "))
|
||||
print(str.join(separator_parts, "-+-"))
|
||||
|
||||
-- Print rows with proper spacing
|
||||
for _, row in ipairs(results) do
|
||||
local value_parts = {}
|
||||
for _, col in ipairs(columns) do
|
||||
local value = tostring(row[col] or "")
|
||||
table.insert(value_parts, str.pad_right(value, widths[col]))
|
||||
end
|
||||
print(str.join(value_parts, " | "))
|
||||
end
|
||||
end
|
||||
|
||||
function postgres.escape_identifier(name)
|
||||
if str.is_blank(name) then
|
||||
error("Identifier name cannot be empty")
|
||||
end
|
||||
return str.template('"${name}"', {name = str.replace(name, '"', '""')})
|
||||
end
|
||||
|
||||
function postgres.escape_literal(value)
|
||||
if type(value) == "string" then
|
||||
return str.template("'${value}'", {value = str.replace(value, "'", "''")})
|
||||
end
|
||||
return tostring(value)
|
||||
end
|
||||
|
||||
return postgres
|
@ -9,6 +9,7 @@ import (
|
||||
"Moonshark/modules/fs"
|
||||
"Moonshark/modules/http"
|
||||
"Moonshark/modules/math"
|
||||
"Moonshark/modules/sql"
|
||||
lua_string "Moonshark/modules/string"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
@ -17,7 +18,7 @@ import (
|
||||
// Global registry instance
|
||||
var Global *Registry
|
||||
|
||||
//go:embed crypto/*.lua fs/*.lua json/*.lua math/*.lua string/*.lua http/*.lua
|
||||
//go:embed **/*.lua
|
||||
var embeddedModules embed.FS
|
||||
|
||||
// Registry manages all Lua modules and Go functions
|
||||
@ -39,6 +40,7 @@ func New() *Registry {
|
||||
maps.Copy(r.goFuncs, crypto.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, fs.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, http.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, sql.GetFunctionList())
|
||||
|
||||
r.loadEmbeddedModules()
|
||||
return r
|
||||
|
205
modules/sql/mysql.go
Normal file
205
modules/sql/mysql.go
Normal file
@ -0,0 +1,205 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
// MySQLDriver implements the Driver interface for MySQL
|
||||
type MySQLDriver struct{}
|
||||
|
||||
func (d *MySQLDriver) Name() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (d *MySQLDriver) Open(dsn string) (Connection, error) {
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Test the connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("mysql: failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &MySQLConnection{db: db}, nil
|
||||
}
|
||||
|
||||
// MySQLConnection implements the Connection interface
|
||||
type MySQLConnection struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Ping(ctx context.Context) error {
|
||||
return c.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Begin(ctx context.Context) (Transaction, error) {
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to begin transaction: %w", err)
|
||||
}
|
||||
return &MySQLTransaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: query failed: %w", err)
|
||||
}
|
||||
return &MySQLRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := c.db.QueryRowContext(ctx, query, args...)
|
||||
return &MySQLRow{row: row}
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
result, err := c.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: exec failed: %w", err)
|
||||
}
|
||||
return &MySQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
stmt, err := c.db.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to prepare statement: %w", err)
|
||||
}
|
||||
return &MySQLStatement{stmt: stmt}, nil
|
||||
}
|
||||
|
||||
// MySQLTransaction implements the Transaction interface
|
||||
type MySQLTransaction struct {
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Rollback() error {
|
||||
return t.tx.Rollback()
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := t.tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: transaction query failed: %w", err)
|
||||
}
|
||||
return &MySQLRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := t.tx.QueryRowContext(ctx, query, args...)
|
||||
return &MySQLRow{row: row}
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
result, err := t.tx.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: transaction exec failed: %w", err)
|
||||
}
|
||||
return &MySQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
stmt, err := t.tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to prepare transaction statement: %w", err)
|
||||
}
|
||||
return &MySQLStatement{stmt: stmt}, nil
|
||||
}
|
||||
|
||||
// MySQLRows implements the Rows interface
|
||||
type MySQLRows struct {
|
||||
rows *sql.Rows
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Next() bool {
|
||||
return r.rows.Next()
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Scan(dest ...any) error {
|
||||
return r.rows.Scan(dest...)
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Columns() ([]string, error) {
|
||||
return r.rows.Columns()
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Close() error {
|
||||
return r.rows.Close()
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Err() error {
|
||||
return r.rows.Err()
|
||||
}
|
||||
|
||||
// MySQLRow implements the Row interface
|
||||
type MySQLRow struct {
|
||||
row *sql.Row
|
||||
}
|
||||
|
||||
func (r *MySQLRow) Scan(dest ...any) error {
|
||||
return r.row.Scan(dest...)
|
||||
}
|
||||
|
||||
// MySQLResult implements the Result interface
|
||||
type MySQLResult struct {
|
||||
result sql.Result
|
||||
}
|
||||
|
||||
func (r *MySQLResult) LastInsertId() (int64, error) {
|
||||
return r.result.LastInsertId()
|
||||
}
|
||||
|
||||
func (r *MySQLResult) RowsAffected() (int64, error) {
|
||||
return r.result.RowsAffected()
|
||||
}
|
||||
|
||||
// MySQLStatement implements the Statement interface
|
||||
type MySQLStatement struct {
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
||||
rows, err := s.stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: statement query failed: %w", err)
|
||||
}
|
||||
return &MySQLRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) QueryRow(ctx context.Context, args ...any) Row {
|
||||
row := s.stmt.QueryRowContext(ctx, args...)
|
||||
return &MySQLRow{row: row}
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
||||
result, err := s.stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: statement exec failed: %w", err)
|
||||
}
|
||||
return &MySQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register MySQL driver on import
|
||||
RegisterDriver("mysql", &MySQLDriver{})
|
||||
}
|
234
modules/sql/postgres.go
Normal file
234
modules/sql/postgres.go
Normal file
@ -0,0 +1,234 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// PostgresDriver implements the Driver interface for PostgreSQL
|
||||
type PostgresDriver struct{}
|
||||
|
||||
func (d *PostgresDriver) Name() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (d *PostgresDriver) Open(dsn string) (Connection, error) {
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to parse config: %w", err)
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to create pool: %w", err)
|
||||
}
|
||||
|
||||
return &PostgresConnection{pool: pool}, nil
|
||||
}
|
||||
|
||||
// PostgresConnection implements the Connection interface
|
||||
type PostgresConnection struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Close() error {
|
||||
c.pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Ping(ctx context.Context) error {
|
||||
return c.pool.Ping(ctx)
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Begin(ctx context.Context) (Transaction, error) {
|
||||
tx, err := c.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to begin transaction: %w", err)
|
||||
}
|
||||
return &PostgresTransaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := c.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: query failed: %w", err)
|
||||
}
|
||||
return &PostgresRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := c.pool.QueryRow(ctx, query, args...)
|
||||
return &PostgresRow{row: row}
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
tag, err := c.pool.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: exec failed: %w", err)
|
||||
}
|
||||
return &PostgresResult{tag: tag}, nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
// pgx doesn't have explicit prepared statements like database/sql
|
||||
// We'll store the query and use it with the pool
|
||||
return &PostgresStatement{pool: c.pool, query: query}, nil
|
||||
}
|
||||
|
||||
// PostgresTransaction implements the Transaction interface
|
||||
type PostgresTransaction struct {
|
||||
tx pgx.Tx
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Commit() error {
|
||||
return t.tx.Commit(context.Background())
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Rollback() error {
|
||||
return t.tx.Rollback(context.Background())
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := t.tx.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: transaction query failed: %w", err)
|
||||
}
|
||||
return &PostgresRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := t.tx.QueryRow(ctx, query, args...)
|
||||
return &PostgresRow{row: row}
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
tag, err := t.tx.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: transaction exec failed: %w", err)
|
||||
}
|
||||
return &PostgresResult{tag: tag}, nil
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
return &PostgresStatement{tx: t.tx, query: query}, nil
|
||||
}
|
||||
|
||||
// PostgresRows implements the Rows interface
|
||||
type PostgresRows struct {
|
||||
rows pgx.Rows
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Next() bool {
|
||||
return r.rows.Next()
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Scan(dest ...any) error {
|
||||
return r.rows.Scan(dest...)
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Columns() ([]string, error) {
|
||||
fields := r.rows.FieldDescriptions()
|
||||
columns := make([]string, len(fields))
|
||||
for i, field := range fields {
|
||||
columns[i] = field.Name
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Close() error {
|
||||
r.rows.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Err() error {
|
||||
return r.rows.Err()
|
||||
}
|
||||
|
||||
// PostgresRow implements the Row interface
|
||||
type PostgresRow struct {
|
||||
row pgx.Row
|
||||
}
|
||||
|
||||
func (r *PostgresRow) Scan(dest ...any) error {
|
||||
return r.row.Scan(dest...)
|
||||
}
|
||||
|
||||
// PostgresResult implements the Result interface
|
||||
type PostgresResult struct {
|
||||
tag pgconn.CommandTag
|
||||
}
|
||||
|
||||
func (r *PostgresResult) LastInsertId() (int64, error) {
|
||||
// PostgreSQL doesn't have AUTO_INCREMENT like MySQL
|
||||
// Users should use RETURNING clause or sequences
|
||||
return 0, fmt.Errorf("postgres: LastInsertId not supported, use RETURNING clause")
|
||||
}
|
||||
|
||||
func (r *PostgresResult) RowsAffected() (int64, error) {
|
||||
return r.tag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// PostgresStatement implements the Statement interface
|
||||
type PostgresStatement struct {
|
||||
pool *pgxpool.Pool
|
||||
tx pgx.Tx
|
||||
query string
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) Close() error {
|
||||
// pgx doesn't require explicit statement cleanup
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
||||
var rows pgx.Rows
|
||||
var err error
|
||||
|
||||
if s.tx != nil {
|
||||
rows, err = s.tx.Query(ctx, s.query, args...)
|
||||
} else {
|
||||
rows, err = s.pool.Query(ctx, s.query, args...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: statement query failed: %w", err)
|
||||
}
|
||||
return &PostgresRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) QueryRow(ctx context.Context, args ...any) Row {
|
||||
var row pgx.Row
|
||||
|
||||
if s.tx != nil {
|
||||
row = s.tx.QueryRow(ctx, s.query, args...)
|
||||
} else {
|
||||
row = s.pool.QueryRow(ctx, s.query, args...)
|
||||
}
|
||||
|
||||
return &PostgresRow{row: row}
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
||||
var tag pgconn.CommandTag
|
||||
var err error
|
||||
|
||||
if s.tx != nil {
|
||||
tag, err = s.tx.Exec(ctx, s.query, args...)
|
||||
} else {
|
||||
tag, err = s.pool.Exec(ctx, s.query, args...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: statement exec failed: %w", err)
|
||||
}
|
||||
return &PostgresResult{tag: tag}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register PostgreSQL driver on import
|
||||
RegisterDriver("postgres", &PostgresDriver{})
|
||||
}
|
367
modules/sql/sql.go
Normal file
367
modules/sql/sql.go
Normal file
@ -0,0 +1,367 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Driver interface for SQL database implementations
|
||||
type Driver interface {
|
||||
Open(dsn string) (Connection, error)
|
||||
Name() string
|
||||
}
|
||||
|
||||
// Connection represents a database connection
|
||||
type Connection interface {
|
||||
Close() error
|
||||
Ping(ctx context.Context) error
|
||||
Begin(ctx context.Context) (Transaction, error)
|
||||
Query(ctx context.Context, query string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, query string, args ...any) Row
|
||||
Exec(ctx context.Context, query string, args ...any) (Result, error)
|
||||
Prepare(ctx context.Context, query string) (Statement, error)
|
||||
}
|
||||
|
||||
// Transaction represents a database transaction
|
||||
type Transaction interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
Query(ctx context.Context, query string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, query string, args ...any) Row
|
||||
Exec(ctx context.Context, query string, args ...any) (Result, error)
|
||||
Prepare(ctx context.Context, query string) (Statement, error)
|
||||
}
|
||||
|
||||
// Rows represents query result rows
|
||||
type Rows interface {
|
||||
Next() bool
|
||||
Scan(dest ...any) error
|
||||
Columns() ([]string, error)
|
||||
Close() error
|
||||
Err() error
|
||||
}
|
||||
|
||||
// Row represents a single query result row
|
||||
type Row interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
// Result represents the result of an executed statement
|
||||
type Result interface {
|
||||
LastInsertId() (int64, error)
|
||||
RowsAffected() (int64, error)
|
||||
}
|
||||
|
||||
// Statement represents a prepared statement
|
||||
type Statement interface {
|
||||
Close() error
|
||||
Query(ctx context.Context, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, args ...any) Row
|
||||
Exec(ctx context.Context, args ...any) (Result, error)
|
||||
}
|
||||
|
||||
// Registry manages database drivers and connections
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
drivers map[string]Driver
|
||||
conns map[string]Connection
|
||||
nextID int
|
||||
}
|
||||
|
||||
var global = &Registry{
|
||||
drivers: make(map[string]Driver),
|
||||
conns: make(map[string]Connection),
|
||||
}
|
||||
|
||||
// RegisterDriver registers a database driver
|
||||
func RegisterDriver(name string, driver Driver) {
|
||||
global.mu.Lock()
|
||||
defer global.mu.Unlock()
|
||||
global.drivers[name] = driver
|
||||
}
|
||||
|
||||
// GetDriver returns a registered driver
|
||||
func GetDriver(name string) (Driver, bool) {
|
||||
global.mu.RLock()
|
||||
defer global.mu.RUnlock()
|
||||
driver, exists := global.drivers[name]
|
||||
return driver, exists
|
||||
}
|
||||
|
||||
// Connect opens a new database connection
|
||||
func Connect(driverName, dsn string) (string, error) {
|
||||
driver, exists := GetDriver(driverName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("unknown driver: %s", driverName)
|
||||
}
|
||||
|
||||
conn, err := driver.Open(dsn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
global.mu.Lock()
|
||||
defer global.mu.Unlock()
|
||||
|
||||
id := fmt.Sprintf("%s_%d", driverName, global.nextID)
|
||||
global.nextID++
|
||||
global.conns[id] = conn
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetConnection retrieves a connection by ID
|
||||
func GetConnection(id string) (Connection, bool) {
|
||||
global.mu.RLock()
|
||||
defer global.mu.RUnlock()
|
||||
conn, exists := global.conns[id]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// CloseConnection closes and removes a connection
|
||||
func CloseConnection(id string) error {
|
||||
global.mu.Lock()
|
||||
defer global.mu.Unlock()
|
||||
|
||||
conn, exists := global.conns[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("connection not found: %s", id)
|
||||
}
|
||||
|
||||
err := conn.Close()
|
||||
delete(global.conns, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// Lua function implementations
|
||||
|
||||
func luaConnect(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(2); err != nil {
|
||||
return s.PushError("connect: %v", err)
|
||||
}
|
||||
|
||||
driver, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("connect: driver must be a string")
|
||||
}
|
||||
|
||||
dsn, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
return s.PushError("connect: dsn must be a string")
|
||||
}
|
||||
|
||||
connID, err := Connect(driver, dsn)
|
||||
if err != nil {
|
||||
return s.PushError("connect: %v", err)
|
||||
}
|
||||
|
||||
s.PushString(connID)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaClose(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(1); err != nil {
|
||||
return s.PushError("close: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("close: connection id must be a string")
|
||||
}
|
||||
|
||||
if err := CloseConnection(connID); err != nil {
|
||||
return s.PushError("close: %v", err)
|
||||
}
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaPing(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(1); err != nil {
|
||||
return s.PushError("ping: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("ping: connection id must be a string")
|
||||
}
|
||||
|
||||
conn, exists := GetConnection(connID)
|
||||
if !exists {
|
||||
return s.PushError("ping: connection not found")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
return s.PushError("ping: %v", err)
|
||||
}
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaQuery(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(2); err != nil {
|
||||
return s.PushError("query: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("query: connection id must be a string")
|
||||
}
|
||||
|
||||
query, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
return s.PushError("query: query must be a string")
|
||||
}
|
||||
|
||||
conn, exists := GetConnection(connID)
|
||||
if !exists {
|
||||
return s.PushError("query: connection not found")
|
||||
}
|
||||
|
||||
// Collect arguments
|
||||
args := make([]any, s.GetTop()-2)
|
||||
for i := 3; i <= s.GetTop(); i++ {
|
||||
val, err := s.ToValue(i)
|
||||
if err != nil {
|
||||
args[i-3] = nil
|
||||
} else {
|
||||
args[i-3] = val
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := conn.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return s.PushError("query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return s.PushError("query: failed to get columns: %v", err)
|
||||
}
|
||||
|
||||
// Build result array
|
||||
s.CreateTable(0, 0)
|
||||
rowIndex := 1
|
||||
|
||||
for rows.Next() {
|
||||
// Create values slice for scanning
|
||||
values := make([]any, len(columns))
|
||||
valuePtrs := make([]any, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return s.PushError("query: scan error: %v", err)
|
||||
}
|
||||
|
||||
// Create row table
|
||||
s.CreateTable(0, len(columns))
|
||||
for i, col := range columns {
|
||||
s.PushString(col)
|
||||
if err := s.PushValue(values[i]); err != nil {
|
||||
s.PushNil()
|
||||
}
|
||||
s.SetTable(-3)
|
||||
}
|
||||
|
||||
// Add to result array
|
||||
s.PushNumber(float64(rowIndex))
|
||||
s.PushCopy(-2)
|
||||
s.SetTable(-4)
|
||||
s.Pop(1) // Remove row table copy
|
||||
|
||||
rowIndex++
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return s.PushError("query: %v", err)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaExec(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(2); err != nil {
|
||||
return s.PushError("exec: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("exec: connection id must be a string")
|
||||
}
|
||||
|
||||
query, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
return s.PushError("exec: query must be a string")
|
||||
}
|
||||
|
||||
conn, exists := GetConnection(connID)
|
||||
if !exists {
|
||||
return s.PushError("exec: connection not found")
|
||||
}
|
||||
|
||||
// Collect arguments
|
||||
args := make([]any, s.GetTop()-2)
|
||||
for i := 3; i <= s.GetTop(); i++ {
|
||||
val, err := s.ToValue(i)
|
||||
if err != nil {
|
||||
args[i-3] = nil
|
||||
} else {
|
||||
args[i-3] = val
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := conn.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return s.PushError("exec: %v", err)
|
||||
}
|
||||
|
||||
// Return result info
|
||||
s.CreateTable(0, 2)
|
||||
|
||||
lastID, _ := result.LastInsertId()
|
||||
s.PushString("last_insert_id")
|
||||
s.PushNumber(float64(lastID))
|
||||
s.SetTable(-3)
|
||||
|
||||
affected, _ := result.RowsAffected()
|
||||
s.PushString("rows_affected")
|
||||
s.PushNumber(float64(affected))
|
||||
s.SetTable(-3)
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// GetFunctionList returns all Lua-callable functions
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"sql_connect": luaConnect,
|
||||
"sql_close": luaClose,
|
||||
"sql_ping": luaPing,
|
||||
"sql_query": luaQuery,
|
||||
"sql_exec": luaExec,
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register SQLite driver on import
|
||||
RegisterDriver("sqlite", &SQLiteDriver{})
|
||||
}
|
384
modules/sql/sqlite.go
Normal file
384
modules/sql/sqlite.go
Normal file
@ -0,0 +1,384 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
// SQLiteDriver implements the Driver interface for SQLite
|
||||
type SQLiteDriver struct{}
|
||||
|
||||
func (d *SQLiteDriver) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (d *SQLiteDriver) Open(dsn string) (Connection, error) {
|
||||
conn, err := sqlite.OpenConn(dsn, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to open database: %w", err)
|
||||
}
|
||||
|
||||
return &SQLiteConnection{conn: conn}, nil
|
||||
}
|
||||
|
||||
// SQLiteConnection implements the Connection interface
|
||||
type SQLiteConnection struct {
|
||||
conn *sqlite.Conn
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Ping(ctx context.Context) error {
|
||||
return sqlitex.ExecuteTransient(c.conn, "SELECT 1", nil)
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Begin(ctx context.Context) (Transaction, error) {
|
||||
if err := sqlitex.ExecuteTransient(c.conn, "BEGIN", nil); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to begin transaction: %w", err)
|
||||
}
|
||||
return &SQLiteTransaction{conn: c.conn}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
stmt, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to prepare query: %w", err)
|
||||
}
|
||||
|
||||
if err := c.bindArgs(stmt, args...); err != nil {
|
||||
stmt.Finalize()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SQLiteRows{stmt: stmt, hasNext: true}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
rows, err := c.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return &SQLiteRow{err: err}
|
||||
}
|
||||
return &SQLiteRow{rows: rows.(*SQLiteRows)}
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
stmt, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
if err := c.bindArgs(stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hasRow, err := stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
|
||||
}
|
||||
|
||||
// Consume all rows if any
|
||||
for hasRow {
|
||||
hasRow, err = stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &SQLiteResult{
|
||||
lastInsertID: c.conn.LastInsertRowID(),
|
||||
rowsAffected: c.conn.Changes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
stmt, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
|
||||
}
|
||||
return &SQLiteStatement{stmt: stmt, conn: c.conn}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) bindArgs(stmt *sqlite.Stmt, args ...any) error {
|
||||
for i, arg := range args {
|
||||
paramIndex := i + 1
|
||||
|
||||
if arg == nil {
|
||||
stmt.BindNull(paramIndex)
|
||||
continue
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case int:
|
||||
stmt.BindInt64(paramIndex, int64(v))
|
||||
case int64:
|
||||
stmt.BindInt64(paramIndex, v)
|
||||
case float64:
|
||||
stmt.BindFloat(paramIndex, v)
|
||||
case string:
|
||||
stmt.BindText(paramIndex, v)
|
||||
case bool:
|
||||
if v {
|
||||
stmt.BindInt64(paramIndex, 1)
|
||||
} else {
|
||||
stmt.BindInt64(paramIndex, 0)
|
||||
}
|
||||
case []byte:
|
||||
stmt.BindBytes(paramIndex, v)
|
||||
default:
|
||||
return fmt.Errorf("sqlite: unsupported parameter type: %T", arg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SQLiteTransaction implements the Transaction interface
|
||||
type SQLiteTransaction struct {
|
||||
conn *sqlite.Conn
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Commit() error {
|
||||
return sqlitex.ExecuteTransient(t.conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Rollback() error {
|
||||
return sqlitex.ExecuteTransient(t.conn, "ROLLBACK", nil)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.Query(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.QueryRow(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.Exec(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.Prepare(ctx, query)
|
||||
}
|
||||
|
||||
// SQLiteRows implements the Rows interface
|
||||
type SQLiteRows struct {
|
||||
stmt *sqlite.Stmt
|
||||
hasNext bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Next() bool {
|
||||
if r.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !r.hasNext {
|
||||
return false
|
||||
}
|
||||
|
||||
var err error
|
||||
r.hasNext, err = r.stmt.Step()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
return r.hasNext
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Scan(dest ...any) error {
|
||||
if r.err != nil {
|
||||
return r.err
|
||||
}
|
||||
|
||||
for i, d := range dest {
|
||||
if i >= r.stmt.ColumnCount() {
|
||||
break
|
||||
}
|
||||
|
||||
switch ptr := d.(type) {
|
||||
case *any:
|
||||
*ptr = r.getValue(i)
|
||||
case *string:
|
||||
*ptr = r.stmt.ColumnText(i)
|
||||
case *int:
|
||||
*ptr = int(r.stmt.ColumnInt64(i))
|
||||
case *int64:
|
||||
*ptr = r.stmt.ColumnInt64(i)
|
||||
case *float64:
|
||||
*ptr = r.stmt.ColumnFloat(i)
|
||||
case *bool:
|
||||
*ptr = r.stmt.ColumnInt64(i) != 0
|
||||
case *[]byte:
|
||||
if r.stmt.ColumnType(i) == sqlite.TypeBlob {
|
||||
// Get blob size first
|
||||
size := r.stmt.ColumnBytes(i, nil)
|
||||
if size == 0 {
|
||||
*ptr = []byte{}
|
||||
} else {
|
||||
buf := make([]byte, size)
|
||||
r.stmt.ColumnBytes(i, buf)
|
||||
*ptr = buf
|
||||
}
|
||||
} else {
|
||||
// Convert text to bytes
|
||||
*ptr = []byte(r.stmt.ColumnText(i))
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("sqlite: unsupported scan destination type: %T", d)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) getValue(index int) any {
|
||||
switch r.stmt.ColumnType(index) {
|
||||
case sqlite.TypeInteger:
|
||||
return r.stmt.ColumnInt64(index)
|
||||
case sqlite.TypeFloat:
|
||||
return r.stmt.ColumnFloat(index)
|
||||
case sqlite.TypeText:
|
||||
return r.stmt.ColumnText(index)
|
||||
case sqlite.TypeBlob:
|
||||
// For blob columns, we need to handle this differently
|
||||
// First, get the size by calling with nil buffer
|
||||
size := r.stmt.ColumnBytes(index, nil)
|
||||
if size == 0 {
|
||||
return []byte{}
|
||||
}
|
||||
// Now allocate buffer and get the actual data
|
||||
buf := make([]byte, size)
|
||||
r.stmt.ColumnBytes(index, buf)
|
||||
return buf
|
||||
case sqlite.TypeNull:
|
||||
return nil
|
||||
default:
|
||||
return r.stmt.ColumnText(index)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Columns() ([]string, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
|
||||
columns := make([]string, r.stmt.ColumnCount())
|
||||
for i := range columns {
|
||||
columns[i] = r.stmt.ColumnName(i)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Close() error {
|
||||
if r.stmt != nil {
|
||||
return r.stmt.Finalize()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Err() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// SQLiteRow implements the Row interface
|
||||
type SQLiteRow struct {
|
||||
rows *SQLiteRows
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *SQLiteRow) Scan(dest ...any) error {
|
||||
if r.err != nil {
|
||||
return r.err
|
||||
}
|
||||
|
||||
if r.rows == nil {
|
||||
return fmt.Errorf("sqlite: no rows available")
|
||||
}
|
||||
|
||||
if !r.rows.Next() {
|
||||
if r.rows.Err() != nil {
|
||||
return r.rows.Err()
|
||||
}
|
||||
return fmt.Errorf("sqlite: no rows in result set")
|
||||
}
|
||||
|
||||
return r.rows.Scan(dest...)
|
||||
}
|
||||
|
||||
// SQLiteResult implements the Result interface
|
||||
type SQLiteResult struct {
|
||||
lastInsertID int64
|
||||
rowsAffected int
|
||||
}
|
||||
|
||||
func (r *SQLiteResult) LastInsertId() (int64, error) {
|
||||
return r.lastInsertID, nil
|
||||
}
|
||||
|
||||
func (r *SQLiteResult) RowsAffected() (int64, error) {
|
||||
return int64(r.rowsAffected), nil
|
||||
}
|
||||
|
||||
// SQLiteStatement implements the Statement interface
|
||||
type SQLiteStatement struct {
|
||||
stmt *sqlite.Stmt
|
||||
conn *sqlite.Conn
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) Close() error {
|
||||
return s.stmt.Finalize()
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
||||
conn := &SQLiteConnection{conn: s.conn}
|
||||
if err := conn.bindArgs(s.stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SQLiteRows{stmt: s.stmt, hasNext: true}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) QueryRow(ctx context.Context, args ...any) Row {
|
||||
rows, err := s.Query(ctx, args...)
|
||||
if err != nil {
|
||||
return &SQLiteRow{err: err}
|
||||
}
|
||||
return &SQLiteRow{rows: rows.(*SQLiteRows)}
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
||||
conn := &SQLiteConnection{conn: s.conn}
|
||||
if err := conn.bindArgs(s.stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hasRow, err := s.stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
|
||||
}
|
||||
|
||||
// Consume all rows if any
|
||||
for hasRow {
|
||||
hasRow, err = s.stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &SQLiteResult{
|
||||
lastInsertID: s.conn.LastInsertRowID(),
|
||||
rowsAffected: s.conn.Changes(),
|
||||
}, nil
|
||||
}
|
655
modules/sqlite/sqlite.lua
Normal file
655
modules/sqlite/sqlite.lua
Normal file
@ -0,0 +1,655 @@
|
||||
local str = require("string")
|
||||
local sqlite = {}
|
||||
|
||||
local Connection = {}
|
||||
Connection.__index = Connection
|
||||
|
||||
function Connection:close()
|
||||
if self._id then
|
||||
local ok = moonshark.sql_close(self._id)
|
||||
self._id = nil
|
||||
return ok
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function Connection:ping()
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_ping(self._id)
|
||||
end
|
||||
|
||||
function Connection:query(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
query_str = str.normalize_whitespace(query_str)
|
||||
return moonshark.sql_query(self._id, query_str, ...)
|
||||
end
|
||||
|
||||
function Connection:exec(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
query_str = str.normalize_whitespace(query_str)
|
||||
return moonshark.sql_exec(self._id, query_str, ...)
|
||||
end
|
||||
|
||||
function Connection:query_row(query_str, ...)
|
||||
local results = self:query(query_str, ...)
|
||||
if results and #results > 0 then
|
||||
return results[1]
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function Connection:query_value(query_str, ...)
|
||||
local row = self:query_row(query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Enhanced transaction support
|
||||
function Connection:begin()
|
||||
local result = self:exec("BEGIN")
|
||||
if result then
|
||||
return {
|
||||
conn = self,
|
||||
active = true,
|
||||
|
||||
commit = function(tx)
|
||||
if tx.active then
|
||||
local result = tx.conn:exec("COMMIT")
|
||||
tx.active = false
|
||||
return result
|
||||
end
|
||||
return false
|
||||
end,
|
||||
|
||||
rollback = function(tx)
|
||||
if tx.active then
|
||||
local result = tx.conn:exec("ROLLBACK")
|
||||
tx.active = false
|
||||
return result
|
||||
end
|
||||
return false
|
||||
end,
|
||||
|
||||
query = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:query(query_str, ...)
|
||||
end,
|
||||
|
||||
exec = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:exec(query_str, ...)
|
||||
end,
|
||||
|
||||
query_row = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:query_row(query_str, ...)
|
||||
end,
|
||||
|
||||
query_value = function(tx, query_str, ...)
|
||||
if not tx.active then
|
||||
error("Transaction is not active")
|
||||
end
|
||||
return tx.conn:query_value(query_str, ...)
|
||||
end
|
||||
}
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Enhanced query builder helpers with string utilities
|
||||
function Connection:insert(table_name, data)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys = {}
|
||||
local values = {}
|
||||
local placeholders = {}
|
||||
|
||||
for key, value in pairs(data) do
|
||||
table.insert(keys, key)
|
||||
table.insert(values, value)
|
||||
table.insert(placeholders, "?")
|
||||
end
|
||||
|
||||
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
||||
table = table_name,
|
||||
columns = str.join(keys, ", "),
|
||||
placeholders = str.join(placeholders, ", ")
|
||||
})
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:upsert(table_name, data, conflict_columns)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys = {}
|
||||
local values = {}
|
||||
local placeholders = {}
|
||||
local updates = {}
|
||||
|
||||
for key, value in pairs(data) do
|
||||
table.insert(keys, key)
|
||||
table.insert(values, value)
|
||||
table.insert(placeholders, "?")
|
||||
table.insert(updates, str.template("${key} = excluded.${key}", {key = key}))
|
||||
end
|
||||
|
||||
local conflict_clause = ""
|
||||
if conflict_columns then
|
||||
if type(conflict_columns) == "string" then
|
||||
conflict_clause = str.template("(${columns})", {columns = conflict_columns})
|
||||
else
|
||||
conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")})
|
||||
end
|
||||
end
|
||||
|
||||
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
|
||||
table = table_name,
|
||||
columns = str.join(keys, ", "),
|
||||
placeholders = str.join(placeholders, ", "),
|
||||
conflict = conflict_clause,
|
||||
updates = str.join(updates, ", ")
|
||||
})
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:update(table_name, data, where_clause, ...)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if str.is_blank(where_clause) then
|
||||
error("WHERE clause cannot be empty for UPDATE")
|
||||
end
|
||||
|
||||
local sets = {}
|
||||
local values = {}
|
||||
|
||||
for key, value in pairs(data) do
|
||||
table.insert(sets, str.template("${key} = ?", {key = key}))
|
||||
table.insert(values, value)
|
||||
end
|
||||
|
||||
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
||||
table = table_name,
|
||||
sets = str.join(sets, ", "),
|
||||
where = where_clause
|
||||
})
|
||||
|
||||
-- Add WHERE clause parameters
|
||||
local where_args = {...}
|
||||
for i = 1, #where_args do
|
||||
table.insert(values, where_args[i])
|
||||
end
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:delete(table_name, where_clause, ...)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if str.is_blank(where_clause) then
|
||||
error("WHERE clause cannot be empty for DELETE")
|
||||
end
|
||||
|
||||
local query = str.template("DELETE FROM ${table} WHERE ${where}", {
|
||||
table = table_name,
|
||||
where = where_clause
|
||||
})
|
||||
return self:exec(query, ...)
|
||||
end
|
||||
|
||||
function Connection:select(table_name, columns, where_clause, ...)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
columns = columns or "*"
|
||||
if type(columns) == "table" then
|
||||
columns = str.join(columns, ", ")
|
||||
end
|
||||
|
||||
local query
|
||||
if where_clause and not str.is_blank(where_clause) then
|
||||
query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
||||
columns = columns,
|
||||
table = table_name,
|
||||
where = where_clause
|
||||
})
|
||||
return self:query(query, ...)
|
||||
else
|
||||
query = str.template("SELECT ${columns} FROM ${table}", {
|
||||
columns = columns,
|
||||
table = table_name
|
||||
})
|
||||
return self:query(query)
|
||||
end
|
||||
end
|
||||
|
||||
-- Enhanced schema helpers with validation
|
||||
function Connection:table_exists(table_name)
|
||||
if str.is_blank(table_name) then
|
||||
return false
|
||||
end
|
||||
|
||||
local result = self:query_value(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
str.trim(table_name)
|
||||
)
|
||||
return result ~= nil
|
||||
end
|
||||
|
||||
function Connection:column_exists(table_name, column_name)
|
||||
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||
return false
|
||||
end
|
||||
|
||||
local result = self:query(str.template("PRAGMA table_info(${table})", {table = table_name}))
|
||||
if result then
|
||||
for _, row in ipairs(result) do
|
||||
if str.iequals(row.name, str.trim(column_name)) then
|
||||
return true
|
||||
end
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function Connection:create_table(table_name, schema)
|
||||
if str.is_blank(table_name) or str.is_blank(schema) then
|
||||
error("Table name and schema cannot be empty")
|
||||
end
|
||||
|
||||
local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", {
|
||||
table = table_name,
|
||||
schema = str.trim(schema)
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:drop_table(table_name)
|
||||
if str.is_blank(table_name) then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local query = str.template("DROP TABLE IF EXISTS ${table}", {table = table_name})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:add_column(table_name, column_def)
|
||||
if str.is_blank(table_name) or str.is_blank(column_def) then
|
||||
error("Table name and column definition cannot be empty")
|
||||
end
|
||||
|
||||
local query = str.template("ALTER TABLE ${table} ADD COLUMN ${column}", {
|
||||
table = table_name,
|
||||
column = str.trim(column_def)
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:create_index(index_name, table_name, columns, unique)
|
||||
if str.is_blank(index_name) or str.is_blank(table_name) then
|
||||
error("Index name and table name cannot be empty")
|
||||
end
|
||||
|
||||
local unique_clause = unique and "UNIQUE " or ""
|
||||
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns)
|
||||
|
||||
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table} (${columns})", {
|
||||
unique = unique_clause,
|
||||
index = index_name,
|
||||
table = table_name,
|
||||
columns = columns_str
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:drop_index(index_name)
|
||||
if str.is_blank(index_name) then
|
||||
error("Index name cannot be empty")
|
||||
end
|
||||
|
||||
local query = str.template("DROP INDEX IF EXISTS ${index}", {index = index_name})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
-- Enhanced SQLite-specific functions
|
||||
function Connection:vacuum()
|
||||
return self:exec("VACUUM")
|
||||
end
|
||||
|
||||
function Connection:analyze()
|
||||
return self:exec("ANALYZE")
|
||||
end
|
||||
|
||||
function Connection:integrity_check()
|
||||
return self:query("PRAGMA integrity_check")
|
||||
end
|
||||
|
||||
function Connection:foreign_keys(enabled)
|
||||
local value = enabled and "ON" or "OFF"
|
||||
return self:exec(str.template("PRAGMA foreign_keys = ${value}", {value = value}))
|
||||
end
|
||||
|
||||
function Connection:journal_mode(mode)
|
||||
mode = mode or "WAL"
|
||||
if not str.contains(str.upper(mode), "DELETE") and
|
||||
not str.contains(str.upper(mode), "TRUNCATE") and
|
||||
not str.contains(str.upper(mode), "PERSIST") and
|
||||
not str.contains(str.upper(mode), "MEMORY") and
|
||||
not str.contains(str.upper(mode), "WAL") and
|
||||
not str.contains(str.upper(mode), "OFF") then
|
||||
error("Invalid journal mode: " .. mode)
|
||||
end
|
||||
return self:query(str.template("PRAGMA journal_mode = ${mode}", {mode = str.upper(mode)}))
|
||||
end
|
||||
|
||||
function Connection:synchronous(level)
|
||||
level = level or "NORMAL"
|
||||
local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"}
|
||||
local level_upper = str.upper(level)
|
||||
|
||||
local valid = false
|
||||
for _, valid_level in ipairs(valid_levels) do
|
||||
if level_upper == valid_level then
|
||||
valid = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not valid then
|
||||
error("Invalid synchronous level: " .. level)
|
||||
end
|
||||
|
||||
return self:exec(str.template("PRAGMA synchronous = ${level}", {level = level_upper}))
|
||||
end
|
||||
|
||||
function Connection:cache_size(size)
|
||||
size = size or -64000
|
||||
if type(size) ~= "number" then
|
||||
error("Cache size must be a number")
|
||||
end
|
||||
return self:exec(str.template("PRAGMA cache_size = ${size}", {size = tostring(size)}))
|
||||
end
|
||||
|
||||
function Connection:temp_store(mode)
|
||||
mode = mode or "MEMORY"
|
||||
local valid_modes = {"DEFAULT", "FILE", "MEMORY"}
|
||||
local mode_upper = str.upper(mode)
|
||||
|
||||
local valid = false
|
||||
for _, valid_mode in ipairs(valid_modes) do
|
||||
if mode_upper == valid_mode then
|
||||
valid = true
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if not valid then
|
||||
error("Invalid temp_store mode: " .. mode)
|
||||
end
|
||||
|
||||
return self:exec(str.template("PRAGMA temp_store = ${mode}", {mode = mode_upper}))
|
||||
end
|
||||
|
||||
-- Connection management with enhanced path handling
|
||||
function sqlite.open(database_path)
|
||||
database_path = database_path or ":memory:"
|
||||
|
||||
-- Clean up path
|
||||
if database_path ~= ":memory:" then
|
||||
database_path = str.trim(database_path)
|
||||
if str.is_blank(database_path) then
|
||||
database_path = ":memory:"
|
||||
end
|
||||
end
|
||||
|
||||
local conn_id = moonshark.sql_connect("sqlite", database_path)
|
||||
if conn_id then
|
||||
local conn = {_id = conn_id}
|
||||
setmetatable(conn, Connection)
|
||||
return conn
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
sqlite.connect = sqlite.open
|
||||
|
||||
-- Enhanced quick execution functions
|
||||
function sqlite.query(database_path, query_str, ...)
|
||||
local conn = sqlite.open(database_path)
|
||||
if not conn then
|
||||
error(str.template("Failed to open SQLite database: ${path}", {
|
||||
path = database_path or ":memory:"
|
||||
}))
|
||||
end
|
||||
|
||||
local results = conn:query(query_str, ...)
|
||||
conn:close()
|
||||
return results
|
||||
end
|
||||
|
||||
function sqlite.exec(database_path, query_str, ...)
|
||||
local conn = sqlite.open(database_path)
|
||||
if not conn then
|
||||
error(str.template("Failed to open SQLite database: ${path}", {
|
||||
path = database_path or ":memory:"
|
||||
}))
|
||||
end
|
||||
|
||||
local result = conn:exec(query_str, ...)
|
||||
conn:close()
|
||||
return result
|
||||
end
|
||||
|
||||
function sqlite.query_row(database_path, query_str, ...)
|
||||
local results = sqlite.query(database_path, query_str, ...)
|
||||
if results and #results > 0 then
|
||||
return results[1]
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function sqlite.query_value(database_path, query_str, ...)
|
||||
local row = sqlite.query_row(database_path, query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Enhanced migration helpers
|
||||
function sqlite.migrate(database_path, migrations)
|
||||
local conn = sqlite.open(database_path)
|
||||
if not conn then
|
||||
error("Failed to open SQLite database for migration")
|
||||
end
|
||||
|
||||
-- Create migrations table
|
||||
conn:create_table("_migrations",
|
||||
"id INTEGER PRIMARY KEY, name TEXT UNIQUE, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP")
|
||||
|
||||
local tx = conn:begin()
|
||||
if not tx then
|
||||
conn:close()
|
||||
error("Failed to begin migration transaction")
|
||||
end
|
||||
|
||||
local success = true
|
||||
local error_msg = ""
|
||||
|
||||
for _, migration in ipairs(migrations) do
|
||||
if not migration.name or str.is_blank(migration.name) then
|
||||
error_msg = "Migration must have a non-empty name"
|
||||
success = false
|
||||
break
|
||||
end
|
||||
|
||||
-- Check if migration already applied
|
||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?",
|
||||
str.trim(migration.name))
|
||||
if not existing then
|
||||
local ok, err = pcall(function()
|
||||
if type(migration.up) == "string" then
|
||||
conn:exec(migration.up)
|
||||
elseif type(migration.up) == "function" then
|
||||
migration.up(conn)
|
||||
else
|
||||
error("Migration 'up' must be string or function")
|
||||
end
|
||||
end)
|
||||
|
||||
if ok then
|
||||
conn:exec("INSERT INTO _migrations (name) VALUES (?)", str.trim(migration.name))
|
||||
print(str.template("Applied migration: ${name}", {name = migration.name}))
|
||||
else
|
||||
success = false
|
||||
error_msg = str.template("Migration '${name}' failed: ${error}", {
|
||||
name = migration.name,
|
||||
error = err or "unknown error"
|
||||
})
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if success then
|
||||
tx:commit()
|
||||
else
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error(error_msg)
|
||||
end
|
||||
|
||||
conn:close()
|
||||
return true
|
||||
end
|
||||
|
||||
-- Enhanced result processing utilities
|
||||
function sqlite.to_array(results, column_name)
|
||||
if not results or #results == 0 then
|
||||
return {}
|
||||
end
|
||||
|
||||
if str.is_blank(column_name) then
|
||||
error("Column name cannot be empty")
|
||||
end
|
||||
|
||||
local array = {}
|
||||
for i, row in ipairs(results) do
|
||||
array[i] = row[column_name]
|
||||
end
|
||||
return array
|
||||
end
|
||||
|
||||
function sqlite.to_map(results, key_column, value_column)
|
||||
if not results or #results == 0 then
|
||||
return {}
|
||||
end
|
||||
|
||||
if str.is_blank(key_column) then
|
||||
error("Key column name cannot be empty")
|
||||
end
|
||||
|
||||
local map = {}
|
||||
for _, row in ipairs(results) do
|
||||
local key = row[key_column]
|
||||
map[key] = value_column and row[value_column] or row
|
||||
end
|
||||
return map
|
||||
end
|
||||
|
||||
function sqlite.group_by(results, column_name)
|
||||
if not results or #results == 0 then
|
||||
return {}
|
||||
end
|
||||
|
||||
if str.is_blank(column_name) then
|
||||
error("Column name cannot be empty")
|
||||
end
|
||||
|
||||
local groups = {}
|
||||
for _, row in ipairs(results) do
|
||||
local key = row[column_name]
|
||||
if not groups[key] then
|
||||
groups[key] = {}
|
||||
end
|
||||
table.insert(groups[key], row)
|
||||
end
|
||||
return groups
|
||||
end
|
||||
|
||||
-- Enhanced debug helper
|
||||
function sqlite.print_results(results)
|
||||
if not results or #results == 0 then
|
||||
print("No results")
|
||||
return
|
||||
end
|
||||
|
||||
-- Get column names from first row
|
||||
local columns = {}
|
||||
for col, _ in pairs(results[1]) do
|
||||
table.insert(columns, col)
|
||||
end
|
||||
table.sort(columns)
|
||||
|
||||
-- Calculate column widths for better formatting
|
||||
local widths = {}
|
||||
for _, col in ipairs(columns) do
|
||||
widths[col] = str.length(col)
|
||||
end
|
||||
|
||||
for _, row in ipairs(results) do
|
||||
for _, col in ipairs(columns) do
|
||||
local value = tostring(row[col] or "")
|
||||
widths[col] = math.max(widths[col], str.length(value))
|
||||
end
|
||||
end
|
||||
|
||||
-- Print header with proper spacing
|
||||
local header_parts = {}
|
||||
local separator_parts = {}
|
||||
for _, col in ipairs(columns) do
|
||||
table.insert(header_parts, str.pad_right(col, widths[col]))
|
||||
table.insert(separator_parts, str.repeat_("-", widths[col]))
|
||||
end
|
||||
|
||||
print(str.join(header_parts, " | "))
|
||||
print(str.join(separator_parts, "-+-"))
|
||||
|
||||
-- Print rows with proper spacing
|
||||
for _, row in ipairs(results) do
|
||||
local value_parts = {}
|
||||
for _, col in ipairs(columns) do
|
||||
local value = tostring(row[col] or "")
|
||||
table.insert(value_parts, str.pad_right(value, widths[col]))
|
||||
end
|
||||
print(str.join(value_parts, " | "))
|
||||
end
|
||||
end
|
||||
|
||||
return sqlite
|
Loading…
x
Reference in New Issue
Block a user