diff --git a/go.mod b/go.mod index 451d293..b423127 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,6 @@ module git.simplesystems.tech/simplesystems/simple-sql go 1.17 -require ( - github.com/jmoiron/sqlx v1.3.4 // indirect - github.com/mattn/go-sqlite3 v1.14.10 // indirect -) +require github.com/jmoiron/sqlx v1.3.4 + +require github.com/mattn/go-sqlite3 v1.14.10 // indirect diff --git a/simplesql.go b/simplesql.go index 484dbb0..c14ecb9 100644 --- a/simplesql.go +++ b/simplesql.go @@ -2,6 +2,7 @@ package simplesql import ( "context" + "database/sql" "embed" "fmt" "path/filepath" @@ -10,19 +11,75 @@ import ( "github.com/jmoiron/sqlx" ) -// Stmts is a map of named statements -type Stmts map[string]*sqlx.NamedStmt +// DB is the client for this package. All SQL actions are accessible from here. +type DB struct { + db *sqlx.DB -// Close will close all prepared statements. -// This should be called after all statements are no longer needed. -func (s *Stmts) Close() { - for _, stmt := range *s { - _ = stmt.Close() - } + stmts stmts } -// Prepare will create a Stmts from an embed.FS. The file name is the key. -func Prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (Stmts, error) { +// stmts is a map of named statements +type stmts map[string]*sqlx.NamedStmt + +// New will return a new DB. Call Close when you are done with it. +// initSQL is run before queries are prepared. This is useful for fresh databases that do not have tables +// needed for the statements to prepare successfully. +func New(ctx context.Context, driverName string, dataSourceName string, queries embed.FS, initSQL ...string) (*DB, error) { + // Open connection to the DB + conn, err := sqlx.Connect(driverName, dataSourceName) + if err != nil { + return nil, fmt.Errorf("creating DB connection: %w", err) + } + + // Execute the init SQL + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("beginning transaction: %w", err) + } + for _, init := range initSQL { + if _, err := tx.Exec(init); err != nil { + return nil, fmt.Errorf("executing init SQL: %w", err) + } + } + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("committing init SQL: %w", err) + } + + // Prepare all statements + stmts, err := prepare(ctx, conn, queries) + if err != nil { + return nil, fmt.Errorf("preparing statements: %w", err) + } + + return &DB{ + db: conn, + stmts: stmts, + }, nil +} + +// Close will close any statements and the DB connection. +func (db *DB) Close() error { + for _, stmt := range db.stmts { + _ = stmt.Close() + } + + return db.db.Close() +} + +func (db *DB) Get(queryName string, dest interface{}, arg interface{}) error { + return db.stmts[queryName].Get(dest, arg) +} + +func (db *DB) Exec(queryName string, args ...interface{}) (sql.Result, error) { + return db.stmts[queryName].Exec(args) +} + +func (db *DB) ExecUnprepared(query string, args ...interface{}) (sql.Result, error) { + return db.db.Exec(query, args) +} + +// prepare will create a stmts from an embed.FS. The file name is the key. +func prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (stmts, error) { // Read entries from sqlq dir entries, err := queries.ReadDir("sqlq") if err != nil { @@ -30,7 +87,7 @@ func Prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (Stmts, error) } // Get file content per entry and map to file name as prepared statement - s := Stmts{} + s := stmts{} for _, entry := range entries { // Get file content fp := filepath.Join("sqlq", entry.Name()) diff --git a/simplesql_test/simplesql_test.go b/simplesql_test/simplesql_test.go index a48ff8d..c04924b 100644 --- a/simplesql_test/simplesql_test.go +++ b/simplesql_test/simplesql_test.go @@ -6,7 +6,6 @@ import ( "testing" ssql "git.simplesystems.tech/simplesystems/simple-sql" - "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" ) @@ -18,42 +17,20 @@ type user struct { } func TestNew(t *testing.T) { + // SQL to create a test table and users + initSQL := `CREATE TABLE users(id INT, username VARCHAR);INSERT INTO users (id, username) VALUES (1, 'u1'), (2, 'u2');` + // Create in-memory db for test - db, err := sqlx.Open("sqlite3", "file:temp.db?mode=memory") + db, err := ssql.New(context.TODO(), "sqlite3", "file:temp.db?mode=memory", sqlQueries, initSQL) if err != nil { t.Fatalf("opening db: %v", err) } defer db.Close() - // Create table - if _, err := db.Exec("CREATE TABLE users(id INT, username VARCHAR)"); err != nil { - t.Fatalf("creating table: %v", err) - } - - // Add users - if _, err := db.Exec("INSERT INTO users (id, username) VALUES (1, 'u1'), (2, 'u2')"); err != nil { - t.Fatalf("adding users: %v", err) - } - - // Prepare statements - stmts, err := ssql.Prepare(context.TODO(), db, sqlQueries) - if err != nil { - t.Fatalf("preparing statements: %v", err) - } - defer stmts.Close() - - if len(stmts) != 1 { - t.Errorf("expecting 1 statement, but got %d", len(stmts)) - } - - if _, ok := stmts["get_user"]; !ok { - t.Error("expecting get_user to be in stmts, but it is not") - } - // Get a user from the table, using the prepared statement var res user - if err := stmts["get_user"].Get(&res, map[string]interface{}{"user_id": "2"}); err != nil { - t.Fatalf("getting users: %v", err) + if err := db.Get("get_user", &res, map[string]interface{}{"user_id": "2"}); err != nil { + t.Fatalf("getting user: %v", err) } if res.Username != "u2" {