move database wrapper to its own library, add migration functionality
This commit is contained in:
parent
888fd70a6f
commit
38a39790cf
3
go.mod
3
go.mod
@ -3,9 +3,9 @@ module dk
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
git.sharkk.net/Sharkk/Sashimi v1.0.1
|
||||
git.sharkk.net/Sharkk/Sushi v1.2.0
|
||||
github.com/valyala/fasthttp v1.65.0
|
||||
zombiezen.com/go/sqlite v1.4.2
|
||||
)
|
||||
|
||||
require (
|
||||
@ -24,4 +24,5 @@ require (
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.37.1 // indirect
|
||||
zombiezen.com/go/sqlite v1.4.2 // indirect
|
||||
)
|
||||
|
2
go.sum
2
go.sum
@ -1,3 +1,5 @@
|
||||
git.sharkk.net/Sharkk/Sashimi v1.0.1 h1:5YMmxnCgcsyasg5J91AS5FVzzJoDZ17I3J4hlJyyMR4=
|
||||
git.sharkk.net/Sharkk/Sashimi v1.0.1/go.mod h1:wTMnO6jo34LIjpDJ0qToq14RbwP6Uf4HtdWDmqxrdAM=
|
||||
git.sharkk.net/Sharkk/Sushi v1.2.0 h1:RwOCZmgaOqtkmuK2Z7/esdLbhSXJZphsOsWEHni4Sss=
|
||||
git.sharkk.net/Sharkk/Sushi v1.2.0/go.mod h1:S84ACGkuZ+BKzBO4lb5WQnm5aw9+l7VSO2T1bjzxL3o=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
|
@ -2,380 +2,63 @@ package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
sashimi "git.sharkk.net/Sharkk/Sashimi"
|
||||
)
|
||||
|
||||
var placeholderRegex = regexp.MustCompile(`%[sd]`)
|
||||
var db *sashimi.DB
|
||||
|
||||
var db *sqlite.Conn
|
||||
|
||||
// Init initializes the database connection with WAL mode and cache settings
|
||||
// Init initializes the database connection
|
||||
func Init(dbPath string) error {
|
||||
conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
||||
var err error
|
||||
db, err = sashimi.New(dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
// Enable WAL mode
|
||||
if err := sqlitex.Execute(conn, "PRAGMA journal_mode=WAL", nil); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to enable WAL mode: %w", err)
|
||||
}
|
||||
|
||||
// Set generous cache size (64MB)
|
||||
if err := sqlitex.Execute(conn, "PRAGMA cache_size=-65536", nil); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to set cache size: %w", err)
|
||||
}
|
||||
|
||||
// Enable foreign keys
|
||||
if err := sqlitex.Execute(conn, "PRAGMA foreign_keys=ON", nil); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||
}
|
||||
|
||||
db = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
// DB returns the global database connection
|
||||
func DB() *sqlite.Conn {
|
||||
func DB() *sashimi.DB {
|
||||
if db == nil {
|
||||
panic("database not initialized - call Init() first")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Scan scans a SQLite statement result into a struct using field names
|
||||
func Scan(stmt *sqlite.Stmt, dest any) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to struct")
|
||||
// Close closes the database connection
|
||||
func Close() error {
|
||||
if db != nil {
|
||||
return db.Close()
|
||||
}
|
||||
|
||||
elem := v.Elem()
|
||||
typ := elem.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
|
||||
fieldValue := elem.Field(i)
|
||||
if !fieldValue.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find column index by name
|
||||
colIndex := -1
|
||||
for j := 0; j < stmt.ColumnCount(); j++ {
|
||||
if stmt.ColumnName(j) == columnName {
|
||||
colIndex = j
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if colIndex == -1 {
|
||||
continue // Column not found
|
||||
}
|
||||
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fieldValue.SetInt(stmt.ColumnInt64(colIndex))
|
||||
case reflect.String:
|
||||
fieldValue.SetString(stmt.ColumnText(colIndex))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fieldValue.SetFloat(stmt.ColumnFloat(colIndex))
|
||||
case reflect.Bool:
|
||||
fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query executes a query with fmt-style placeholders and automatically binds parameters
|
||||
func Query(query string, args ...any) (*sqlite.Stmt, error) {
|
||||
// Replace fmt placeholders with SQLite placeholders
|
||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||
|
||||
stmt, err := DB().Prepare(convertedQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Bind parameters with correct types
|
||||
for i, arg := range args {
|
||||
if i >= len(paramTypes) {
|
||||
break
|
||||
}
|
||||
|
||||
switch paramTypes[i] {
|
||||
case "s": // string
|
||||
if s, ok := arg.(string); ok {
|
||||
stmt.BindText(i+1, s)
|
||||
} else {
|
||||
stmt.BindText(i+1, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
case "d": // integer
|
||||
switch v := arg.(type) {
|
||||
case int:
|
||||
stmt.BindInt64(i+1, int64(v))
|
||||
case int32:
|
||||
stmt.BindInt64(i+1, int64(v))
|
||||
case int64:
|
||||
stmt.BindInt64(i+1, v)
|
||||
case float64:
|
||||
stmt.BindInt64(i+1, int64(v))
|
||||
default:
|
||||
if i64, err := strconv.ParseInt(fmt.Sprintf("%v", arg), 10, 64); err == nil {
|
||||
stmt.BindInt64(i+1, i64)
|
||||
} else {
|
||||
stmt.BindInt64(i+1, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
// Wrapper functions for convenience
|
||||
func Query(query string, args ...any) (*sashimi.Stmt, error) {
|
||||
return db.Query(query, args...)
|
||||
}
|
||||
|
||||
// Get executes a query and returns the first row
|
||||
func Get(dest any, query string, args ...any) error {
|
||||
stmt, err := Query(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
hasRow, err := stmt.Step()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasRow {
|
||||
return fmt.Errorf("no rows found")
|
||||
}
|
||||
|
||||
return Scan(stmt, dest)
|
||||
return db.Get(dest, query, args...)
|
||||
}
|
||||
|
||||
// Select executes a query and scans all rows into a slice
|
||||
func Select(dest any, query string, args ...any) error {
|
||||
destValue := reflect.ValueOf(dest)
|
||||
if destValue.Kind() != reflect.Ptr || destValue.Elem().Kind() != reflect.Slice {
|
||||
return fmt.Errorf("dest must be a pointer to slice")
|
||||
}
|
||||
|
||||
sliceValue := destValue.Elem()
|
||||
elemType := sliceValue.Type().Elem()
|
||||
|
||||
// Ensure element type is a pointer to struct
|
||||
if elemType.Kind() != reflect.Ptr || elemType.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("slice elements must be pointers to structs")
|
||||
}
|
||||
|
||||
stmt, err := Query(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
for {
|
||||
hasRow, err := stmt.Step()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasRow {
|
||||
break
|
||||
}
|
||||
|
||||
// Create new instance of the element type
|
||||
newElem := reflect.New(elemType.Elem())
|
||||
if err := Scan(stmt, newElem.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sliceValue.Set(reflect.Append(sliceValue, newElem))
|
||||
}
|
||||
|
||||
return nil
|
||||
return db.Select(dest, query, args...)
|
||||
}
|
||||
|
||||
// Exec executes a statement with fmt-style placeholders
|
||||
func Exec(query string, args ...any) error {
|
||||
convertedQuery, paramTypes := convertPlaceholders(query)
|
||||
|
||||
sqlArgs := make([]any, len(args))
|
||||
for i, arg := range args {
|
||||
if i < len(paramTypes) && paramTypes[i] == "d" {
|
||||
// Convert to int64 for integer parameters
|
||||
switch v := arg.(type) {
|
||||
case int:
|
||||
sqlArgs[i] = int64(v)
|
||||
case int32:
|
||||
sqlArgs[i] = int64(v)
|
||||
case int64:
|
||||
sqlArgs[i] = v
|
||||
default:
|
||||
sqlArgs[i] = arg
|
||||
}
|
||||
} else {
|
||||
sqlArgs[i] = arg
|
||||
}
|
||||
}
|
||||
|
||||
return sqlitex.Execute(DB(), convertedQuery, &sqlitex.ExecOptions{
|
||||
Args: sqlArgs,
|
||||
})
|
||||
return db.Exec(query, args...)
|
||||
}
|
||||
|
||||
// Update updates specific fields in the database
|
||||
func Update(tableName string, fields map[string]any, whereField string, whereValue any) error {
|
||||
if len(fields) == 0 {
|
||||
return nil // No changes
|
||||
}
|
||||
|
||||
// Build UPDATE query
|
||||
setParts := make([]string, 0, len(fields))
|
||||
args := make([]any, 0, len(fields)+1)
|
||||
|
||||
for field, value := range fields {
|
||||
setParts = append(setParts, field+" = ?")
|
||||
args = append(args, value)
|
||||
}
|
||||
|
||||
args = append(args, whereValue)
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?",
|
||||
tableName, strings.Join(setParts, ", "), whereField)
|
||||
|
||||
return sqlitex.Execute(DB(), query, &sqlitex.ExecOptions{
|
||||
Args: args,
|
||||
})
|
||||
return db.Update(tableName, fields, whereField, whereValue)
|
||||
}
|
||||
|
||||
// Insert inserts a struct into the database
|
||||
func Insert(tableName string, obj any, excludeFields ...string) (int64, error) {
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
}
|
||||
t := v.Type()
|
||||
|
||||
exclude := make(map[string]bool)
|
||||
for _, field := range excludeFields {
|
||||
exclude[toSnakeCase(field)] = true
|
||||
}
|
||||
|
||||
var columns []string
|
||||
var placeholders []string
|
||||
var args []any
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
columnName := toSnakeCase(field.Name)
|
||||
if exclude[columnName] {
|
||||
continue
|
||||
}
|
||||
|
||||
columns = append(columns, columnName)
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, v.Field(i).Interface())
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
||||
|
||||
stmt, err := DB().Prepare(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
// Bind parameters
|
||||
for i, arg := range args {
|
||||
switch v := arg.(type) {
|
||||
case string:
|
||||
stmt.BindText(i+1, v)
|
||||
case int, int32, int64:
|
||||
stmt.BindInt64(i+1, reflect.ValueOf(v).Int())
|
||||
case float32, float64:
|
||||
stmt.BindFloat(i+1, reflect.ValueOf(v).Float())
|
||||
default:
|
||||
stmt.BindText(i+1, fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stmt.Step()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return DB().LastInsertRowID(), nil
|
||||
return db.Insert(tableName, obj, excludeFields...)
|
||||
}
|
||||
|
||||
// Transaction executes multiple operations atomically
|
||||
func Transaction(fn func() error) error {
|
||||
conn := DB()
|
||||
|
||||
// Begin transaction
|
||||
if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute operations
|
||||
err := fn()
|
||||
|
||||
if err != nil {
|
||||
// Rollback on error
|
||||
sqlitex.Execute(conn, "ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit on success
|
||||
return sqlitex.Execute(conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
func convertPlaceholders(query string) (string, []string) {
|
||||
var paramTypes []string
|
||||
|
||||
convertedQuery := placeholderRegex.ReplaceAllStringFunc(query, func(match string) string {
|
||||
paramTypes = append(paramTypes, match[1:]) // Remove % prefix
|
||||
return "?"
|
||||
})
|
||||
|
||||
return convertedQuery, paramTypes
|
||||
}
|
||||
|
||||
// toSnakeCase converts PascalCase to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
runes := []rune(s)
|
||||
|
||||
for i, r := range runes {
|
||||
if i > 0 {
|
||||
prev := runes[i-1]
|
||||
|
||||
// Add underscore before digit if previous char was letter
|
||||
if unicode.IsDigit(r) && unicode.IsLetter(prev) {
|
||||
result.WriteByte('_')
|
||||
}
|
||||
// Add underscore before uppercase letter
|
||||
if unicode.IsUpper(r) {
|
||||
// Don't add if previous was also uppercase (unless end of acronym)
|
||||
if !unicode.IsUpper(prev) ||
|
||||
(i+1 < len(runes) && unicode.IsLower(runes[i+1])) {
|
||||
result.WriteByte('_')
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
}
|
||||
return result.String()
|
||||
return db.Transaction(fn)
|
||||
}
|
||||
|
87
main.go
87
main.go
@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@ -21,6 +22,8 @@ import (
|
||||
"git.sharkk.net/Sharkk/Sushi/csrf"
|
||||
"git.sharkk.net/Sharkk/Sushi/session"
|
||||
"git.sharkk.net/Sharkk/Sushi/timing"
|
||||
|
||||
sashimi "git.sharkk.net/Sharkk/Sashimi"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -36,17 +39,90 @@ func main() {
|
||||
case "serve":
|
||||
flag.CommandLine.Parse(os.Args[2:])
|
||||
startServer(port)
|
||||
case "migrate":
|
||||
handleMigrationCommand()
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", os.Args[1])
|
||||
fmt.Fprintln(os.Stderr, "Available commands:")
|
||||
fmt.Fprintln(os.Stderr, " serve - Start the server")
|
||||
fmt.Fprintln(os.Stderr, " migrate - Run pending migrations")
|
||||
fmt.Fprintln(os.Stderr, " migrate new - Create a new migration")
|
||||
fmt.Fprintln(os.Stderr, " migrate status - Show migration status")
|
||||
fmt.Fprintln(os.Stderr, " (no command) - Start the server")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func getDBPath() (string, error) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current working directory: %w", err)
|
||||
}
|
||||
return filepath.Join(cwd, "data", "dk.db"), nil
|
||||
}
|
||||
|
||||
func getMigrationsDir() (string, error) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current working directory: %w", err)
|
||||
}
|
||||
return filepath.Join(cwd, "sql"), nil
|
||||
}
|
||||
|
||||
func initDatabase() error {
|
||||
dbPath, err := getDBPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return database.Init(dbPath)
|
||||
}
|
||||
|
||||
func handleMigrationCommand() {
|
||||
if err := initDatabase(); err != nil {
|
||||
log.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
migrationsDir, err := getMigrationsDir()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get migrations directory: %v", err)
|
||||
}
|
||||
|
||||
migrator := sashimi.NewMigrator(database.DB(), migrationsDir)
|
||||
|
||||
if len(os.Args) < 3 {
|
||||
if err := migrator.Run(); err != nil {
|
||||
log.Fatalf("Migration failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
subcommand := os.Args[2]
|
||||
switch subcommand {
|
||||
case "new":
|
||||
if len(os.Args) < 4 {
|
||||
log.Fatal("Usage: migrate new <migration_name>")
|
||||
}
|
||||
migrationName := strings.Join(os.Args[3:], " ")
|
||||
if err := migrator.CreateNew(migrationName); err != nil {
|
||||
log.Fatalf("Failed to create migration: %v", err)
|
||||
}
|
||||
case "status":
|
||||
if err := migrator.Status(); err != nil {
|
||||
log.Fatalf("Failed to get migration status: %v", err)
|
||||
}
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown migration subcommand: %s\n", subcommand)
|
||||
fmt.Fprintln(os.Stderr, "Available subcommands:")
|
||||
fmt.Fprintln(os.Stderr, " (none) - Run pending migrations")
|
||||
fmt.Fprintln(os.Stderr, " new - Create a new migration")
|
||||
fmt.Fprintln(os.Stderr, " status - Show migration status")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func startServer(port string) {
|
||||
fmt.Println("Starting Dragon Knight server...")
|
||||
fmt.Println("Dragon Knight is starting!")
|
||||
if err := start(port); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
@ -58,11 +134,10 @@ func start(port string) error {
|
||||
return fmt.Errorf("failed to get current working directory: %w", err)
|
||||
}
|
||||
|
||||
err = database.Init(filepath.Join(cwd, "data/dk.db"))
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize database:", err)
|
||||
if err := initDatabase(); err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
defer database.DB().Close()
|
||||
defer database.Close()
|
||||
|
||||
control.Init(filepath.Join(cwd, "data/control.json"))
|
||||
defer control.Save()
|
||||
@ -116,7 +191,7 @@ func start(port string) error {
|
||||
}()
|
||||
|
||||
<-c
|
||||
log.Println("\nReceived shutdown signal, shutting down gracefully...")
|
||||
log.Println("\nShutting down! Beginning cleanup...")
|
||||
|
||||
log.Println("Saving sessions...")
|
||||
sushi.SaveSessions()
|
||||
|
@ -1,3 +1,6 @@
|
||||
-- Migration 1: create database
|
||||
-- Created: 2025-08-22 15:00:10
|
||||
|
||||
DROP TABLE IF EXISTS babble;
|
||||
CREATE TABLE babble (
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@ -426,3 +429,5 @@ CREATE TABLE fight_logs (
|
||||
`name` TEXT NOT NULL DEFAULT '',
|
||||
`created` INTEGER NOT NULL DEFAULT (unixepoch())
|
||||
);
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user