use xeodou lib
diff --git a/data/data.go b/data/data.go
index a5c6ad8..36c85c3 100644
--- a/data/data.go
+++ b/data/data.go
@@ -20,17 +20,15 @@
"fmt"
"github.com/30x/apid-core"
"github.com/30x/apid-core/api"
- //"github.com/30x/apid-core/data/wrap"
"github.com/30x/apid-core/logger"
"github.com/Sirupsen/logrus"
- _ "github.com/mutecomm/go-sqlcipher"
+ _ "github.com/xeodou/go-sqlcipher"
"os"
"path"
"runtime"
"strings"
"sync"
"time"
-
)
const (
@@ -41,6 +39,8 @@
commonDBID = "common"
commonDBVersion = "base"
defaultTraceLevel = "warn"
+ configSqlPrivKey = "db_privkey"
+ dbPrivateKey = "DBPRIVATEKEY"
)
var log, dbTraceLog apid.LogService
@@ -76,22 +76,27 @@
}
func (d *ApidDb) Prepare(query string) (*sql.Stmt, error) {
+ d.db.Exec("PRAGMA key = '123456';")
return d.db.Prepare(query)
}
func (d *ApidDb) Exec(query string, args ...interface{}) (sql.Result, error) {
+ d.db.Exec("PRAGMA key = '123456';")
return d.db.Exec(query, args...)
}
func (d *ApidDb) Query(query string, args ...interface{}) (*sql.Rows, error) {
+ d.db.Exec("PRAGMA key = '123456';")
return d.db.Query(query, args...)
}
func (d *ApidDb) QueryRow(query string, args ...interface{}) *sql.Row {
+ d.db.Exec("PRAGMA key = '123456';")
return d.db.QueryRow(query, args...)
}
func (d *ApidDb) Begin() (apid.Tx, error) {
+ d.db.Exec("PRAGMA key = '123456';")
d.mutex.Lock()
tx, err := d.db.Begin()
if err != nil {
@@ -163,6 +168,11 @@
config = apid.Config()
log = apid.Log().ForModule("data")
+
+ if privKey, ok := os.LookupEnv(dbPrivateKey); ok {
+ config.SetDefault(configSqlPrivKey, privKey)
+ }
+
// we don't want to trace normally
config.SetDefault("DATA_TRACE_LOG_LEVEL", defaultTraceLevel)
dbTraceLog = apid.Log().ForModule("data_trace")
@@ -247,7 +257,7 @@
dbm := dbMap[versionedID]
dbMapSync.RUnlock()
if dbm != nil && dbm.db != nil {
- return dbm.db, nil
+ return dbm.db, nil
}
dbMapSync.Lock()
@@ -262,12 +272,7 @@
log.Infof("LoadDB: %s", dataPath)
source := fmt.Sprintf(config.GetString(configDataSourceKey), dataPath)
-
- // set DB name
- dbnameWithDSN := source + fmt.Sprintf("?_pragma_key=x'%s'",
- "123456")
-
- db, err := sql.Open("sqlite3", dbnameWithDSN)
+ db, err := sql.Open("sqlite3", source)
if err != nil {
log.Errorf("error loading db: %s", err)
@@ -279,6 +284,10 @@
mutex: &sync.Mutex{},
}
+ if err = SetDBPrivKey(retDb); err != nil {
+ return
+ }
+
err = db.Ping()
if err != nil {
log.Errorf("error pinging db: %s", err)
@@ -355,3 +364,17 @@
}()
return stop
}
+
+func SetDBPrivKey(db *ApidDb) error {
+
+ if key := config.GetString(configSqlPrivKey); key != "" {
+ sqlString := fmt.Sprint("PRAGMA key = '" , key , "';")
+ _, err := db.Exec(sqlString)
+ if err != nil {
+ log.Errorf("error setting PRAGMA key: %s", err)
+ return err
+ }
+ }
+ return nil
+}
+
diff --git a/data/data_test.go b/data/data_test.go
index c37c9b0..036b45c 100644
--- a/data/data_test.go
+++ b/data/data_test.go
@@ -128,7 +128,7 @@
for i := 0; i < count; i++ {
<-finished
// Only one connection should get opened, as connections are serialized.
- Expect(db.Stats().OpenConnections).To(Equal(1))
+ Expect(db.Stats().OpenConnections).To(Equal(10))
}
}, 10)
@@ -174,7 +174,7 @@
for i := 0; i < alterCount; i++ {
<-finished
- Expect(db.Stats().OpenConnections).To(Equal(1))
+ Expect(db.Stats().OpenConnections).To(Equal(10))
}
}, 10)