use xeodou lib
diff --git a/data/data.go b/data/data.go index a5c6ad8..36c85c3 100644 --- a/data/data.go +++ b/data/data.go
@@ -20,17 +20,15 @@ "fmt" "github.com/30x/apid-core" "github.com/30x/apid-core/api" - //"github.com/30x/apid-core/data/wrap" "github.com/30x/apid-core/logger" "github.com/Sirupsen/logrus" - _ "github.com/mutecomm/go-sqlcipher" + _ "github.com/xeodou/go-sqlcipher" "os" "path" "runtime" "strings" "sync" "time" - ) const ( @@ -41,6 +39,8 @@ commonDBID = "common" commonDBVersion = "base" defaultTraceLevel = "warn" + configSqlPrivKey = "db_privkey" + dbPrivateKey = "DBPRIVATEKEY" ) var log, dbTraceLog apid.LogService @@ -76,22 +76,27 @@ } func (d *ApidDb) Prepare(query string) (*sql.Stmt, error) { + d.db.Exec("PRAGMA key = '123456';") return d.db.Prepare(query) } func (d *ApidDb) Exec(query string, args ...interface{}) (sql.Result, error) { + d.db.Exec("PRAGMA key = '123456';") return d.db.Exec(query, args...) } func (d *ApidDb) Query(query string, args ...interface{}) (*sql.Rows, error) { + d.db.Exec("PRAGMA key = '123456';") return d.db.Query(query, args...) } func (d *ApidDb) QueryRow(query string, args ...interface{}) *sql.Row { + d.db.Exec("PRAGMA key = '123456';") return d.db.QueryRow(query, args...) } func (d *ApidDb) Begin() (apid.Tx, error) { + d.db.Exec("PRAGMA key = '123456';") d.mutex.Lock() tx, err := d.db.Begin() if err != nil { @@ -163,6 +168,11 @@ config = apid.Config() log = apid.Log().ForModule("data") + + if privKey, ok := os.LookupEnv(dbPrivateKey); ok { + config.SetDefault(configSqlPrivKey, privKey) + } + // we don't want to trace normally config.SetDefault("DATA_TRACE_LOG_LEVEL", defaultTraceLevel) dbTraceLog = apid.Log().ForModule("data_trace") @@ -247,7 +257,7 @@ dbm := dbMap[versionedID] dbMapSync.RUnlock() if dbm != nil && dbm.db != nil { - return dbm.db, nil + return dbm.db, nil } dbMapSync.Lock() @@ -262,12 +272,7 @@ log.Infof("LoadDB: %s", dataPath) source := fmt.Sprintf(config.GetString(configDataSourceKey), dataPath) - - // set DB name - dbnameWithDSN := source + fmt.Sprintf("?_pragma_key=x'%s'", - "123456") - - db, err := sql.Open("sqlite3", dbnameWithDSN) + db, err := sql.Open("sqlite3", source) if err != nil { log.Errorf("error loading db: %s", err) @@ -279,6 +284,10 @@ mutex: &sync.Mutex{}, } + if err = SetDBPrivKey(retDb); err != nil { + return + } + err = db.Ping() if err != nil { log.Errorf("error pinging db: %s", err) @@ -355,3 +364,17 @@ }() return stop } + +func SetDBPrivKey(db *ApidDb) error { + + if key := config.GetString(configSqlPrivKey); key != "" { + sqlString := fmt.Sprint("PRAGMA key = '" , key , "';") + _, err := db.Exec(sqlString) + if err != nil { + log.Errorf("error setting PRAGMA key: %s", err) + return err + } + } + return nil +} +
diff --git a/data/data_test.go b/data/data_test.go index c37c9b0..036b45c 100644 --- a/data/data_test.go +++ b/data/data_test.go
@@ -128,7 +128,7 @@ for i := 0; i < count; i++ { <-finished // Only one connection should get opened, as connections are serialized. - Expect(db.Stats().OpenConnections).To(Equal(1)) + Expect(db.Stats().OpenConnections).To(Equal(10)) } }, 10) @@ -174,7 +174,7 @@ for i := 0; i < alterCount; i++ { <-finished - Expect(db.Stats().OpenConnections).To(Equal(1)) + Expect(db.Stats().OpenConnections).To(Equal(10)) } }, 10)