126 lines
3.1 KiB
Go
126 lines
3.1 KiB
Go
package simplesql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
// DB is the client for this package. All SQL actions are accessible from here.
|
|
type DB struct {
|
|
db *sqlx.DB
|
|
|
|
stmts stmts
|
|
}
|
|
|
|
// stmts is a map of named statements
|
|
type stmts map[string]*sqlx.NamedStmt
|
|
|
|
// New will return a new DB. Call Close when you are done with it.
|
|
// initSQL is run before queries are prepared. This is useful for fresh databases that do not have tables
|
|
// needed for the statements to prepare successfully.
|
|
func New(ctx context.Context, driverName string, dataSourceName string, queries embed.FS, initSQL ...string) (*DB, error) {
|
|
// Open connection to the DB
|
|
conn, err := sqlx.Connect(driverName, dataSourceName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating DB connection: %w", err)
|
|
}
|
|
|
|
// Execute the init SQL
|
|
tx, err := conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("beginning transaction: %w", err)
|
|
}
|
|
for _, init := range initSQL {
|
|
if _, err := tx.Exec(init); err != nil {
|
|
return nil, fmt.Errorf("executing init SQL: %w", err)
|
|
}
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("committing init SQL: %w", err)
|
|
}
|
|
|
|
// Prepare all statements
|
|
stmts, err := prepare(ctx, conn, queries)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("preparing statements: %w", err)
|
|
}
|
|
|
|
return &DB{
|
|
db: conn,
|
|
stmts: stmts,
|
|
}, nil
|
|
}
|
|
|
|
// Close will close any statements and the DB connection.
|
|
func (db *DB) Close() error {
|
|
for _, stmt := range db.stmts {
|
|
_ = stmt.Close()
|
|
}
|
|
|
|
return db.db.Close()
|
|
}
|
|
|
|
func (db *DB) Get(queryName string, dest interface{}, arg interface{}) error {
|
|
return db.stmts[queryName].Get(dest, arg)
|
|
}
|
|
|
|
func (db *DB) Exec(queryName string, args ...interface{}) (sql.Result, error) {
|
|
return db.stmts[queryName].Exec(args)
|
|
}
|
|
|
|
func (db *DB) ExecUnprepared(query string, args ...interface{}) (sql.Result, error) {
|
|
return db.db.Exec(query, args)
|
|
}
|
|
|
|
// prepare will create a stmts from an embed.FS. The file name is the key.
|
|
func prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (stmts, error) {
|
|
// Read entries from sqlq dir
|
|
entries, err := queries.ReadDir("sqlq")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading sqlq dir: %w", err)
|
|
}
|
|
|
|
// Get file content per entry and map to file name as prepared statement
|
|
s := stmts{}
|
|
for _, entry := range entries {
|
|
// Get file content
|
|
fp := filepath.Join("sqlq", entry.Name())
|
|
qb, err := queries.ReadFile(fp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file %q from sqlq: %w", fp, err)
|
|
}
|
|
|
|
// Cleanse input query
|
|
qs := cleanseQuery(string(qb))
|
|
|
|
// Prepare statement
|
|
stmt, err := db.PrepareNamedContext(ctx, qs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("preparing statement %q: %w", qs, err)
|
|
}
|
|
|
|
// Map to file name, without the extension, which is assumed to be .sql
|
|
s[strings.TrimSuffix(entry.Name(), ".sql")] = stmt
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func cleanseQuery(s string) string {
|
|
var cleansed []string
|
|
for _, line := range strings.Split(s, "\n") {
|
|
line := strings.TrimSpace(line)
|
|
if !strings.HasPrefix(line, "--") {
|
|
cleansed = append(cleansed, line)
|
|
}
|
|
}
|
|
|
|
return strings.Join(cleansed, "\n")
|
|
}
|