235 lines
5.8 KiB
Go
235 lines
5.8 KiB
Go
package sql
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// PostgresDriver implements the Driver interface for PostgreSQL
|
|
type PostgresDriver struct{}
|
|
|
|
func (d *PostgresDriver) Name() string {
|
|
return "postgres"
|
|
}
|
|
|
|
func (d *PostgresDriver) Open(dsn string) (Connection, error) {
|
|
config, err := pgxpool.ParseConfig(dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to parse config: %w", err)
|
|
}
|
|
|
|
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to create pool: %w", err)
|
|
}
|
|
|
|
return &PostgresConnection{pool: pool}, nil
|
|
}
|
|
|
|
// PostgresConnection implements the Connection interface
|
|
type PostgresConnection struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
func (c *PostgresConnection) Close() error {
|
|
c.pool.Close()
|
|
return nil
|
|
}
|
|
|
|
func (c *PostgresConnection) Ping(ctx context.Context) error {
|
|
return c.pool.Ping(ctx)
|
|
}
|
|
|
|
func (c *PostgresConnection) Begin(ctx context.Context) (Transaction, error) {
|
|
tx, err := c.pool.Begin(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to begin transaction: %w", err)
|
|
}
|
|
return &PostgresTransaction{tx: tx}, nil
|
|
}
|
|
|
|
func (c *PostgresConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
|
rows, err := c.pool.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: query failed: %w", err)
|
|
}
|
|
return &PostgresRows{rows: rows}, nil
|
|
}
|
|
|
|
func (c *PostgresConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
|
row := c.pool.QueryRow(ctx, query, args...)
|
|
return &PostgresRow{row: row}
|
|
}
|
|
|
|
func (c *PostgresConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
|
tag, err := c.pool.Exec(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: exec failed: %w", err)
|
|
}
|
|
return &PostgresResult{tag: tag}, nil
|
|
}
|
|
|
|
func (c *PostgresConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
|
// pgx doesn't have explicit prepared statements like database/sql
|
|
// We'll store the query and use it with the pool
|
|
return &PostgresStatement{pool: c.pool, query: query}, nil
|
|
}
|
|
|
|
// PostgresTransaction implements the Transaction interface
|
|
type PostgresTransaction struct {
|
|
tx pgx.Tx
|
|
}
|
|
|
|
func (t *PostgresTransaction) Commit() error {
|
|
return t.tx.Commit(context.Background())
|
|
}
|
|
|
|
func (t *PostgresTransaction) Rollback() error {
|
|
return t.tx.Rollback(context.Background())
|
|
}
|
|
|
|
func (t *PostgresTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
|
rows, err := t.tx.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: transaction query failed: %w", err)
|
|
}
|
|
return &PostgresRows{rows: rows}, nil
|
|
}
|
|
|
|
func (t *PostgresTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
|
row := t.tx.QueryRow(ctx, query, args...)
|
|
return &PostgresRow{row: row}
|
|
}
|
|
|
|
func (t *PostgresTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
|
tag, err := t.tx.Exec(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: transaction exec failed: %w", err)
|
|
}
|
|
return &PostgresResult{tag: tag}, nil
|
|
}
|
|
|
|
func (t *PostgresTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
|
return &PostgresStatement{tx: t.tx, query: query}, nil
|
|
}
|
|
|
|
// PostgresRows implements the Rows interface
|
|
type PostgresRows struct {
|
|
rows pgx.Rows
|
|
}
|
|
|
|
func (r *PostgresRows) Next() bool {
|
|
return r.rows.Next()
|
|
}
|
|
|
|
func (r *PostgresRows) Scan(dest ...any) error {
|
|
return r.rows.Scan(dest...)
|
|
}
|
|
|
|
func (r *PostgresRows) Columns() ([]string, error) {
|
|
fields := r.rows.FieldDescriptions()
|
|
columns := make([]string, len(fields))
|
|
for i, field := range fields {
|
|
columns[i] = field.Name
|
|
}
|
|
return columns, nil
|
|
}
|
|
|
|
func (r *PostgresRows) Close() error {
|
|
r.rows.Close()
|
|
return nil
|
|
}
|
|
|
|
func (r *PostgresRows) Err() error {
|
|
return r.rows.Err()
|
|
}
|
|
|
|
// PostgresRow implements the Row interface
|
|
type PostgresRow struct {
|
|
row pgx.Row
|
|
}
|
|
|
|
func (r *PostgresRow) Scan(dest ...any) error {
|
|
return r.row.Scan(dest...)
|
|
}
|
|
|
|
// PostgresResult implements the Result interface
|
|
type PostgresResult struct {
|
|
tag pgconn.CommandTag
|
|
}
|
|
|
|
func (r *PostgresResult) LastInsertId() (int64, error) {
|
|
// PostgreSQL doesn't have AUTO_INCREMENT like MySQL
|
|
// Users should use RETURNING clause or sequences
|
|
return 0, fmt.Errorf("postgres: LastInsertId not supported, use RETURNING clause")
|
|
}
|
|
|
|
func (r *PostgresResult) RowsAffected() (int64, error) {
|
|
return r.tag.RowsAffected(), nil
|
|
}
|
|
|
|
// PostgresStatement implements the Statement interface
|
|
type PostgresStatement struct {
|
|
pool *pgxpool.Pool
|
|
tx pgx.Tx
|
|
query string
|
|
}
|
|
|
|
func (s *PostgresStatement) Close() error {
|
|
// pgx doesn't require explicit statement cleanup
|
|
return nil
|
|
}
|
|
|
|
func (s *PostgresStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
|
var rows pgx.Rows
|
|
var err error
|
|
|
|
if s.tx != nil {
|
|
rows, err = s.tx.Query(ctx, s.query, args...)
|
|
} else {
|
|
rows, err = s.pool.Query(ctx, s.query, args...)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: statement query failed: %w", err)
|
|
}
|
|
return &PostgresRows{rows: rows}, nil
|
|
}
|
|
|
|
func (s *PostgresStatement) QueryRow(ctx context.Context, args ...any) Row {
|
|
var row pgx.Row
|
|
|
|
if s.tx != nil {
|
|
row = s.tx.QueryRow(ctx, s.query, args...)
|
|
} else {
|
|
row = s.pool.QueryRow(ctx, s.query, args...)
|
|
}
|
|
|
|
return &PostgresRow{row: row}
|
|
}
|
|
|
|
func (s *PostgresStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
|
var tag pgconn.CommandTag
|
|
var err error
|
|
|
|
if s.tx != nil {
|
|
tag, err = s.tx.Exec(ctx, s.query, args...)
|
|
} else {
|
|
tag, err = s.pool.Exec(ctx, s.query, args...)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: statement exec failed: %w", err)
|
|
}
|
|
return &PostgresResult{tag: tag}, nil
|
|
}
|
|
|
|
func init() {
|
|
// Register PostgreSQL driver on import
|
|
RegisterDriver("postgres", &PostgresDriver{})
|
|
}
|