126 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			126 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package simplesql
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"database/sql"
 | |
| 	"embed"
 | |
| 	"fmt"
 | |
| 	"path/filepath"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/jmoiron/sqlx"
 | |
| )
 | |
| 
 | |
| // DB is the client for this package.  All SQL actions are accessible from here.
 | |
| type DB struct {
 | |
| 	db *sqlx.DB
 | |
| 
 | |
| 	stmts stmts
 | |
| }
 | |
| 
 | |
| // stmts is a map of named statements
 | |
| type stmts map[string]*sqlx.NamedStmt
 | |
| 
 | |
| // New will return a new DB.  Call Close when you are done with it.
 | |
| // initSQL is run before queries are prepared.  This is useful for fresh databases that do not have tables
 | |
| // needed for the statements to prepare successfully.
 | |
| func New(ctx context.Context, driverName string, dataSourceName string, queries embed.FS, initSQL ...string) (*DB, error) {
 | |
| 	// Open connection to the DB
 | |
| 	conn, err := sqlx.Connect(driverName, dataSourceName)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("creating DB connection: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// Execute the init SQL
 | |
| 	tx, err := conn.BeginTx(ctx, nil)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("beginning transaction: %w", err)
 | |
| 	}
 | |
| 	for _, init := range initSQL {
 | |
| 		if _, err := tx.Exec(init); err != nil {
 | |
| 			return nil, fmt.Errorf("executing init SQL: %w", err)
 | |
| 		}
 | |
| 	}
 | |
| 	if err := tx.Commit(); err != nil {
 | |
| 		return nil, fmt.Errorf("committing init SQL: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// Prepare all statements
 | |
| 	stmts, err := prepare(ctx, conn, queries)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("preparing statements: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	return &DB{
 | |
| 		db:    conn,
 | |
| 		stmts: stmts,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // Close will close any statements and the DB connection.
 | |
| func (db *DB) Close() error {
 | |
| 	for _, stmt := range db.stmts {
 | |
| 		_ = stmt.Close()
 | |
| 	}
 | |
| 
 | |
| 	return db.db.Close()
 | |
| }
 | |
| 
 | |
| func (db *DB) Get(queryName string, dest interface{}, arg interface{}) error {
 | |
| 	return db.stmts[queryName].Get(dest, arg)
 | |
| }
 | |
| 
 | |
| func (db *DB) Exec(queryName string, args ...interface{}) (sql.Result, error) {
 | |
| 	return db.stmts[queryName].Exec(args)
 | |
| }
 | |
| 
 | |
| func (db *DB) ExecUnprepared(query string, args ...interface{}) (sql.Result, error) {
 | |
| 	return db.db.Exec(query, args)
 | |
| }
 | |
| 
 | |
| // prepare will create a stmts from an embed.FS.  The file name is the key.
 | |
| func prepare(ctx context.Context, db *sqlx.DB, queries embed.FS) (stmts, error) {
 | |
| 	// Read entries from sqlq dir
 | |
| 	entries, err := queries.ReadDir("sqlq")
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("reading sqlq dir: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// Get file content per entry and map to file name as prepared statement
 | |
| 	s := stmts{}
 | |
| 	for _, entry := range entries {
 | |
| 		// Get file content
 | |
| 		fp := filepath.Join("sqlq", entry.Name())
 | |
| 		qb, err := queries.ReadFile(fp)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("reading file %q from sqlq: %w", fp, err)
 | |
| 		}
 | |
| 
 | |
| 		// Cleanse input query
 | |
| 		qs := cleanseQuery(string(qb))
 | |
| 
 | |
| 		// Prepare statement
 | |
| 		stmt, err := db.PrepareNamedContext(ctx, qs)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("preparing statement %q: %w", qs, err)
 | |
| 		}
 | |
| 
 | |
| 		// Map to file name, without the extension, which is assumed to be .sql
 | |
| 		s[strings.TrimSuffix(entry.Name(), ".sql")] = stmt
 | |
| 	}
 | |
| 
 | |
| 	return s, nil
 | |
| }
 | |
| 
 | |
| func cleanseQuery(s string) string {
 | |
| 	var cleansed []string
 | |
| 	for _, line := range strings.Split(s, "\n") {
 | |
| 		line := strings.TrimSpace(line)
 | |
| 		if !strings.HasPrefix(line, "--") {
 | |
| 			cleansed = append(cleansed, line)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return strings.Join(cleansed, "\n")
 | |
| }
 |