package simplesql import ( "context" "database/sql" "embed" "fmt" "path/filepath" "strings" "github.com/jmoiron/sqlx" ) // DB is the client for this package. All SQL actions are accessible from here. type DB struct { db *sqlx.DB stmts stmts } // stmts is a map of named statements type stmts map[string]*sqlx.NamedStmt // New will return a new DB. Call Close when you are done with it. // initSQL is run before queries are prepared. This is useful for fresh databases that do not have tables // needed for the statements to prepare successfully. func New(ctx context.Context, driverName string, dataSourceName string, queries embed.FS, initSQL ...string) (*DB, error) { // Open connection to the DB conn, err := sqlx.Connect(driverName, dataSourceName) if err != nil { return nil, fmt.Errorf("creating DB connection: %w", err) } // Execute the init SQL tx, err := conn.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("beginning transaction: %w", err) } for _, init := range initSQL { if _, err := tx.Exec(init); err != nil { return nil, fmt.Errorf("executing init SQL: %w", err) } } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("committing init SQL: %w", err) } // Prepare all statements stmts, err := prepare(ctx, conn, queries) if err != nil { return nil, fmt.Errorf("preparing statements: %w", err) } return &DB{ db: conn, stmts: stmts, }, nil } // Close will close any statements and the DB connection. func (db *DB) Close() error { for _, stmt := range db.stmts { _ = stmt.Close() } return db.db.Close() } func (db *DB) Get(queryName string, dest interface{}, arg interface{}) error { return db.stmts[queryName].Get(dest, arg) } func (db *DB) Exec(queryName string, args ...interface{}) (sql.Result, error) { return db.stmts[queryName].Exec(args) } func (db *DB) ExecUnprepared(query string, args ...interface{}) (sql.Result, error) { return db.db.Exec(query, args) } // prepare will create a stmts from an embed.FS. The file name is the key. func prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (stmts, error) { // Read entries from sqlq dir entries, err := queries.ReadDir("sqlq") if err != nil { return nil, fmt.Errorf("reading sqlq dir: %w", err) } // Get file content per entry and map to file name as prepared statement s := stmts{} for _, entry := range entries { // Get file content fp := filepath.Join("sqlq", entry.Name()) qb, err := queries.ReadFile(fp) if err != nil { return nil, fmt.Errorf("reading file %q from sqlq: %w", fp, err) } // Cleanse input query qs := cleanseQuery(string(qb)) // Prepare statement stmt, err := db.PrepareNamedContext(ctx, qs) if err != nil { return nil, fmt.Errorf("preparing statement %q: %w", qs, err) } // Map to file name, without the extension, which is assumed to be .sql s[strings.TrimSuffix(entry.Name(), ".sql")] = stmt } return s, nil } func cleanseQuery(s string) string { var cleansed []string for _, line := range strings.Split(s, "\n") { line := strings.TrimSpace(line) if !strings.HasPrefix(line, "--") { cleansed = append(cleansed, line) } } return strings.Join(cleansed, "\n") }