180 lines
3.8 KiB
Go
180 lines
3.8 KiB
Go
package database
|
|
|
|
import (
|
|
"os"
|
|
"testing"
|
|
)
|
|
|
|
func TestOpen(t *testing.T) {
|
|
// Create a temporary database file
|
|
tempFile := "test.db"
|
|
defer os.Remove(tempFile)
|
|
|
|
db, err := Open(tempFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
if db == nil {
|
|
t.Fatal("Database instance is nil")
|
|
}
|
|
}
|
|
|
|
func TestExec(t *testing.T) {
|
|
tempFile := "test_exec.db"
|
|
defer os.Remove(tempFile)
|
|
|
|
db, err := Open(tempFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Test table creation
|
|
err = db.Exec(`CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
// Test data insertion
|
|
err = db.Exec(`INSERT INTO test_table (name) VALUES (?)`, "test_name")
|
|
if err != nil {
|
|
t.Fatalf("Failed to insert data: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestQueryRow(t *testing.T) {
|
|
tempFile := "test_query.db"
|
|
defer os.Remove(tempFile)
|
|
|
|
db, err := Open(tempFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Setup test data
|
|
err = db.Exec(`CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT, value INTEGER)`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
err = db.Exec(`INSERT INTO test_table (name, value) VALUES (?, ?)`, "test", 42)
|
|
if err != nil {
|
|
t.Fatalf("Failed to insert data: %v", err)
|
|
}
|
|
|
|
// Test query
|
|
row, err := db.QueryRow("SELECT name, value FROM test_table WHERE id = ?", 1)
|
|
if err != nil {
|
|
t.Fatalf("Failed to query row: %v", err)
|
|
}
|
|
|
|
if row == nil {
|
|
t.Fatal("Row is nil")
|
|
}
|
|
defer row.Close()
|
|
|
|
name := row.Text(0)
|
|
value := row.Int(1)
|
|
|
|
if name != "test" {
|
|
t.Errorf("Expected name 'test', got '%s'", name)
|
|
}
|
|
|
|
if value != 42 {
|
|
t.Errorf("Expected value 42, got %d", value)
|
|
}
|
|
}
|
|
|
|
func TestQuery(t *testing.T) {
|
|
tempFile := "test_query_all.db"
|
|
defer os.Remove(tempFile)
|
|
|
|
db, err := Open(tempFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Setup test data
|
|
err = db.Exec(`CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
names := []string{"test1", "test2", "test3"}
|
|
for _, name := range names {
|
|
err = db.Exec(`INSERT INTO test_table (name) VALUES (?)`, name)
|
|
if err != nil {
|
|
t.Fatalf("Failed to insert data: %v", err)
|
|
}
|
|
}
|
|
|
|
// Test query with callback
|
|
var results []string
|
|
err = db.Query("SELECT name FROM test_table ORDER BY id", func(row *Row) error {
|
|
results = append(results, row.Text(0))
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("Failed to query: %v", err)
|
|
}
|
|
|
|
if len(results) != 3 {
|
|
t.Errorf("Expected 3 results, got %d", len(results))
|
|
}
|
|
|
|
for i, expected := range names {
|
|
if i < len(results) && results[i] != expected {
|
|
t.Errorf("Expected result[%d] = '%s', got '%s'", i, expected, results[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTransaction(t *testing.T) {
|
|
tempFile := "test_transaction.db"
|
|
defer os.Remove(tempFile)
|
|
|
|
db, err := Open(tempFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Setup
|
|
err = db.Exec(`CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
// Test successful transaction
|
|
err = db.Transaction(func(txDB *DB) error {
|
|
err := txDB.Exec(`INSERT INTO test_table (name) VALUES (?)`, "tx_test1")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return txDB.Exec(`INSERT INTO test_table (name) VALUES (?)`, "tx_test2")
|
|
})
|
|
|
|
if err != nil {
|
|
t.Fatalf("Transaction failed: %v", err)
|
|
}
|
|
|
|
// Verify data was committed
|
|
var count int
|
|
row, err := db.QueryRow("SELECT COUNT(*) FROM test_table")
|
|
if err != nil {
|
|
t.Fatalf("Failed to count rows: %v", err)
|
|
}
|
|
if row != nil {
|
|
count = row.Int(0)
|
|
row.Close()
|
|
}
|
|
|
|
if count != 2 {
|
|
t.Errorf("Expected 2 rows, got %d", count)
|
|
}
|
|
} |