Expose apid.DB interface instead of sql.DB for more flexibility
diff --git a/data/data.go b/data/data.go index ff0f6fb..f6a03b5 100644 --- a/data/data.go +++ b/data/data.go
@@ -24,7 +24,7 @@ var log, dbTraceLog apid.LogService var config apid.ConfigService -var dbMap = make(map[string]*sql.DB) +var dbMap = make(map[string]apid.DB) var dbMapSync sync.RWMutex func CreateDataService() apid.DataService { @@ -45,11 +45,11 @@ type dataService struct { } -func (d *dataService) DB() (*sql.DB, error) { +func (d *dataService) DB() (apid.DB, error) { return d.DBForID(commonDBID) } -func (d *dataService) DBForID(id string) (db *sql.DB, err error) { +func (d *dataService) DBForID(id string) (db apid.DB, err error) { dbMapSync.RLock() db = dbMap[id]
diff --git a/data/data_test.go b/data/data_test.go index 2c7db21..5581c5c 100644 --- a/data/data_test.go +++ b/data/data_test.go
@@ -1,7 +1,6 @@ package data_test import ( - "database/sql" "github.com/30x/apid" "github.com/30x/apid/factory" . "github.com/onsi/ginkgo" @@ -16,7 +15,7 @@ ) const ( - count = 3000 + count = 2000 setupSql = ` CREATE TABLE IF NOT EXISTS test_1 (id INTEGER PRIMARY KEY, counter TEXT); CREATE TABLE IF NOT EXISTS test_2 (id INTEGER PRIMARY KEY, counter TEXT); @@ -46,6 +45,25 @@ os.RemoveAll(tmpDir) }) + It("should be able to open a new datbase", func () { + db, err := apid.Data().DBForID("test") + Expect(err).NotTo(HaveOccurred()) + setup(db) + + var prod string + rows, err := db.Query(`SELECT counter FROM test_2 LIMIT 5`) + Expect(err).NotTo(HaveOccurred()) + defer rows.Close() + var count = 0 + for rows.Next() { + count++ + rows.Scan(&prod) + } + Expect(count).To(Equal(5)) + + //db, err := apid.Data().DBForID("test", "someid") + }) + It("should handle multi-threaded access", func(done Done) { db, err := apid.Data().DBForID("test") Expect(err).NotTo(HaveOccurred()) @@ -83,7 +101,7 @@ time.Sleep(time.Duration(r.Intn(1)) * time.Millisecond) } -func setup(db *sql.DB) { +func setup(db apid.DB) { _, err := db.Exec(setupSql) if err != nil { log.Fatal(err) @@ -101,7 +119,7 @@ tx.Commit() } -func read(db *sql.DB, i int) { +func read(db apid.DB, i int) { var prod string rows, err := db.Query(`SELECT counter FROM test_2 LIMIT 5`) if err != nil { @@ -116,7 +134,7 @@ fmt.Print(".") } -func write(db *sql.DB, i int) { +func write(db apid.DB, i int) { tx, err := db.Begin() defer tx.Rollback() if err != nil {
diff --git a/data_service.go b/data_service.go index fff6b63..6cc6ea6 100644 --- a/data_service.go +++ b/data_service.go
@@ -1,8 +1,23 @@ package apid -import "database/sql" +import ( + "database/sql" +) type DataService interface { - DB() (*sql.DB, error) - DBForID(id string) (db *sql.DB, err error) + DB() (DB, error) + DBForID(id string) (db DB, err error) +} + +type DB interface { + Ping() (error) + Prepare(query string) (*sql.Stmt, error) + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + Begin() (*sql.Tx, error) + + //Close() error + //Stats() sql.DBStats + //Driver() driver.Driver }