Merge pull request #13 from 30x/APIRT-4751
Apirt 4751 - (use shared cache mode)
diff --git a/.travis.yml b/.travis.yml
index a902a1c..0fc0167 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,7 +1,7 @@
language: go
go:
- - 1.7.x
+ - 1.8.3
before_install:
- sudo add-apt-repository ppa:masterminds/glide -y
diff --git a/data/data.go b/data/data.go
index d5c893c..537ffe0 100644
--- a/data/data.go
+++ b/data/data.go
@@ -15,6 +15,7 @@
package data
import (
+ "context"
"database/sql"
"fmt"
"github.com/30x/apid-core"
@@ -38,21 +39,113 @@
statCollectionInterval = 10
commonDBID = "common"
commonDBVersion = "base"
-
- defaultTraceLevel = "warn"
+ dbOpenMode = "?cache=shared&mode=rwc"
+ defaultTraceLevel = "warn"
)
var log, dbTraceLog apid.LogService
var config apid.ConfigService
type dbMapInfo struct {
- db *sql.DB
+ db *ApidDb
closed chan bool
}
var dbMap = make(map[string]*dbMapInfo)
var dbMapSync sync.RWMutex
+type ApidDb struct {
+ db *sql.DB
+ mutex *sync.Mutex
+}
+
+func (d *ApidDb) Ping() error {
+ return d.db.Ping()
+}
+
+func (d *ApidDb) Prepare(query string) (*sql.Stmt, error) {
+ return d.db.Prepare(query)
+}
+
+func (d *ApidDb) Exec(query string, args ...interface{}) (sql.Result, error) {
+ return d.db.Exec(query, args...)
+}
+
+func (d *ApidDb) Query(query string, args ...interface{}) (*sql.Rows, error) {
+ return d.db.Query(query, args...)
+}
+
+func (d *ApidDb) QueryRow(query string, args ...interface{}) *sql.Row {
+ return d.db.QueryRow(query, args...)
+}
+
+func (d *ApidDb) Begin() (apid.Tx, error) {
+ d.mutex.Lock()
+ tx, err := d.db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ return &Tx{
+ tx: tx,
+ mutex: d.mutex,
+ }, nil
+}
+
+func (d *ApidDb) Stats() sql.DBStats {
+ return d.db.Stats()
+}
+
+type Tx struct {
+ tx *sql.Tx
+ mutex *sync.Mutex
+ closed bool
+}
+
+func (tx *Tx) Commit() error {
+ if !tx.closed {
+ defer tx.mutex.Unlock()
+ tx.closed = true
+ }
+ return tx.tx.Commit()
+}
+func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
+ return tx.tx.Exec(query, args...)
+}
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+ return tx.tx.ExecContext(ctx, query, args...)
+}
+func (tx *Tx) Prepare(query string) (*sql.Stmt, error) {
+ return tx.tx.Prepare(query)
+}
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
+ return tx.tx.PrepareContext(ctx, query)
+}
+func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
+ return tx.tx.Query(query, args...)
+}
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
+ return tx.tx.QueryContext(ctx, query, args...)
+}
+func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
+ return tx.tx.QueryRow(query, args...)
+}
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
+ return tx.tx.QueryRowContext(ctx, query, args...)
+}
+func (tx *Tx) Rollback() error {
+ if !tx.closed {
+ defer tx.mutex.Unlock()
+ tx.closed = true
+ }
+ return tx.tx.Rollback()
+}
+func (tx *Tx) Stmt(stmt *sql.Stmt) *sql.Stmt {
+ return tx.tx.Stmt(stmt)
+}
+func (tx *Tx) StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt {
+ return tx.tx.StmtContext(ctx, stmt)
+}
+
func CreateDataService() apid.DataService {
config = apid.Config()
log = apid.Log().ForModule("data")
@@ -132,7 +225,7 @@
return
}
-func (d *dataService) dbVersionForID(id, version string) (db *sql.DB, err error) {
+func (d *dataService) dbVersionForID(id, version string) (retDb *ApidDb, err error) {
var stoplogchan chan bool
versionedID := VersionedDBID(id, version)
@@ -155,7 +248,7 @@
log.Infof("LoadDB: %s", dataPath)
source := fmt.Sprintf(config.GetString(configDataSourceKey), dataPath)
-
+ source += dbOpenMode
wrappedDriverName := "dd:" + config.GetString(configDataDriverKey)
driver := wrap.NewDriver(&sqlite3.SQLiteDriver{}, dbTraceLog)
func() {
@@ -166,12 +259,18 @@
sql.Register(wrappedDriverName, driver)
}()
- db, err = sql.Open(wrappedDriverName, source)
+ db, err := sql.Open(wrappedDriverName, source)
+
if err != nil {
log.Errorf("error loading db: %s", err)
return
}
+ retDb = &ApidDb{
+ db: db,
+ mutex: &sync.Mutex{},
+ }
+
err = db.Ping()
if err != nil {
log.Errorf("error pinging db: %s", err)
@@ -199,14 +298,17 @@
db.SetMaxOpenConns(config.GetInt(api.ConfigDBMaxConns))
db.SetMaxIdleConns(config.GetInt(api.ConfigDBIdleConns))
db.SetConnMaxLifetime(time.Duration(config.GetInt(api.ConfigDBConnsTimeout)) * time.Second)
- dbInfo := dbMapInfo{db: db, closed: stoplogchan}
+ dbInfo := dbMapInfo{
+ db: retDb,
+ closed: stoplogchan,
+ }
dbMap[versionedID] = &dbInfo
return
}
func Delete(versionedID string) interface{} {
- return func(db *sql.DB) {
- err := db.Close()
+ return func(db *ApidDb) {
+ err := db.db.Close()
if err != nil {
log.Errorf("error closing DB: %v", err)
}
diff --git a/data/data_suite_test.go b/data/data_suite_test.go
index 11a4d84..1953ad0 100644
--- a/data/data_suite_test.go
+++ b/data/data_suite_test.go
@@ -18,11 +18,11 @@
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
- "testing"
"github.com/30x/apid-core"
"github.com/30x/apid-core/factory"
"io/ioutil"
"os"
+ "testing"
)
var tmpDir string
diff --git a/data/data_test.go b/data/data_test.go
index 26dd52e..71dfb26 100644
--- a/data/data_test.go
+++ b/data/data_test.go
@@ -17,25 +17,23 @@
import (
"fmt"
"github.com/30x/apid-core"
+ "github.com/30x/apid-core/data"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
- "log"
"math/rand"
"strconv"
"time"
- "github.com/30x/apid-core/data"
- "database/sql"
)
const (
- count = 2000
+ count = 5000
setupSql = `
CREATE TABLE test_1 (id INTEGER PRIMARY KEY, counter TEXT);
CREATE TABLE test_2 (id INTEGER PRIMARY KEY, counter TEXT);`
)
var (
- r *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
+ r *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
var _ = Describe("Data Service", func() {
@@ -77,90 +75,95 @@
Expect(err).NotTo(HaveOccurred())
setup(db)
id := data.VersionedDBID("release", "version")
- sqlDB := db.(*sql.DB)
- Expect(sqlDB.Stats().OpenConnections).To(Equal(1))
+ Expect(db.Stats().OpenConnections).To(Equal(1))
// run finalizer
- data.Delete(id).(func(db *sql.DB))(sqlDB)
- Expect(sqlDB.Stats().OpenConnections).To(Equal(0))
+ data.Delete(id).(func(db *data.ApidDb))(db.(*data.ApidDb))
+ Expect(db.Stats().OpenConnections).To(Equal(0))
Expect(data.DBPath(id)).ShouldNot(BeAnExistingFile())
})
- It("should handle multi-threaded access", func(done Done) {
+ It("should handle concurrent read & serialized write", func() {
db, err := apid.Data().DBForID("test")
Expect(err).NotTo(HaveOccurred())
setup(db)
- finished := make(chan struct{})
+ finished := make(chan bool, count+1)
go func() {
for i := 0; i < count; i++ {
write(db, i)
}
- finished <- struct{}{}
+ finished <- true
}()
- go func() {
- for i := 0; i < count; i++ {
- go func() {
- read(db)
- finished <- struct{}{}
- }()
- time.Sleep(time.Duration(r.Intn(2)) * time.Millisecond)
- }
- }()
+ for i := 0; i < count; i++ {
+ go func() {
+ read(db)
+ finished <- true
+ }()
+ }
for i := 0; i < count+1; i++ {
<-finished
}
+ }, 10)
- close(done)
+ It("should handle concurrent write", func() {
+ db, err := apid.Data().DBForID("test_write")
+ Expect(err).NotTo(HaveOccurred())
+ setup(db)
+ finished := make(chan bool, count)
+
+ for i := 0; i < count; i++ {
+ go func() {
+ write(db, i)
+ finished <- true
+ }()
+ }
+
+ for i := 0; i < count; i++ {
+ <-finished
+ }
}, 10)
})
func setup(db apid.DB) {
_, err := db.Exec(setupSql)
- if err != nil {
- log.Fatal(err)
- }
+ Expect(err).Should(Succeed())
tx, err := db.Begin()
- if err != nil {
- log.Fatal(err)
- }
+ Expect(err).Should(Succeed())
for i := 0; i < count; i++ {
_, err := tx.Exec("INSERT INTO test_2 (counter) VALUES (?);", strconv.Itoa(i))
- if err != nil {
- log.Fatalf("filling up test_2 table failed. Exec error=%s", err)
- }
+ Expect(err).Should(Succeed())
}
- tx.Commit()
+ Expect(tx.Commit()).Should(Succeed())
}
func read(db apid.DB) {
+ defer GinkgoRecover()
var counter string
rows, err := db.Query(`SELECT counter FROM test_2 LIMIT 5`)
- if err != nil {
- log.Fatalf("test_2 select failed. Exec error=%s", err)
- } else {
- defer rows.Close()
- for rows.Next() {
- rows.Scan(&counter)
- //fmt.Print("*")
- }
+ Expect(err).Should(Succeed())
+ defer rows.Close()
+ for rows.Next() {
+ rows.Scan(&counter)
}
fmt.Print(".")
}
func write(db apid.DB, i int) {
+ defer GinkgoRecover()
+ // DB INSERT as a txn
tx, err := db.Begin()
+ Expect(err).Should(Succeed())
defer tx.Rollback()
- if err != nil {
- log.Fatalf("Write failed. Exec error=%s", err)
- }
prep, err := tx.Prepare("INSERT INTO test_1 (counter) VALUES ($1);")
- _, err = tx.Stmt(prep).Exec(strconv.Itoa(i))
- if err != nil {
- log.Fatalf("Write failed. Exec error=%s", err)
- }
- prep.Close()
- tx.Commit()
+ Expect(err).Should(Succeed())
+ _, err = prep.Exec(strconv.Itoa(i))
+ Expect(err).Should(Succeed())
+ Expect(prep.Close()).Should(Succeed())
+ Expect(tx.Commit()).Should(Succeed())
+ // DB INSERT directly, not via a txn
+ //_, err = db.Exec("INSERT INTO test_1 (counter) VALUES ($?)", i+10000)
+ //Expect(err).Should(Succeed())
fmt.Print("+")
}
diff --git a/data_service.go b/data_service.go
index 7a5a3d0..c606e10 100644
--- a/data_service.go
+++ b/data_service.go
@@ -15,6 +15,7 @@
package apid
import (
+ "context"
"database/sql"
)
@@ -37,9 +38,24 @@
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
- Begin() (*sql.Tx, error)
-
+ Begin() (Tx, error)
+ Stats() sql.DBStats
//Close() error
//Stats() sql.DBStats
//Driver() driver.Driver
}
+
+type Tx interface {
+ Commit() error
+ Exec(query string, args ...interface{}) (sql.Result, error)
+ ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
+ Prepare(query string) (*sql.Stmt, error)
+ PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
+ Query(query string, args ...interface{}) (*sql.Rows, error)
+ QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
+ QueryRow(query string, args ...interface{}) *sql.Row
+ QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
+ Rollback() error
+ Stmt(stmt *sql.Stmt) *sql.Stmt
+ StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
+}