diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua index 0adf5f4..dc30e89 100644 --- a/runner/lua/sqlite.lua +++ b/runner/lua/sqlite.lua @@ -27,20 +27,37 @@ local connection_mt = { error("connection:query: query must be a string", 2) end - -- Handle params (named or positional) - local processed_params = handle_params(params, ...) + -- Fast path for no parameters + if params == nil and select('#', ...) == 0 then + return __sqlite_query(self.db_name, query, nil, self.id) + end - -- Call with appropriate arguments - if type(processed_params) == "table" and processed_params[1] ~= nil then - -- Positional parameters - insert self.db_name and query at the beginning - table.insert(processed_params, 1, query) - table.insert(processed_params, 1, self.db_name) - -- Add connection ID at the end - table.insert(processed_params, self.id) - return __sqlite_query(unpack(processed_params)) + -- Handle various parameter types efficiently + if type(params) == "table" then + -- If it's an array-like table with numeric keys + if params[1] ~= nil then + -- For positional parameters, we want to include the required prefix args + local args = {self.db_name, query} + -- Append all parameters + for i=1, #params do + args[i+2] = params[i] + end + -- Add connection ID + args[#args+1] = self.id + return __sqlite_query(unpack(args)) + else + -- Named parameters + return __sqlite_query(self.db_name, query, params, self.id) + end else - -- Named parameters or no parameters - return __sqlite_query(self.db_name, query, processed_params, self.id) + -- Variadic parameters, combine with first param + local args = {self.db_name, query, params} + local n = select('#', ...) + for i=1, n do + args[i+3] = select(i, ...) + end + args[#args+1] = self.id + return __sqlite_query(unpack(args)) end end, @@ -50,20 +67,37 @@ local connection_mt = { error("connection:exec: query must be a string", 2) end - -- Handle params (named or positional) - local processed_params = handle_params(params, ...) + -- Fast path for no parameters + if params == nil and select('#', ...) == 0 then + return __sqlite_exec(self.db_name, query, nil, self.id) + end - -- Call with appropriate arguments - if type(processed_params) == "table" and processed_params[1] ~= nil then - -- Positional parameters - insert self.db_name and query at the beginning - table.insert(processed_params, 1, query) - table.insert(processed_params, 1, self.db_name) - -- Add connection ID at the end - table.insert(processed_params, self.id) - return __sqlite_exec(unpack(processed_params)) + -- Handle various parameter types efficiently + if type(params) == "table" then + -- If it's an array-like table with numeric keys + if params[1] ~= nil then + -- For positional parameters, we want to include the required prefix args + local args = {self.db_name, query} + -- Append all parameters + for i=1, #params do + args[i+2] = params[i] + end + -- Add connection ID + args[#args+1] = self.id + return __sqlite_exec(unpack(args)) + else + -- Named parameters + return __sqlite_exec(self.db_name, query, params, self.id) + end else - -- Named parameters or no parameters - return __sqlite_exec(self.db_name, query, processed_params, self.id) + -- Variadic parameters, combine with first param + local args = {self.db_name, query, params} + local n = select('#', ...) + for i=1, n do + args[i+3] = select(i, ...) + end + args[#args+1] = self.id + return __sqlite_exec(unpack(args)) end end, @@ -79,7 +113,7 @@ local connection_mt = { local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)") if index_def then - -- Parse index definition: INDEX:idx_name(col1,col2) + -- Parse index definition local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)") if index_name and columns_str then @@ -106,34 +140,32 @@ local connection_mt = { error("connection:create_table: no columns specified", 2) end - -- Create the table - local query = string.format("CREATE TABLE IF NOT EXISTS %s (%s)", - table_name, table.concat(columns, ", ")) + -- Build combined statement for table and indices + local statements = {} - local result = self:exec(query) + -- Add the CREATE TABLE statement + table.insert(statements, string.format( + "CREATE TABLE IF NOT EXISTS %s (%s)", + table_name, + table.concat(columns, ", ") + )) - -- Create indices - if #indices > 0 then - self:begin() + -- Add CREATE INDEX statements + for _, idx in ipairs(indices) do + local unique = idx.unique and "UNIQUE " or "" - for _, idx in ipairs(indices) do - local unique = idx.unique and "UNIQUE " or "" - - local index_query = string.format( - "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", - unique, - idx.name, - table_name, - table.concat(idx.columns, ", ") - ) - - self:exec(index_query) - end - - self:commit() + table.insert(statements, string.format( + "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", + unique, + idx.name, + table_name, + table.concat(idx.columns, ", ") + )) end - return result + -- Execute all statements in a single transaction + local combined_sql = table.concat(statements, ";\n") + return self:exec(combined_sql) end, -- Insert a row or multiple rows @@ -142,43 +174,58 @@ local connection_mt = { error("connection:insert: data must be a table", 2) end + -- Case 1: Named columns with array data if columns and type(columns) == "table" then - local placeholders = {} - for _ in ipairs(columns) do - table.insert(placeholders, "?") - end - - local query = string.format( - "INSERT INTO %s (%s) VALUES (%s)", - table_name, - table.concat(columns, ", "), - table.concat(placeholders, ", ") - ) - - local use_transaction = #data > 1 and type(data[1]) == "table" - - if use_transaction then - self:begin() - end - - local affected = 0 - + -- Check if we have multiple rows if #data > 0 and type(data[1]) == "table" then - for _, row in ipairs(data) do - local result = self:exec(query, row) - affected = affected + result + -- Build a single multi-value INSERT + local placeholders = {} + local values = {} + local params = {} + local param_index = 1 + + for i, row in ipairs(data) do + local row_placeholders = {} + for j, _ in ipairs(columns) do + local param_name = "p" .. param_index + table.insert(row_placeholders, ":" .. param_name) + params[param_name] = row[j] + param_index = param_index + 1 + end + table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")") end + + local query = string.format( + "INSERT INTO %s (%s) VALUES %s", + table_name, + table.concat(columns, ", "), + table.concat(placeholders, ", ") + ) + + return self:exec(query, params) else - affected = self:exec(query, data) - end + -- Single row with defined columns + local placeholders = {} + local params = {} - if use_transaction then - self:commit() - end + for i, col in ipairs(columns) do + local param_name = "p" .. i + table.insert(placeholders, ":" .. param_name) + params[param_name] = data[i] + end - return affected + local query = string.format( + "INSERT INTO %s (%s) VALUES (%s)", + table_name, + table.concat(columns, ", "), + table.concat(placeholders, ", ") + ) + + return self:exec(query, params) + end end + -- Case 2: Object-style single row {col1=val1, col2=val2} if data[1] == nil and next(data) ~= nil then local columns = {} local placeholders = {} @@ -186,8 +233,9 @@ local connection_mt = { for col, val in pairs(data) do table.insert(columns, col) - table.insert(placeholders, ":" .. col) - params[":" .. col] = val + local param_name = "p" .. #columns + table.insert(placeholders, ":" .. param_name) + params[param_name] = val end local query = string.format( @@ -200,34 +248,74 @@ local connection_mt = { return self:exec(query, params) end + -- Case 3: Array of rows without predefined columns if #data > 0 and type(data[1]) == "table" then - self:begin() - local affected = 0 + -- Extract columns from the first row + local first_row = data[1] + local inferred_columns = {} - for _, row in ipairs(data) do - local result = self:insert(table_name, row) - affected = affected + result + -- Determine if first row is array or object + local is_array = first_row[1] ~= nil + + if is_array then + -- Cannot infer column names from array + error("connection:insert: column names required for array data", 2) + else + -- Get columns from object keys + for col, _ in pairs(first_row) do + table.insert(inferred_columns, col) + end + + -- Build multi-value INSERT + local placeholders = {} + local params = {} + local param_index = 1 + + for _, row in ipairs(data) do + local row_placeholders = {} + for _, col in ipairs(inferred_columns) do + local param_name = "p" .. param_index + table.insert(row_placeholders, ":" .. param_name) + params[param_name] = row[col] + param_index = param_index + 1 + end + table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")") + end + + local query = string.format( + "INSERT INTO %s (%s) VALUES %s", + table_name, + table.concat(inferred_columns, ", "), + table.concat(placeholders, ", ") + ) + + return self:exec(query, params) end - - self:commit() - return affected end error("connection:insert: invalid data format", 2) end, -- Update rows - update = function(self, table_name, data, where, where_params) + update = function(self, table_name, data, where, where_params, ...) if type(data) ~= "table" then error("connection:update: data must be a table", 2) end + -- Fast path for when there's no data + if next(data) == nil then + return 0 + end + local sets = {} local params = {} + local param_index = 1 for col, val in pairs(data) do - table.insert(sets, col .. " = :" .. col) - params[col] = val + local param_name = "p" .. param_index + table.insert(sets, col .. " = :" .. param_name) + params[param_name] = val + param_index = param_index + 1 end local query = string.format( @@ -240,8 +328,53 @@ local connection_mt = { query = query .. " WHERE " .. where if where_params then - for k, v in pairs(where_params) do - params[k] = v + if type(where_params) == "table" then + -- Handle named parameters in WHERE clause + for k, v in pairs(where_params) do + local param_name + if type(k) == "string" and k:sub(1, 1) == ":" then + param_name = k:sub(2) + else + param_name = "w" .. param_index + -- Replace the placeholder in the WHERE clause + where = where:gsub(":" .. k, ":" .. param_name) + end + params[param_name] = v + param_index = param_index + 1 + end + else + -- Handle positional parameters (? placeholders) + local args = {where_params, ...} + local pos = 1 + local offset = 0 + + -- Replace ? with named parameters + while true do + local start_pos, end_pos = where:find("?", pos) + if not start_pos then break end + + local param_name = "w" .. param_index + local replacement = ":" .. param_name + + where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1) + + if args[pos - offset] ~= nil then + params[param_name] = args[pos - offset] + else + params[param_name] = nil + end + + param_index = param_index + 1 + pos = start_pos + #replacement + offset = offset + 1 + end + + query = string.format( + "UPDATE %s SET %s WHERE %s", + table_name, + table.concat(sets, ", "), + where + ) end end end @@ -260,15 +393,25 @@ local connection_mt = { return self:exec(query, params) end, - -- Get one row + -- Get one row efficiently get_one = function(self, query, params, ...) - -- Handle both named and positional parameters + if type(query) ~= "string" then + error("connection:get_one: query must be a string", 2) + end + + -- Add LIMIT 1 to query if not already limited + local limited_query = query + if not query:lower():match("limit%s+%d+") then + limited_query = query .. " LIMIT 1" + end + local results if select('#', ...) > 0 then - results = self:query(query, params, ...) + results = self:query(limited_query, params, ...) else - results = self:query(query, params) + results = self:query(limited_query, params) end + return results[1] end, diff --git a/runner/sqlite.go b/runner/sqlite.go index a93db12..38a8f4c 100644 --- a/runner/sqlite.go +++ b/runner/sqlite.go @@ -13,16 +13,13 @@ import ( "Moonshark/utils/logger" - "maps" - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // SQLiteConnection tracks an active connection type SQLiteConnection struct { - DbName string - Conn *sqlite.Conn - Pool *sqlitex.Pool + Conn *sqlite.Conn + Pool *sqlitex.Pool } // SQLiteManager handles database connections @@ -54,7 +51,6 @@ func CleanupSQLite() { sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() - // Release all connections and close pools for id, conn := range sqliteManager.activeConns { if conn.Pool != nil { conn.Pool.Put(conn.Conn) @@ -79,9 +75,6 @@ func ReleaseActiveConnections(state *luajit.State) { return } - sqliteManager.mu.Lock() - defer sqliteManager.mu.Unlock() - // Get active connections table from Lua state.GetGlobal("__active_sqlite_connections") if !state.IsTable(-1) { @@ -89,6 +82,9 @@ func ReleaseActiveConnections(state *luajit.State) { return } + sqliteManager.mu.Lock() + defer sqliteManager.mu.Unlock() + // Iterate through active connections state.PushNil() // Start iteration for state.Next(-2) { @@ -113,8 +109,8 @@ func ReleaseActiveConnections(state *luajit.State) { state.SetGlobal("__active_sqlite_connections") } -// getPool returns a connection pool for the specified database -func getPool(dbName string) (*sqlitex.Pool, error) { +// getConnection returns a connection for the database +func getConnection(dbName, connID string) (*sqlite.Conn, error) { if sqliteManager == nil { return nil, errors.New("SQLite not initialized") } @@ -125,326 +121,325 @@ func getPool(dbName string) (*sqlitex.Pool, error) { return nil, errors.New("invalid database name") } - // Check for existing pool with read lock + // Check for existing connection sqliteManager.mu.RLock() - pool, exists := sqliteManager.pools[dbName] - sqliteManager.mu.RUnlock() + conn, exists := sqliteManager.activeConns[connID] if exists { - return pool, nil + sqliteManager.mu.RUnlock() + return conn.Conn, nil } + sqliteManager.mu.RUnlock() - // Create new pool with write lock + // Get or create pool under write lock sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() - // Double check if another goroutine created it - if pool, exists = sqliteManager.pools[dbName]; exists { - return pool, nil + // Double-check if a connection was created while waiting for lock + if conn, exists = sqliteManager.activeConns[connID]; exists { + return conn.Conn, nil } - // Create database file path and pool - dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") - pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - sqliteManager.pools[dbName] = pool - return pool, nil -} - -// getConnection returns a connection from the pool -func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) { - // Check for existing connection first - sqliteManager.mu.RLock() - conn, exists := sqliteManager.activeConns[connID] - sqliteManager.mu.RUnlock() - - if exists { - return conn.Conn, conn.Pool, nil - } - - // Get the pool - pool, err := getPool(dbName) - if err != nil { - return nil, nil, err + // Get or create pool + pool, exists := sqliteManager.pools[dbName] + if !exists { + dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") + var err error + pool, err = sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + sqliteManager.pools[dbName] = pool } // Get a connection dbConn, err := pool.Take(context.Background()) if err != nil { - return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err) + return nil, fmt.Errorf("failed to get connection from pool: %w", err) } // Store connection - sqliteManager.mu.Lock() sqliteManager.activeConns[connID] = &SQLiteConnection{ - DbName: dbName, - Conn: dbConn, - Pool: pool, + Conn: dbConn, + Pool: pool, } - sqliteManager.mu.Unlock() - return dbConn, pool, nil + return dbConn, nil } -// processParams extracts parameters and connection ID from Lua state -func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) { - connID = defaultConnID - - // Check if using positional parameters - if state.GetTop() >= 3 && !state.IsTable(3) { - isPositional = true - paramCount := state.GetTop() - 2 // Count all args after db and query - - // Check if last param is a connection ID - lastIdx := paramCount + 2 // db(1) + query(2) + paramCount - if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { - connID = state.ToString(lastIdx) - paramCount-- // Exclude connID from param count - } - - // Create array for positional parameters - positionalParams = make([]any, paramCount) - - // Collect all parameters - for i := 0; i < paramCount; i++ { - paramIdx := i + 3 // Params start at index 3 - switch state.GetType(paramIdx) { - case luajit.TypeNumber: - positionalParams[i] = state.ToNumber(paramIdx) - case luajit.TypeString: - positionalParams[i] = state.ToString(paramIdx) - case luajit.TypeBoolean: - positionalParams[i] = state.ToBoolean(paramIdx) - case luajit.TypeNil: - positionalParams[i] = nil - default: - val, errConv := state.ToValue(paramIdx) - if errConv != nil { - return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv) - } - positionalParams[i] = val - } - } - return nil, connID, isPositional, positionalParams, nil +// releaseConnection returns a connection to its pool +func releaseConnection(connID string) { + if sqliteManager == nil { + return } - // Named parameter handling - if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { - connID = state.ToString(4) + sqliteManager.mu.Lock() + defer sqliteManager.mu.Unlock() + + conn, exists := sqliteManager.activeConns[connID] + if !exists { + return } - // Get table parameters if present - if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { - params, err = state.ToTable(3) + if conn.Pool != nil { + conn.Pool.Put(conn.Conn) } - - return params, connID, isPositional, nil, err + delete(sqliteManager.activeConns, connID) } -// prepareExecOptions prepares SQLite execution options based on parameters -func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions { - execOpts := &sqlitex.ExecOptions{} - - if params == nil && !isPositional { - return execOpts - } - - // Prepare parameters - isArray := false - var namedParams map[string]any - var arrParams []any - - // Check for array parameters - if m, ok := params.(map[string]any); ok { - if arr, hasArray := m[""]; hasArray { - isArray = true - if slice, ok := arr.([]any); ok { - arrParams = slice - } else if floatSlice, ok := arr.([]float64); ok { - arrParams = make([]any, len(floatSlice)) - for i, v := range floatSlice { - arrParams[i] = v - } - } - } else { - // Process named parameters - namedParams = make(map[string]any, len(m)) - for k, v := range m { - if len(k) > 0 && k[0] != ':' { - namedParams[":"+k] = v - } else { - namedParams[k] = v - } - } - } - } else if slice, ok := params.([]any); ok { - isArray = true - arrParams = slice - } else if floatSlice, ok := params.([]float64); ok { - isArray = true - arrParams = make([]any, len(floatSlice)) - for i, v := range floatSlice { - arrParams[i] = v - } - } - - // Use positional params if explicitly provided - if isPositional { - arrParams = positionalParams - isArray = true - } - - // Limit positional params to actual placeholders - if isArray && arrParams != nil { - placeholderCount := strings.Count(query, "?") - if len(arrParams) > placeholderCount { - arrParams = arrParams[:placeholderCount] - } - execOpts.Args = arrParams - } else if namedParams != nil { - execOpts.Named = namedParams - } - - return execOpts -} - -// sqlOperation handles both query and exec operations -func sqlOperation(state *luajit.State, isQuery bool) int { - operation := "query" - if !isQuery { - operation = "exec" - } - - // Get database name - if !state.IsString(1) { - state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation)) +// sqlQuery executes a SQL query and returns results +func sqlQuery(state *luajit.State) int { + // Get required parameters + if state.GetTop() < 3 || !state.IsString(1) || !state.IsString(2) { + state.PushString("sqlite.query: requires database name, query, and optional parameters") return -1 } + dbName := state.ToString(1) - - // Get query - if !state.IsString(2) { - state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation)) - return -1 - } query := state.ToString(2) - - // Generate a temporary connection ID if needed - defaultConnID := fmt.Sprintf("temp_%p", &query) - - // Process parameters and get connection ID - params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) - return -1 - } + connID := fmt.Sprintf("temp_%p", &query) // Get connection - conn, pool, err := getConnection(dbName, connID) + conn, err := getConnection(dbName, connID) if err != nil { - state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) + state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } + // Create execution options + var execOpts sqlitex.ExecOptions + rows := make([]map[string]any, 0, 16) + // For temporary connections, defer release - if strings.HasPrefix(connID, "temp_") { - defer func() { - sqliteManager.mu.Lock() - delete(sqliteManager.activeConns, connID) - sqliteManager.mu.Unlock() - pool.Put(conn) - }() - } + defer releaseConnection(connID) - // Prepare execution options - execOpts := prepareExecOptions(query, params, isPositional, positionalParams) - - // Define rows slice outside the closure - var rows []map[string]any - - // For queries, add result function - if isQuery { - execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { - row := make(map[string]any) - columnCount := stmt.ColumnCount() - - for i := range columnCount { - columnName := stmt.ColumnName(i) - - switch stmt.ColumnType(i) { - case sqlite.TypeInteger: - row[columnName] = stmt.ColumnInt64(i) - case sqlite.TypeFloat: - row[columnName] = stmt.ColumnFloat(i) - case sqlite.TypeText: - row[columnName] = stmt.ColumnText(i) - case sqlite.TypeBlob: - blobSize := stmt.ColumnLen(i) - buf := make([]byte, blobSize) - row[columnName] = stmt.ColumnBytes(i, buf) - case sqlite.TypeNull: - row[columnName] = nil - } + // Set up parameters if provided + if state.GetTop() >= 3 { + if state.IsTable(3) { + params, err := state.ToTable(3) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error())) + return -1 } - // Add row copy to results - rowCopy := make(map[string]any, len(row)) - maps.Copy(rowCopy, row) - rows = append(rows, rowCopy) - return nil + // Check for array-style params + if arr, ok := params[""]; ok { + if arrParams, ok := arr.([]any); ok { + execOpts.Args = arrParams + } else if floatArr, ok := arr.([]float64); ok { + args := make([]any, len(floatArr)) + for i, v := range floatArr { + args[i] = v + } + execOpts.Args = args + } + } else { + // Named parameters + named := make(map[string]any, len(params)) + for k, v := range params { + if len(k) > 0 && k[0] != ':' { + named[":"+k] = v + } else { + named[k] = v + } + } + execOpts.Named = named + } + } else { + // Positional parameters + count := state.GetTop() - 2 + args := make([]any, count) + for i := 0; i < count; i++ { + idx := i + 3 + switch state.GetType(idx) { + case luajit.TypeNumber: + args[i] = state.ToNumber(idx) + case luajit.TypeString: + args[i] = state.ToString(idx) + case luajit.TypeBoolean: + args[i] = state.ToBoolean(idx) + case luajit.TypeNil: + args[i] = nil + default: + val, err := state.ToValue(idx) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error())) + return -1 + } + args[i] = val + } + } + execOpts.Args = args } } + // Set up result function + execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { + row := make(map[string]any) + colCount := stmt.ColumnCount() + + for i := range colCount { + colName := stmt.ColumnName(i) + switch stmt.ColumnType(i) { + case sqlite.TypeInteger: + row[colName] = stmt.ColumnInt64(i) + case sqlite.TypeFloat: + row[colName] = stmt.ColumnFloat(i) + case sqlite.TypeText: + row[colName] = stmt.ColumnText(i) + case sqlite.TypeBlob: + blobSize := stmt.ColumnLen(i) + buf := make([]byte, blobSize) + row[colName] = stmt.ColumnBytes(i, buf) + case sqlite.TypeNull: + row[colName] = nil + } + } + rows = append(rows, row) // No need to copy, this row is used only once + return nil + } + // Execute query - var execErr error - if isQuery || execOpts.Args != nil || execOpts.Named != nil { - execErr = sqlitex.Execute(conn, query, execOpts) - } else { - // Use ExecScript for queries without parameters - execErr = sqlitex.ExecScript(conn, query) - } - - if execErr != nil { - state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error())) + if err := sqlitex.Execute(conn, query, &execOpts); err != nil { + state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } - // Return results for query, affected rows for exec - if isQuery { - // Create result table with rows - state.NewTable() - for i, row := range rows { - state.PushNumber(float64(i + 1)) - if err := state.PushTable(row); err != nil { - state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) - return -1 - } - state.SetTable(-3) + // Create result table + state.NewTable() + for i, row := range rows { + state.PushNumber(float64(i + 1)) + if err := state.PushTable(row); err != nil { + state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) + return -1 } - } else { - // Return number of affected rows - state.PushNumber(float64(conn.Changes())) + state.SetTable(-3) } return 1 } -// luaSQLQuery executes a SQL query and returns results to Lua -func luaSQLQuery(state *luajit.State) int { - return sqlOperation(state, true) -} +// sqlExec executes a SQL statement without returning results +func sqlExec(state *luajit.State) int { + // Get required parameters + if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { + state.PushString("sqlite.exec: requires database name and query") + return -1 + } -// luaSQLExec executes a SQL statement without returning results -func luaSQLExec(state *luajit.State) int { - return sqlOperation(state, false) + dbName := state.ToString(1) + query := state.ToString(2) + connID := fmt.Sprintf("temp_%p", &query) + + // Get connection + conn, err := getConnection(dbName, connID) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) + return -1 + } + + // For temporary connections, defer release + defer releaseConnection(connID) + + // Check if parameters are provided + hasParams := state.GetTop() >= 3 && !state.IsNil(3) + hasPlaceholders := strings.Contains(query, "?") || strings.Contains(query, ":") + + // Fast path for simple queries with no parameters + if !hasParams || !hasPlaceholders { + if err := sqlitex.ExecScript(conn, query); err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) + return -1 + } + state.PushNumber(float64(conn.Changes())) + return 1 + } + + // Create execution options for parameterized query + var execOpts sqlitex.ExecOptions + + // Set up parameters + if state.IsTable(3) { + params, err := state.ToTable(3) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error())) + return -1 + } + + // Check for array-style params + if arr, ok := params[""]; ok { + if arrParams, ok := arr.([]any); ok { + execOpts.Args = arrParams + } else if floatArr, ok := arr.([]float64); ok { + args := make([]any, len(floatArr)) + for i, v := range floatArr { + args[i] = v + } + execOpts.Args = args + } + } else { + // Named parameters + named := make(map[string]any, len(params)) + for k, v := range params { + if len(k) > 0 && k[0] != ':' { + named[":"+k] = v + } else { + named[k] = v + } + } + execOpts.Named = named + } + } else { + // Positional parameters + count := state.GetTop() - 2 + args := make([]any, count) + for i := 0; i < count; i++ { + idx := i + 3 + switch state.GetType(idx) { + case luajit.TypeNumber: + args[i] = state.ToNumber(idx) + case luajit.TypeString: + args[i] = state.ToString(idx) + case luajit.TypeBoolean: + args[i] = state.ToBoolean(idx) + case luajit.TypeNil: + args[i] = nil + default: + val, err := state.ToValue(idx) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error())) + return -1 + } + args[i] = val + } + } + execOpts.Args = args + } + + // Count the number of placeholders to validate parameter count + if execOpts.Args != nil { + placeholderCount := strings.Count(query, "?") + if len(execOpts.Args) > placeholderCount { + state.PushString(fmt.Sprintf("sqlite.exec: too many parameters provided (%d) for placeholders (%d)", + len(execOpts.Args), placeholderCount)) + return -1 + } + } + + // Execute with parameters + if err := sqlitex.Execute(conn, query, &execOpts); err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) + return -1 + } + + // Return affected rows + state.PushNumber(float64(conn.Changes())) + return 1 } // RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { - if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil { + if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err } - return state.RegisterGoFunction("__sqlite_exec", luaSQLExec) + return state.RegisterGoFunction("__sqlite_exec", sqlExec) }