diff --git a/go.mod b/go.mod index 960489a..18cfee9 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module dk go 1.25.0 require ( + git.sharkk.net/Sharkk/Sashimi v1.0.1 git.sharkk.net/Sharkk/Sushi v1.2.0 github.com/valyala/fasthttp v1.65.0 - zombiezen.com/go/sqlite v1.4.2 ) require ( @@ -24,4 +24,5 @@ require ( modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect modernc.org/sqlite v1.37.1 // indirect + zombiezen.com/go/sqlite v1.4.2 // indirect ) diff --git a/go.sum b/go.sum index a31e7bc..cd607f6 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +git.sharkk.net/Sharkk/Sashimi v1.0.1 h1:5YMmxnCgcsyasg5J91AS5FVzzJoDZ17I3J4hlJyyMR4= +git.sharkk.net/Sharkk/Sashimi v1.0.1/go.mod h1:wTMnO6jo34LIjpDJ0qToq14RbwP6Uf4HtdWDmqxrdAM= git.sharkk.net/Sharkk/Sushi v1.2.0 h1:RwOCZmgaOqtkmuK2Z7/esdLbhSXJZphsOsWEHni4Sss= git.sharkk.net/Sharkk/Sushi v1.2.0/go.mod h1:S84ACGkuZ+BKzBO4lb5WQnm5aw9+l7VSO2T1bjzxL3o= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= diff --git a/internal/database/wrapper.go b/internal/database/wrapper.go index e53da86..79b0a41 100644 --- a/internal/database/wrapper.go +++ b/internal/database/wrapper.go @@ -2,380 +2,63 @@ package database import ( "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "unicode" - "zombiezen.com/go/sqlite" - "zombiezen.com/go/sqlite/sqlitex" + sashimi "git.sharkk.net/Sharkk/Sashimi" ) -var placeholderRegex = regexp.MustCompile(`%[sd]`) +var db *sashimi.DB -var db *sqlite.Conn - -// Init initializes the database connection with WAL mode and cache settings +// Init initializes the database connection func Init(dbPath string) error { - conn, err := sqlite.OpenConn(dbPath, sqlite.OpenReadWrite|sqlite.OpenCreate) + var err error + db, err = sashimi.New(dbPath) if err != nil { - return fmt.Errorf("failed to open database: %w", err) + return fmt.Errorf("failed to initialize database: %w", err) } - - // Enable WAL mode - if err := sqlitex.Execute(conn, "PRAGMA journal_mode=WAL", nil); err != nil { - conn.Close() - return fmt.Errorf("failed to enable WAL mode: %w", err) - } - - // Set generous cache size (64MB) - if err := sqlitex.Execute(conn, "PRAGMA cache_size=-65536", nil); err != nil { - conn.Close() - return fmt.Errorf("failed to set cache size: %w", err) - } - - // Enable foreign keys - if err := sqlitex.Execute(conn, "PRAGMA foreign_keys=ON", nil); err != nil { - conn.Close() - return fmt.Errorf("failed to enable foreign keys: %w", err) - } - - db = conn return nil } // DB returns the global database connection -func DB() *sqlite.Conn { +func DB() *sashimi.DB { if db == nil { panic("database not initialized - call Init() first") } return db } -// Scan scans a SQLite statement result into a struct using field names -func Scan(stmt *sqlite.Stmt, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct { - return fmt.Errorf("dest must be a pointer to struct") +// Close closes the database connection +func Close() error { + if db != nil { + return db.Close() } - - elem := v.Elem() - typ := elem.Type() - - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - columnName := toSnakeCase(field.Name) - - fieldValue := elem.Field(i) - if !fieldValue.CanSet() { - continue - } - - // Find column index by name - colIndex := -1 - for j := 0; j < stmt.ColumnCount(); j++ { - if stmt.ColumnName(j) == columnName { - colIndex = j - break - } - } - - if colIndex == -1 { - continue // Column not found - } - - switch fieldValue.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fieldValue.SetInt(stmt.ColumnInt64(colIndex)) - case reflect.String: - fieldValue.SetString(stmt.ColumnText(colIndex)) - case reflect.Float32, reflect.Float64: - fieldValue.SetFloat(stmt.ColumnFloat(colIndex)) - case reflect.Bool: - fieldValue.SetBool(stmt.ColumnInt(colIndex) != 0) - } - } - return nil } -// Query executes a query with fmt-style placeholders and automatically binds parameters -func Query(query string, args ...any) (*sqlite.Stmt, error) { - // Replace fmt placeholders with SQLite placeholders - convertedQuery, paramTypes := convertPlaceholders(query) - - stmt, err := DB().Prepare(convertedQuery) - if err != nil { - return nil, err - } - - // Bind parameters with correct types - for i, arg := range args { - if i >= len(paramTypes) { - break - } - - switch paramTypes[i] { - case "s": // string - if s, ok := arg.(string); ok { - stmt.BindText(i+1, s) - } else { - stmt.BindText(i+1, fmt.Sprintf("%v", arg)) - } - case "d": // integer - switch v := arg.(type) { - case int: - stmt.BindInt64(i+1, int64(v)) - case int32: - stmt.BindInt64(i+1, int64(v)) - case int64: - stmt.BindInt64(i+1, v) - case float64: - stmt.BindInt64(i+1, int64(v)) - default: - if i64, err := strconv.ParseInt(fmt.Sprintf("%v", arg), 10, 64); err == nil { - stmt.BindInt64(i+1, i64) - } else { - stmt.BindInt64(i+1, 0) - } - } - } - } - - return stmt, nil +// Wrapper functions for convenience +func Query(query string, args ...any) (*sashimi.Stmt, error) { + return db.Query(query, args...) } -// Get executes a query and returns the first row func Get(dest any, query string, args ...any) error { - stmt, err := Query(query, args...) - if err != nil { - return err - } - defer stmt.Finalize() - - hasRow, err := stmt.Step() - if err != nil { - return err - } - if !hasRow { - return fmt.Errorf("no rows found") - } - - return Scan(stmt, dest) + return db.Get(dest, query, args...) } -// Select executes a query and scans all rows into a slice func Select(dest any, query string, args ...any) error { - destValue := reflect.ValueOf(dest) - if destValue.Kind() != reflect.Ptr || destValue.Elem().Kind() != reflect.Slice { - return fmt.Errorf("dest must be a pointer to slice") - } - - sliceValue := destValue.Elem() - elemType := sliceValue.Type().Elem() - - // Ensure element type is a pointer to struct - if elemType.Kind() != reflect.Ptr || elemType.Elem().Kind() != reflect.Struct { - return fmt.Errorf("slice elements must be pointers to structs") - } - - stmt, err := Query(query, args...) - if err != nil { - return err - } - defer stmt.Finalize() - - for { - hasRow, err := stmt.Step() - if err != nil { - return err - } - if !hasRow { - break - } - - // Create new instance of the element type - newElem := reflect.New(elemType.Elem()) - if err := Scan(stmt, newElem.Interface()); err != nil { - return err - } - - sliceValue.Set(reflect.Append(sliceValue, newElem)) - } - - return nil + return db.Select(dest, query, args...) } -// Exec executes a statement with fmt-style placeholders func Exec(query string, args ...any) error { - convertedQuery, paramTypes := convertPlaceholders(query) - - sqlArgs := make([]any, len(args)) - for i, arg := range args { - if i < len(paramTypes) && paramTypes[i] == "d" { - // Convert to int64 for integer parameters - switch v := arg.(type) { - case int: - sqlArgs[i] = int64(v) - case int32: - sqlArgs[i] = int64(v) - case int64: - sqlArgs[i] = v - default: - sqlArgs[i] = arg - } - } else { - sqlArgs[i] = arg - } - } - - return sqlitex.Execute(DB(), convertedQuery, &sqlitex.ExecOptions{ - Args: sqlArgs, - }) + return db.Exec(query, args...) } -// Update updates specific fields in the database func Update(tableName string, fields map[string]any, whereField string, whereValue any) error { - if len(fields) == 0 { - return nil // No changes - } - - // Build UPDATE query - setParts := make([]string, 0, len(fields)) - args := make([]any, 0, len(fields)+1) - - for field, value := range fields { - setParts = append(setParts, field+" = ?") - args = append(args, value) - } - - args = append(args, whereValue) - - query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = ?", - tableName, strings.Join(setParts, ", "), whereField) - - return sqlitex.Execute(DB(), query, &sqlitex.ExecOptions{ - Args: args, - }) + return db.Update(tableName, fields, whereField, whereValue) } -// Insert inserts a struct into the database func Insert(tableName string, obj any, excludeFields ...string) (int64, error) { - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Pointer { - v = v.Elem() - } - t := v.Type() - - exclude := make(map[string]bool) - for _, field := range excludeFields { - exclude[toSnakeCase(field)] = true - } - - var columns []string - var placeholders []string - var args []any - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - columnName := toSnakeCase(field.Name) - if exclude[columnName] { - continue - } - - columns = append(columns, columnName) - placeholders = append(placeholders, "?") - args = append(args, v.Field(i).Interface()) - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", - tableName, strings.Join(columns, ", "), strings.Join(placeholders, ", ")) - - stmt, err := DB().Prepare(query) - if err != nil { - return 0, err - } - defer stmt.Finalize() - - // Bind parameters - for i, arg := range args { - switch v := arg.(type) { - case string: - stmt.BindText(i+1, v) - case int, int32, int64: - stmt.BindInt64(i+1, reflect.ValueOf(v).Int()) - case float32, float64: - stmt.BindFloat(i+1, reflect.ValueOf(v).Float()) - default: - stmt.BindText(i+1, fmt.Sprintf("%v", v)) - } - } - - _, err = stmt.Step() - if err != nil { - return 0, err - } - - return DB().LastInsertRowID(), nil + return db.Insert(tableName, obj, excludeFields...) } -// Transaction executes multiple operations atomically func Transaction(fn func() error) error { - conn := DB() - - // Begin transaction - if err := sqlitex.Execute(conn, "BEGIN", nil); err != nil { - return err - } - - // Execute operations - err := fn() - - if err != nil { - // Rollback on error - sqlitex.Execute(conn, "ROLLBACK", nil) - return err - } - - // Commit on success - return sqlitex.Execute(conn, "COMMIT", nil) -} - -func convertPlaceholders(query string) (string, []string) { - var paramTypes []string - - convertedQuery := placeholderRegex.ReplaceAllStringFunc(query, func(match string) string { - paramTypes = append(paramTypes, match[1:]) // Remove % prefix - return "?" - }) - - return convertedQuery, paramTypes -} - -// toSnakeCase converts PascalCase to snake_case -func toSnakeCase(s string) string { - var result strings.Builder - runes := []rune(s) - - for i, r := range runes { - if i > 0 { - prev := runes[i-1] - - // Add underscore before digit if previous char was letter - if unicode.IsDigit(r) && unicode.IsLetter(prev) { - result.WriteByte('_') - } - // Add underscore before uppercase letter - if unicode.IsUpper(r) { - // Don't add if previous was also uppercase (unless end of acronym) - if !unicode.IsUpper(prev) || - (i+1 < len(runes) && unicode.IsLower(runes[i+1])) { - result.WriteByte('_') - } - } - } - result.WriteRune(unicode.ToLower(r)) - } - return result.String() + return db.Transaction(fn) } diff --git a/main.go b/main.go index 9250e28..a03b600 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" "time" @@ -21,6 +22,8 @@ import ( "git.sharkk.net/Sharkk/Sushi/csrf" "git.sharkk.net/Sharkk/Sushi/session" "git.sharkk.net/Sharkk/Sushi/timing" + + sashimi "git.sharkk.net/Sharkk/Sashimi" ) func main() { @@ -36,17 +39,90 @@ func main() { case "serve": flag.CommandLine.Parse(os.Args[2:]) startServer(port) + case "migrate": + handleMigrationCommand() default: fmt.Fprintf(os.Stderr, "Unknown command: %s\n", os.Args[1]) fmt.Fprintln(os.Stderr, "Available commands:") - fmt.Fprintln(os.Stderr, " serve - Start the server") - fmt.Fprintln(os.Stderr, " (no command) - Start the server") + fmt.Fprintln(os.Stderr, " serve - Start the server") + fmt.Fprintln(os.Stderr, " migrate - Run pending migrations") + fmt.Fprintln(os.Stderr, " migrate new - Create a new migration") + fmt.Fprintln(os.Stderr, " migrate status - Show migration status") + fmt.Fprintln(os.Stderr, " (no command) - Start the server") + os.Exit(1) + } +} + +func getDBPath() (string, error) { + cwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("failed to get current working directory: %w", err) + } + return filepath.Join(cwd, "data", "dk.db"), nil +} + +func getMigrationsDir() (string, error) { + cwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("failed to get current working directory: %w", err) + } + return filepath.Join(cwd, "sql"), nil +} + +func initDatabase() error { + dbPath, err := getDBPath() + if err != nil { + return err + } + return database.Init(dbPath) +} + +func handleMigrationCommand() { + if err := initDatabase(); err != nil { + log.Fatalf("Failed to initialize database: %v", err) + } + defer database.Close() + + migrationsDir, err := getMigrationsDir() + if err != nil { + log.Fatalf("Failed to get migrations directory: %v", err) + } + + migrator := sashimi.NewMigrator(database.DB(), migrationsDir) + + if len(os.Args) < 3 { + if err := migrator.Run(); err != nil { + log.Fatalf("Migration failed: %v", err) + } + return + } + + subcommand := os.Args[2] + switch subcommand { + case "new": + if len(os.Args) < 4 { + log.Fatal("Usage: migrate new ") + } + migrationName := strings.Join(os.Args[3:], " ") + if err := migrator.CreateNew(migrationName); err != nil { + log.Fatalf("Failed to create migration: %v", err) + } + case "status": + if err := migrator.Status(); err != nil { + log.Fatalf("Failed to get migration status: %v", err) + } + default: + fmt.Fprintf(os.Stderr, "Unknown migration subcommand: %s\n", subcommand) + fmt.Fprintln(os.Stderr, "Available subcommands:") + fmt.Fprintln(os.Stderr, " (none) - Run pending migrations") + fmt.Fprintln(os.Stderr, " new - Create a new migration") + fmt.Fprintln(os.Stderr, " status - Show migration status") os.Exit(1) } } func startServer(port string) { - fmt.Println("Starting Dragon Knight server...") + fmt.Println("Dragon Knight is starting!") if err := start(port); err != nil { log.Fatalf("Server failed: %v", err) } @@ -58,11 +134,10 @@ func start(port string) error { return fmt.Errorf("failed to get current working directory: %w", err) } - err = database.Init(filepath.Join(cwd, "data/dk.db")) - if err != nil { - log.Fatal("Failed to initialize database:", err) + if err := initDatabase(); err != nil { + return fmt.Errorf("failed to initialize database: %w", err) } - defer database.DB().Close() + defer database.Close() control.Init(filepath.Join(cwd, "data/control.json")) defer control.Save() @@ -116,7 +191,7 @@ func start(port string) error { }() <-c - log.Println("\nReceived shutdown signal, shutting down gracefully...") + log.Println("\nShutting down! Beginning cleanup...") log.Println("Saving sessions...") sushi.SaveSessions() diff --git a/sql/1_create_database.sql b/sql/1_create_database.sql index 627c315..fb4487c 100644 --- a/sql/1_create_database.sql +++ b/sql/1_create_database.sql @@ -1,3 +1,6 @@ +-- Migration 1: create database +-- Created: 2025-08-22 15:00:10 + DROP TABLE IF EXISTS babble; CREATE TABLE babble ( `id` INTEGER PRIMARY KEY AUTOINCREMENT, @@ -426,3 +429,5 @@ CREATE TABLE fight_logs ( `name` TEXT NOT NULL DEFAULT '', `created` INTEGER NOT NULL DEFAULT (unixepoch()) ); + +