[ISSUE-65205927] add mutex to begin transaction, to serialize conncurrent writes
diff --git a/data/data.go b/data/data.go index d9d1404..537ffe0 100644 --- a/data/data.go +++ b/data/data.go
@@ -15,6 +15,7 @@ package data import ( + "context" "database/sql" "fmt" "github.com/30x/apid-core" @@ -39,20 +40,112 @@ commonDBID = "common" commonDBVersion = "base" dbOpenMode = "?cache=shared&mode=rwc" - defaultTraceLevel = "warn" + defaultTraceLevel = "warn" ) var log, dbTraceLog apid.LogService var config apid.ConfigService type dbMapInfo struct { - db *sql.DB + db *ApidDb closed chan bool } var dbMap = make(map[string]*dbMapInfo) var dbMapSync sync.RWMutex +type ApidDb struct { + db *sql.DB + mutex *sync.Mutex +} + +func (d *ApidDb) Ping() error { + return d.db.Ping() +} + +func (d *ApidDb) Prepare(query string) (*sql.Stmt, error) { + return d.db.Prepare(query) +} + +func (d *ApidDb) Exec(query string, args ...interface{}) (sql.Result, error) { + return d.db.Exec(query, args...) +} + +func (d *ApidDb) Query(query string, args ...interface{}) (*sql.Rows, error) { + return d.db.Query(query, args...) +} + +func (d *ApidDb) QueryRow(query string, args ...interface{}) *sql.Row { + return d.db.QueryRow(query, args...) +} + +func (d *ApidDb) Begin() (apid.Tx, error) { + d.mutex.Lock() + tx, err := d.db.Begin() + if err != nil { + return nil, err + } + return &Tx{ + tx: tx, + mutex: d.mutex, + }, nil +} + +func (d *ApidDb) Stats() sql.DBStats { + return d.db.Stats() +} + +type Tx struct { + tx *sql.Tx + mutex *sync.Mutex + closed bool +} + +func (tx *Tx) Commit() error { + if !tx.closed { + defer tx.mutex.Unlock() + tx.closed = true + } + return tx.tx.Commit() +} +func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.tx.Exec(query, args...) +} +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return tx.tx.ExecContext(ctx, query, args...) +} +func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { + return tx.tx.Prepare(query) +} +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return tx.tx.PrepareContext(ctx, query) +} +func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { + return tx.tx.Query(query, args...) +} +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return tx.tx.QueryContext(ctx, query, args...) +} +func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row { + return tx.tx.QueryRow(query, args...) +} +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return tx.tx.QueryRowContext(ctx, query, args...) +} +func (tx *Tx) Rollback() error { + if !tx.closed { + defer tx.mutex.Unlock() + tx.closed = true + } + return tx.tx.Rollback() +} +func (tx *Tx) Stmt(stmt *sql.Stmt) *sql.Stmt { + return tx.tx.Stmt(stmt) +} +func (tx *Tx) StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt { + return tx.tx.StmtContext(ctx, stmt) +} + func CreateDataService() apid.DataService { config = apid.Config() log = apid.Log().ForModule("data") @@ -132,7 +225,7 @@ return } -func (d *dataService) dbVersionForID(id, version string) (db *sql.DB, err error) { +func (d *dataService) dbVersionForID(id, version string) (retDb *ApidDb, err error) { var stoplogchan chan bool versionedID := VersionedDBID(id, version) @@ -166,12 +259,18 @@ sql.Register(wrappedDriverName, driver) }() - db, err = sql.Open(wrappedDriverName, source) + db, err := sql.Open(wrappedDriverName, source) + if err != nil { log.Errorf("error loading db: %s", err) return } + retDb = &ApidDb{ + db: db, + mutex: &sync.Mutex{}, + } + err = db.Ping() if err != nil { log.Errorf("error pinging db: %s", err) @@ -199,14 +298,17 @@ db.SetMaxOpenConns(config.GetInt(api.ConfigDBMaxConns)) db.SetMaxIdleConns(config.GetInt(api.ConfigDBIdleConns)) db.SetConnMaxLifetime(time.Duration(config.GetInt(api.ConfigDBConnsTimeout)) * time.Second) - dbInfo := dbMapInfo{db: db, closed: stoplogchan} + dbInfo := dbMapInfo{ + db: retDb, + closed: stoplogchan, + } dbMap[versionedID] = &dbInfo return } func Delete(versionedID string) interface{} { - return func(db *sql.DB) { - err := db.Close() + return func(db *ApidDb) { + err := db.db.Close() if err != nil { log.Errorf("error closing DB: %v", err) }
diff --git a/data/data_suite_test.go b/data/data_suite_test.go index 11a4d84..1953ad0 100644 --- a/data/data_suite_test.go +++ b/data/data_suite_test.go
@@ -18,11 +18,11 @@ . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "testing" "github.com/30x/apid-core" "github.com/30x/apid-core/factory" "io/ioutil" "os" + "testing" ) var tmpDir string
diff --git a/data/data_test.go b/data/data_test.go index b4d20ae..71dfb26 100644 --- a/data/data_test.go +++ b/data/data_test.go
@@ -17,14 +17,12 @@ import ( "fmt" "github.com/30x/apid-core" + "github.com/30x/apid-core/data" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "log" "math/rand" "strconv" "time" - "github.com/30x/apid-core/data" - "database/sql" ) const ( @@ -35,7 +33,7 @@ ) var ( - r *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + r *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) ) var _ = Describe("Data Service", func() { @@ -77,94 +75,95 @@ Expect(err).NotTo(HaveOccurred()) setup(db) id := data.VersionedDBID("release", "version") - sqlDB := db.(*sql.DB) - Expect(sqlDB.Stats().OpenConnections).To(Equal(1)) + Expect(db.Stats().OpenConnections).To(Equal(1)) // run finalizer - data.Delete(id).(func(db *sql.DB))(sqlDB) - Expect(sqlDB.Stats().OpenConnections).To(Equal(0)) + data.Delete(id).(func(db *data.ApidDb))(db.(*data.ApidDb)) + Expect(db.Stats().OpenConnections).To(Equal(0)) Expect(data.DBPath(id)).ShouldNot(BeAnExistingFile()) }) - It("should handle multi-threaded access", func(done Done) { + It("should handle concurrent read & serialized write", func() { db, err := apid.Data().DBForID("test") Expect(err).NotTo(HaveOccurred()) setup(db) - finished := make(chan struct{}) + finished := make(chan bool, count+1) go func() { for i := 0; i < count; i++ { write(db, i) } - finished <- struct{}{} + finished <- true }() - go func() { - for i := 0; i < count; i++ { - go func() { - read(db) - finished <- struct{}{} - }() - time.Sleep(time.Duration(r.Intn(2)) * time.Millisecond) - } - }() + for i := 0; i < count; i++ { + go func() { + read(db) + finished <- true + }() + } for i := 0; i < count+1; i++ { <-finished } + }, 10) - close(done) + It("should handle concurrent write", func() { + db, err := apid.Data().DBForID("test_write") + Expect(err).NotTo(HaveOccurred()) + setup(db) + finished := make(chan bool, count) + + for i := 0; i < count; i++ { + go func() { + write(db, i) + finished <- true + }() + } + + for i := 0; i < count; i++ { + <-finished + } }, 10) }) func setup(db apid.DB) { _, err := db.Exec(setupSql) - if err != nil { - log.Fatal(err) - } + Expect(err).Should(Succeed()) tx, err := db.Begin() - if err != nil { - log.Fatal(err) - } + Expect(err).Should(Succeed()) for i := 0; i < count; i++ { _, err := tx.Exec("INSERT INTO test_2 (counter) VALUES (?);", strconv.Itoa(i)) - if err != nil { - log.Fatalf("filling up test_2 table failed. Exec error=%s", err) - } + Expect(err).Should(Succeed()) } - tx.Commit() + Expect(tx.Commit()).Should(Succeed()) } func read(db apid.DB) { + defer GinkgoRecover() var counter string rows, err := db.Query(`SELECT counter FROM test_2 LIMIT 5`) - if err != nil { - log.Fatalf("test_2 select failed. Exec error=%s", err) - } else { - defer rows.Close() - for rows.Next() { - rows.Scan(&counter) - //fmt.Print("*") - } + Expect(err).Should(Succeed()) + defer rows.Close() + for rows.Next() { + rows.Scan(&counter) } fmt.Print(".") } func write(db apid.DB, i int) { - + defer GinkgoRecover() // DB INSERT as a txn tx, err := db.Begin() + Expect(err).Should(Succeed()) defer tx.Rollback() - if err != nil { - log.Fatalf("Write failed. Exec error=%s", err) - } prep, err := tx.Prepare("INSERT INTO test_1 (counter) VALUES ($1);") - _, err = tx.Stmt(prep).Exec(strconv.Itoa(i)) - if err != nil { - log.Fatalf("Write failed. Exec error=%s", err) - } - prep.Close() - tx.Commit() + Expect(err).Should(Succeed()) + _, err = prep.Exec(strconv.Itoa(i)) + Expect(err).Should(Succeed()) + Expect(prep.Close()).Should(Succeed()) + Expect(tx.Commit()).Should(Succeed()) // DB INSERT directly, not via a txn - db.Exec("INSERT INTO test_1 (counter) VALUES ($?)", i + 10000) + //_, err = db.Exec("INSERT INTO test_1 (counter) VALUES ($?)", i+10000) + //Expect(err).Should(Succeed()) fmt.Print("+") }
diff --git a/data_service.go b/data_service.go index 7a5a3d0..c606e10 100644 --- a/data_service.go +++ b/data_service.go
@@ -15,6 +15,7 @@ package apid import ( + "context" "database/sql" ) @@ -37,9 +38,24 @@ Exec(query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row - Begin() (*sql.Tx, error) - + Begin() (Tx, error) + Stats() sql.DBStats //Close() error //Stats() sql.DBStats //Driver() driver.Driver } + +type Tx interface { + Commit() error + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + Rollback() error + Stmt(stmt *sql.Stmt) *sql.Stmt + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +}