Merge pull request #13 from 30x/APIRT-4751

Apirt 4751 - (use shared cache mode)
diff --git a/.travis.yml b/.travis.yml
index a902a1c..0fc0167 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,7 +1,7 @@
 language: go
 
 go:
-  - 1.7.x
+  - 1.8.3
 
 before_install:
   - sudo add-apt-repository ppa:masterminds/glide -y
diff --git a/data/data.go b/data/data.go
index d5c893c..537ffe0 100644
--- a/data/data.go
+++ b/data/data.go
@@ -15,6 +15,7 @@
 package data
 
 import (
+	"context"
 	"database/sql"
 	"fmt"
 	"github.com/30x/apid-core"
@@ -38,21 +39,113 @@
 	statCollectionInterval = 10
 	commonDBID             = "common"
 	commonDBVersion        = "base"
-
-	defaultTraceLevel = "warn"
+	dbOpenMode             = "?cache=shared&mode=rwc"
+	defaultTraceLevel      = "warn"
 )
 
 var log, dbTraceLog apid.LogService
 var config apid.ConfigService
 
 type dbMapInfo struct {
-	db     *sql.DB
+	db     *ApidDb
 	closed chan bool
 }
 
 var dbMap = make(map[string]*dbMapInfo)
 var dbMapSync sync.RWMutex
 
+type ApidDb struct {
+	db    *sql.DB
+	mutex *sync.Mutex
+}
+
+func (d *ApidDb) Ping() error {
+	return d.db.Ping()
+}
+
+func (d *ApidDb) Prepare(query string) (*sql.Stmt, error) {
+	return d.db.Prepare(query)
+}
+
+func (d *ApidDb) Exec(query string, args ...interface{}) (sql.Result, error) {
+	return d.db.Exec(query, args...)
+}
+
+func (d *ApidDb) Query(query string, args ...interface{}) (*sql.Rows, error) {
+	return d.db.Query(query, args...)
+}
+
+func (d *ApidDb) QueryRow(query string, args ...interface{}) *sql.Row {
+	return d.db.QueryRow(query, args...)
+}
+
+func (d *ApidDb) Begin() (apid.Tx, error) {
+	d.mutex.Lock()
+	tx, err := d.db.Begin()
+	if err != nil {
+		return nil, err
+	}
+	return &Tx{
+		tx:    tx,
+		mutex: d.mutex,
+	}, nil
+}
+
+func (d *ApidDb) Stats() sql.DBStats {
+	return d.db.Stats()
+}
+
+type Tx struct {
+	tx     *sql.Tx
+	mutex  *sync.Mutex
+	closed bool
+}
+
+func (tx *Tx) Commit() error {
+	if !tx.closed {
+		defer tx.mutex.Unlock()
+		tx.closed = true
+	}
+	return tx.tx.Commit()
+}
+func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
+	return tx.tx.Exec(query, args...)
+}
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+	return tx.tx.ExecContext(ctx, query, args...)
+}
+func (tx *Tx) Prepare(query string) (*sql.Stmt, error) {
+	return tx.tx.Prepare(query)
+}
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
+	return tx.tx.PrepareContext(ctx, query)
+}
+func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
+	return tx.tx.Query(query, args...)
+}
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
+	return tx.tx.QueryContext(ctx, query, args...)
+}
+func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
+	return tx.tx.QueryRow(query, args...)
+}
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
+	return tx.tx.QueryRowContext(ctx, query, args...)
+}
+func (tx *Tx) Rollback() error {
+	if !tx.closed {
+		defer tx.mutex.Unlock()
+		tx.closed = true
+	}
+	return tx.tx.Rollback()
+}
+func (tx *Tx) Stmt(stmt *sql.Stmt) *sql.Stmt {
+	return tx.tx.Stmt(stmt)
+}
+func (tx *Tx) StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt {
+	return tx.tx.StmtContext(ctx, stmt)
+}
+
 func CreateDataService() apid.DataService {
 	config = apid.Config()
 	log = apid.Log().ForModule("data")
@@ -132,7 +225,7 @@
 	return
 }
 
-func (d *dataService) dbVersionForID(id, version string) (db *sql.DB, err error) {
+func (d *dataService) dbVersionForID(id, version string) (retDb *ApidDb, err error) {
 
 	var stoplogchan chan bool
 	versionedID := VersionedDBID(id, version)
@@ -155,7 +248,7 @@
 
 	log.Infof("LoadDB: %s", dataPath)
 	source := fmt.Sprintf(config.GetString(configDataSourceKey), dataPath)
-
+	source += dbOpenMode
 	wrappedDriverName := "dd:" + config.GetString(configDataDriverKey)
 	driver := wrap.NewDriver(&sqlite3.SQLiteDriver{}, dbTraceLog)
 	func() {
@@ -166,12 +259,18 @@
 		sql.Register(wrappedDriverName, driver)
 	}()
 
-	db, err = sql.Open(wrappedDriverName, source)
+	db, err := sql.Open(wrappedDriverName, source)
+
 	if err != nil {
 		log.Errorf("error loading db: %s", err)
 		return
 	}
 
+	retDb = &ApidDb{
+		db:    db,
+		mutex: &sync.Mutex{},
+	}
+
 	err = db.Ping()
 	if err != nil {
 		log.Errorf("error pinging db: %s", err)
@@ -199,14 +298,17 @@
 	db.SetMaxOpenConns(config.GetInt(api.ConfigDBMaxConns))
 	db.SetMaxIdleConns(config.GetInt(api.ConfigDBIdleConns))
 	db.SetConnMaxLifetime(time.Duration(config.GetInt(api.ConfigDBConnsTimeout)) * time.Second)
-	dbInfo := dbMapInfo{db: db, closed: stoplogchan}
+	dbInfo := dbMapInfo{
+		db:     retDb,
+		closed: stoplogchan,
+	}
 	dbMap[versionedID] = &dbInfo
 	return
 }
 
 func Delete(versionedID string) interface{} {
-	return func(db *sql.DB) {
-		err := db.Close()
+	return func(db *ApidDb) {
+		err := db.db.Close()
 		if err != nil {
 			log.Errorf("error closing DB: %v", err)
 		}
diff --git a/data/data_suite_test.go b/data/data_suite_test.go
index 11a4d84..1953ad0 100644
--- a/data/data_suite_test.go
+++ b/data/data_suite_test.go
@@ -18,11 +18,11 @@
 	. "github.com/onsi/ginkgo"
 	. "github.com/onsi/gomega"
 
-	"testing"
 	"github.com/30x/apid-core"
 	"github.com/30x/apid-core/factory"
 	"io/ioutil"
 	"os"
+	"testing"
 )
 
 var tmpDir string
diff --git a/data/data_test.go b/data/data_test.go
index 26dd52e..71dfb26 100644
--- a/data/data_test.go
+++ b/data/data_test.go
@@ -17,25 +17,23 @@
 import (
 	"fmt"
 	"github.com/30x/apid-core"
+	"github.com/30x/apid-core/data"
 	. "github.com/onsi/ginkgo"
 	. "github.com/onsi/gomega"
-	"log"
 	"math/rand"
 	"strconv"
 	"time"
-	"github.com/30x/apid-core/data"
-	"database/sql"
 )
 
 const (
-	count    = 2000
+	count    = 5000
 	setupSql = `
 		CREATE TABLE test_1 (id INTEGER PRIMARY KEY, counter TEXT);
 		CREATE TABLE test_2 (id INTEGER PRIMARY KEY, counter TEXT);`
 )
 
 var (
-	r      *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
+	r *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
 )
 
 var _ = Describe("Data Service", func() {
@@ -77,90 +75,95 @@
 		Expect(err).NotTo(HaveOccurred())
 		setup(db)
 		id := data.VersionedDBID("release", "version")
-		sqlDB := db.(*sql.DB)
-		Expect(sqlDB.Stats().OpenConnections).To(Equal(1))
+		Expect(db.Stats().OpenConnections).To(Equal(1))
 		// run finalizer
-		data.Delete(id).(func(db *sql.DB))(sqlDB)
-		Expect(sqlDB.Stats().OpenConnections).To(Equal(0))
+		data.Delete(id).(func(db *data.ApidDb))(db.(*data.ApidDb))
+		Expect(db.Stats().OpenConnections).To(Equal(0))
 		Expect(data.DBPath(id)).ShouldNot(BeAnExistingFile())
 	})
 
-	It("should handle multi-threaded access", func(done Done) {
+	It("should handle concurrent read & serialized write", func() {
 		db, err := apid.Data().DBForID("test")
 		Expect(err).NotTo(HaveOccurred())
 		setup(db)
-		finished := make(chan struct{})
+		finished := make(chan bool, count+1)
 
 		go func() {
 			for i := 0; i < count; i++ {
 				write(db, i)
 			}
-			finished <- struct{}{}
+			finished <- true
 		}()
 
-		go func() {
-			for i := 0; i < count; i++ {
-				go func() {
-					read(db)
-					finished <- struct{}{}
-				}()
-				time.Sleep(time.Duration(r.Intn(2)) * time.Millisecond)
-			}
-		}()
+		for i := 0; i < count; i++ {
+			go func() {
+				read(db)
+				finished <- true
+			}()
+		}
 
 		for i := 0; i < count+1; i++ {
 			<-finished
 		}
+	}, 10)
 
-		close(done)
+	It("should handle concurrent write", func() {
+		db, err := apid.Data().DBForID("test_write")
+		Expect(err).NotTo(HaveOccurred())
+		setup(db)
+		finished := make(chan bool, count)
+
+		for i := 0; i < count; i++ {
+			go func() {
+				write(db, i)
+				finished <- true
+			}()
+		}
+
+		for i := 0; i < count; i++ {
+			<-finished
+		}
 	}, 10)
 })
 
 func setup(db apid.DB) {
 	_, err := db.Exec(setupSql)
-	if err != nil {
-		log.Fatal(err)
-	}
+	Expect(err).Should(Succeed())
 	tx, err := db.Begin()
-	if err != nil {
-		log.Fatal(err)
-	}
+	Expect(err).Should(Succeed())
 	for i := 0; i < count; i++ {
 		_, err := tx.Exec("INSERT INTO test_2 (counter) VALUES (?);", strconv.Itoa(i))
-		if err != nil {
-			log.Fatalf("filling up test_2 table failed. Exec error=%s", err)
-		}
+		Expect(err).Should(Succeed())
 	}
-	tx.Commit()
+	Expect(tx.Commit()).Should(Succeed())
 }
 
 func read(db apid.DB) {
+	defer GinkgoRecover()
 	var counter string
 	rows, err := db.Query(`SELECT counter FROM test_2 LIMIT 5`)
-	if err != nil {
-		log.Fatalf("test_2 select failed. Exec error=%s", err)
-	} else {
-		defer rows.Close()
-		for rows.Next() {
-			rows.Scan(&counter)
-			//fmt.Print("*")
-		}
+	Expect(err).Should(Succeed())
+	defer rows.Close()
+	for rows.Next() {
+		rows.Scan(&counter)
 	}
 	fmt.Print(".")
 }
 
 func write(db apid.DB, i int) {
+	defer GinkgoRecover()
+	// DB INSERT as a txn
 	tx, err := db.Begin()
+	Expect(err).Should(Succeed())
 	defer tx.Rollback()
-	if err != nil {
-		log.Fatalf("Write failed. Exec error=%s", err)
-	}
 	prep, err := tx.Prepare("INSERT INTO test_1 (counter) VALUES ($1);")
-	_, err = tx.Stmt(prep).Exec(strconv.Itoa(i))
-	if err != nil {
-		log.Fatalf("Write failed. Exec error=%s", err)
-	}
-	prep.Close()
-	tx.Commit()
+	Expect(err).Should(Succeed())
+	_, err = prep.Exec(strconv.Itoa(i))
+	Expect(err).Should(Succeed())
+	Expect(prep.Close()).Should(Succeed())
+	Expect(tx.Commit()).Should(Succeed())
+	// DB INSERT directly, not via a txn
+	//_, err = db.Exec("INSERT INTO test_1 (counter) VALUES ($?)", i+10000)
+	//Expect(err).Should(Succeed())
 	fmt.Print("+")
 }
diff --git a/data_service.go b/data_service.go
index 7a5a3d0..c606e10 100644
--- a/data_service.go
+++ b/data_service.go
@@ -15,6 +15,7 @@
 package apid
 
 import (
+	"context"
 	"database/sql"
 )
 
@@ -37,9 +38,24 @@
 	Exec(query string, args ...interface{}) (sql.Result, error)
 	Query(query string, args ...interface{}) (*sql.Rows, error)
 	QueryRow(query string, args ...interface{}) *sql.Row
-	Begin() (*sql.Tx, error)
-
+	Begin() (Tx, error)
+	Stats() sql.DBStats
 	//Close() error
 	//Stats() sql.DBStats
 	//Driver() driver.Driver
 }
+
+type Tx interface {
+	Commit() error
+	Exec(query string, args ...interface{}) (sql.Result, error)
+	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
+	Prepare(query string) (*sql.Stmt, error)
+	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
+	Query(query string, args ...interface{}) (*sql.Rows, error)
+	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
+	QueryRow(query string, args ...interface{}) *sql.Row
+	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
+	Rollback() error
+	Stmt(stmt *sql.Stmt) *sql.Stmt
+	StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
+}