Rewrite to create a connection and prepare queries in one call. Stmts is no longer exposed to the user. An init SQL is added to initialize a database before preparing queries.
This commit is contained in:
parent
46c0128d67
commit
7c30c3c72b
7
go.mod
7
go.mod
@ -2,7 +2,6 @@ module git.simplesystems.tech/simplesystems/simple-sql
|
|||||||
|
|
||||||
go 1.17
|
go 1.17
|
||||||
|
|
||||||
require (
|
require github.com/jmoiron/sqlx v1.3.4
|
||||||
github.com/jmoiron/sqlx v1.3.4 // indirect
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.10 // indirect
|
require github.com/mattn/go-sqlite3 v1.14.10 // indirect
|
||||||
)
|
|
||||||
|
79
simplesql.go
79
simplesql.go
@ -2,6 +2,7 @@ package simplesql
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -10,19 +11,75 @@ import (
|
|||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Stmts is a map of named statements
|
// DB is the client for this package. All SQL actions are accessible from here.
|
||||||
type Stmts map[string]*sqlx.NamedStmt
|
type DB struct {
|
||||||
|
db *sqlx.DB
|
||||||
|
|
||||||
// Close will close all prepared statements.
|
stmts stmts
|
||||||
// This should be called after all statements are no longer needed.
|
|
||||||
func (s *Stmts) Close() {
|
|
||||||
for _, stmt := range *s {
|
|
||||||
_ = stmt.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare will create a Stmts from an embed.FS. The file name is the key.
|
// stmts is a map of named statements
|
||||||
func Prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (Stmts, error) {
|
type stmts map[string]*sqlx.NamedStmt
|
||||||
|
|
||||||
|
// New will return a new DB. Call Close when you are done with it.
|
||||||
|
// initSQL is run before queries are prepared. This is useful for fresh databases that do not have tables
|
||||||
|
// needed for the statements to prepare successfully.
|
||||||
|
func New(ctx context.Context, driverName string, dataSourceName string, queries embed.FS, initSQL ...string) (*DB, error) {
|
||||||
|
// Open connection to the DB
|
||||||
|
conn, err := sqlx.Connect(driverName, dataSourceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating DB connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the init SQL
|
||||||
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("beginning transaction: %w", err)
|
||||||
|
}
|
||||||
|
for _, init := range initSQL {
|
||||||
|
if _, err := tx.Exec(init); err != nil {
|
||||||
|
return nil, fmt.Errorf("executing init SQL: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, fmt.Errorf("committing init SQL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare all statements
|
||||||
|
stmts, err := prepare(ctx, conn, queries)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("preparing statements: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DB{
|
||||||
|
db: conn,
|
||||||
|
stmts: stmts,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close will close any statements and the DB connection.
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
for _, stmt := range db.stmts {
|
||||||
|
_ = stmt.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Get(queryName string, dest interface{}, arg interface{}) error {
|
||||||
|
return db.stmts[queryName].Get(dest, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Exec(queryName string, args ...interface{}) (sql.Result, error) {
|
||||||
|
return db.stmts[queryName].Exec(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ExecUnprepared(query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
return db.db.Exec(query, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare will create a stmts from an embed.FS. The file name is the key.
|
||||||
|
func prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (stmts, error) {
|
||||||
// Read entries from sqlq dir
|
// Read entries from sqlq dir
|
||||||
entries, err := queries.ReadDir("sqlq")
|
entries, err := queries.ReadDir("sqlq")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -30,7 +87,7 @@ func Prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (Stmts, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get file content per entry and map to file name as prepared statement
|
// Get file content per entry and map to file name as prepared statement
|
||||||
s := Stmts{}
|
s := stmts{}
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
// Get file content
|
// Get file content
|
||||||
fp := filepath.Join("sqlq", entry.Name())
|
fp := filepath.Join("sqlq", entry.Name())
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
ssql "git.simplesystems.tech/simplesystems/simple-sql"
|
ssql "git.simplesystems.tech/simplesystems/simple-sql"
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,42 +17,20 @@ type user struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
|
// SQL to create a test table and users
|
||||||
|
initSQL := `CREATE TABLE users(id INT, username VARCHAR);INSERT INTO users (id, username) VALUES (1, 'u1'), (2, 'u2');`
|
||||||
|
|
||||||
// Create in-memory db for test
|
// Create in-memory db for test
|
||||||
db, err := sqlx.Open("sqlite3", "file:temp.db?mode=memory")
|
db, err := ssql.New(context.TODO(), "sqlite3", "file:temp.db?mode=memory", sqlQueries, initSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("opening db: %v", err)
|
t.Fatalf("opening db: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Create table
|
|
||||||
if _, err := db.Exec("CREATE TABLE users(id INT, username VARCHAR)"); err != nil {
|
|
||||||
t.Fatalf("creating table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add users
|
|
||||||
if _, err := db.Exec("INSERT INTO users (id, username) VALUES (1, 'u1'), (2, 'u2')"); err != nil {
|
|
||||||
t.Fatalf("adding users: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare statements
|
|
||||||
stmts, err := ssql.Prepare(context.TODO(), db, sqlQueries)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("preparing statements: %v", err)
|
|
||||||
}
|
|
||||||
defer stmts.Close()
|
|
||||||
|
|
||||||
if len(stmts) != 1 {
|
|
||||||
t.Errorf("expecting 1 statement, but got %d", len(stmts))
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := stmts["get_user"]; !ok {
|
|
||||||
t.Error("expecting get_user to be in stmts, but it is not")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get a user from the table, using the prepared statement
|
// Get a user from the table, using the prepared statement
|
||||||
var res user
|
var res user
|
||||||
if err := stmts["get_user"].Get(&res, map[string]interface{}{"user_id": "2"}); err != nil {
|
if err := db.Get("get_user", &res, map[string]interface{}{"user_id": "2"}); err != nil {
|
||||||
t.Fatalf("getting users: %v", err)
|
t.Fatalf("getting user: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Username != "u2" {
|
if res.Username != "u2" {
|
||||||
|
Loading…
Reference in New Issue
Block a user