b/68272561 Add generic "sql.rows->[]struct" mapping to apid-core (#27)
* [ISSUE-68272561] add StructsFromRows
* [ISSUE-68272561] support sql.Null***
* [ISSUE-68272561] rename test table
* [ISSUE-68272561] add QueryStructs to db and tx
diff --git a/data/data.go b/data/data.go
index 744ba6f..0ae6424 100644
--- a/data/data.go
+++ b/data/data.go
@@ -18,14 +18,15 @@
"context"
"database/sql"
"fmt"
+ "github.com/Sirupsen/logrus"
"github.com/apid/apid-core"
"github.com/apid/apid-core/api"
"github.com/apid/apid-core/data/wrap"
"github.com/apid/apid-core/logger"
- "github.com/Sirupsen/logrus"
"github.com/mattn/go-sqlite3"
"os"
"path"
+ "reflect"
"runtime"
"strings"
"sync"
@@ -52,6 +53,7 @@
var dbMap = make(map[string]*dbMapInfo)
var dbMapSync sync.RWMutex
+var tagFieldMapper = make(map[reflect.Type]map[string]string)
type ApidDb struct {
db *sql.DB
@@ -90,6 +92,16 @@
return d.db.QueryRow(query, args...)
}
+func (d *ApidDb) QueryStructs(dest interface{}, query string, args ...interface{}) error {
+ rows, err := d.db.Query(query, args...)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+ err = StructsFromRows(dest, rows)
+ return err
+}
+
func (d *ApidDb) Begin() (apid.Tx, error) {
d.mutex.Lock()
tx, err := d.db.Begin()
@@ -158,6 +170,16 @@
return tx.tx.StmtContext(ctx, stmt)
}
+func (tx *Tx) QueryStructs(dest interface{}, query string, args ...interface{}) error {
+ rows, err := tx.tx.Query(query, args...)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+ err = StructsFromRows(dest, rows)
+ return err
+}
+
func CreateDataService() apid.DataService {
config = apid.Config()
log = apid.Log().ForModule("data")
@@ -358,3 +380,128 @@
}()
return stop
}
+
+// StructsFromRows fill the dest slice with the values of according rows.
+// Each row is marshaled into a struct. The "db" tag in the struct is used for field mapping.
+// It will take care of null value. Supported type mappings from Sqlite3 to Go are:
+// text->string; integer->int/int64/sql.NullInt64; float->float/float64/sql.NullFloat64;
+// blob->[]byte/string/sql.NullString
+func StructsFromRows(dest interface{}, rows *sql.Rows) error {
+ t := reflect.TypeOf(dest)
+ if t == nil {
+ return nil
+ }
+ // type of the struct
+ t = t.Elem().Elem()
+ //build mapper if not existent
+ m, ok := tagFieldMapper[t]
+ if !ok {
+ m = make(map[string]string)
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ if key := f.Tag.Get("db"); key != "" {
+ m[key] = f.Name
+ }
+
+ }
+ tagFieldMapper[t] = m
+ }
+
+ colNames, err := rows.Columns()
+ if err != nil {
+ return err
+ }
+ colTypes, err := rows.ColumnTypes()
+ if err != nil {
+ return err
+ }
+
+ cols := make([]interface{}, len(colNames))
+ slice := reflect.New(reflect.SliceOf(t)).Elem()
+
+ for i := range cols {
+ switch colTypes[i].DatabaseTypeName() {
+ case "null":
+ cols[i] = new(sql.NullString)
+ case "text":
+ cols[i] = new(sql.NullString)
+ case "integer":
+ cols[i] = new(sql.NullInt64)
+ case "float":
+ cols[i] = new(sql.NullFloat64)
+ case "blob":
+ cols[i] = new([]byte)
+ default:
+ return fmt.Errorf("unsupprted column type: %s", colTypes[i].DatabaseTypeName())
+ }
+ }
+ for rows.Next() {
+ v := reflect.New(t).Elem()
+ err := rows.Scan(cols...)
+ if err != nil {
+ return err
+ }
+ for i := range cols {
+ switch c := cols[i].(type) {
+ case *sql.NullString:
+ if f := v.FieldByName(m[colNames[i]]); f.IsValid() {
+ if reflect.TypeOf(*c).AssignableTo(f.Type()) {
+ f.Set(reflect.ValueOf(c).Elem())
+ } else if reflect.TypeOf("").AssignableTo(f.Type()) {
+ if c.Valid {
+ f.SetString(c.String)
+ }
+ } else {
+ return fmt.Errorf("cannot convert column type %s to field type %s",
+ colTypes[i].DatabaseTypeName(), f.Type().String())
+ }
+
+ }
+ case *sql.NullInt64:
+ if f := v.FieldByName(m[colNames[i]]); f.IsValid() {
+ if reflect.TypeOf(*c).AssignableTo(f.Type()) {
+ f.Set(reflect.ValueOf(c).Elem())
+ } else if reflect.TypeOf(int64(0)).ConvertibleTo(f.Type()) {
+ if c.Valid {
+ f.SetInt(c.Int64)
+ }
+ } else {
+ return fmt.Errorf("cannot convert column type %s to field type %s",
+ colTypes[i].DatabaseTypeName(), f.Type().String())
+ }
+
+ }
+ case *sql.NullFloat64:
+ if f := v.FieldByName(m[colNames[i]]); f.IsValid() {
+ if reflect.TypeOf(*c).AssignableTo(f.Type()) {
+ f.Set(reflect.ValueOf(c).Elem())
+ } else if reflect.TypeOf(float64(0)).ConvertibleTo(f.Type()) {
+ if c.Valid {
+ f.SetFloat(c.Float64)
+ }
+ } else {
+ return fmt.Errorf("cannot convert column type %s to field type %s",
+ colTypes[i].DatabaseTypeName(), f.Type().String())
+ }
+ }
+ case *[]byte:
+ if f := v.FieldByName(m[colNames[i]]); f.IsValid() {
+ if reflect.TypeOf(*c).AssignableTo(f.Type()) {
+ f.SetBytes(*c)
+ } else if reflect.TypeOf("").AssignableTo(f.Type()) {
+ f.SetString(string(*c))
+ } else if reflect.TypeOf(sql.NullString{}).AssignableTo(f.Type()) {
+ f.FieldByName("String").SetString(string(*c))
+ f.FieldByName("Valid").SetBool(len(*c) > 0)
+ } else {
+ return fmt.Errorf("cannot convert column type %s to field type %s",
+ colTypes[i].DatabaseTypeName(), f.Type().String())
+ }
+ }
+ }
+ }
+ slice = reflect.Append(slice, v)
+ }
+ reflect.ValueOf(dest).Elem().Set(slice)
+ return nil
+}
diff --git a/data/data_test.go b/data/data_test.go
index fee0d8d..c7ea93b 100644
--- a/data/data_test.go
+++ b/data/data_test.go
@@ -15,6 +15,7 @@
package data_test
import (
+ "database/sql"
"fmt"
"github.com/apid/apid-core"
"github.com/apid/apid-core/data"
@@ -244,6 +245,149 @@
<-finished
}
}, 10)
+
+ Context("StructsFromRows", func() {
+ type TestStruct struct {
+ Id string `db:"id"`
+ QuotaInterval int64 `db:"quota_interval"`
+ SignedInt int `db:"signed_int"`
+ SqlInt sql.NullInt64 `db:"sql_int"`
+ Ratio float64 `db:"ratio"`
+ ShortFloat float32 `db:"short_float"`
+ SqlFloat sql.NullFloat64 `db:"sql_float"`
+ CreatedAt string `db:"created_at"`
+ CreatedBy sql.NullString `db:"created_by"`
+ UpdatedAt []byte `db:"updated_at"`
+ StringBlob sql.NullString `db:"string_blob"`
+ NotInDb string `db:"not_in_db"`
+ NotUsed string
+ }
+ var db apid.DB
+ BeforeEach(func() {
+ version := time.Now().String()
+ var err error
+ db, err = apid.Data().DBVersionForID("test", version)
+ Ω(err).Should(Succeed())
+ _, err = db.Exec(`
+ CREATE TABLE IF NOT EXISTS test_table (
+ id text,
+ quota_interval integer,
+ signed_int integer,
+ sql_int integer,
+ created_at blob,
+ created_by text,
+ updated_at blob,
+ string_blob blob,
+ ratio float,
+ short_float float,
+ sql_float float,
+ not_used text,
+ primary key (id)
+ );
+ INSERT INTO "test_table" VALUES(
+ 'b7e0970c-4677-4b05-8105-5ea59fdcf4e7',
+ 1,
+ -1,
+ -2,
+ '2017-10-26 22:26:50.153+00:00',
+ 'haoming',
+ '2017-10-26 22:26:50.153+00:00',
+ '2017-10-26 22:26:50.153+00:00',
+ 0.5,
+ 0.6,
+ 0.7,
+ 'not_used'
+ );
+ INSERT INTO "test_table" VALUES(
+ 'a7e0970c-4677-4b05-8105-5ea59fdcf4e7',
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL
+ );
+ `)
+ Ω(err).Should(Succeed())
+ })
+
+ AfterEach(func() {
+ _, err := db.Exec(`
+ DROP TABLE IF EXISTS test_table;
+ `)
+ Ω(err).Should(Succeed())
+ })
+
+ verifyResult := func(s []TestStruct) {
+ Ω(len(s)).Should(Equal(2))
+ Ω(s[0].Id).Should(Equal("b7e0970c-4677-4b05-8105-5ea59fdcf4e7"))
+ Ω(s[0].QuotaInterval).Should(Equal(int64(1)))
+ Ω(s[0].SignedInt).Should(Equal(int(-1)))
+ Ω(s[0].SqlInt).Should(Equal(sql.NullInt64{-2, true}))
+ Ω(s[0].Ratio).Should(Equal(float64(0.5)))
+ Ω(s[0].ShortFloat).Should(Equal(float32(0.6)))
+ Ω(s[0].SqlFloat).Should(Equal(sql.NullFloat64{0.7, true}))
+ Ω(s[0].CreatedAt).Should(Equal("2017-10-26 22:26:50.153+00:00"))
+ Ω(s[0].CreatedBy).Should(Equal(sql.NullString{"haoming", true}))
+ Ω(s[0].UpdatedAt).Should(Equal([]byte("2017-10-26 22:26:50.153+00:00")))
+ Ω(s[0].StringBlob).Should(Equal(sql.NullString{"2017-10-26 22:26:50.153+00:00", true}))
+ Ω(s[0].NotInDb).Should(BeZero())
+ Ω(s[0].NotUsed).Should(BeZero())
+
+ Ω(s[1].Id).Should(Equal("a7e0970c-4677-4b05-8105-5ea59fdcf4e7"))
+ Ω(s[1].QuotaInterval).Should(BeZero())
+ Ω(s[1].SignedInt).Should(BeZero())
+ Ω(s[1].SignedInt).Should(BeZero())
+ Ω(s[1].Ratio).Should(BeZero())
+ Ω(s[1].ShortFloat).Should(BeZero())
+ Ω(s[1].SqlFloat).Should(BeZero())
+ Ω(s[1].CreatedAt).Should(BeZero())
+ Ω(s[1].CreatedBy.Valid).Should(BeFalse())
+ Ω(s[1].UpdatedAt).Should(BeZero())
+ Ω(s[1].StringBlob.Valid).Should(BeFalse())
+ Ω(s[1].NotInDb).Should(BeZero())
+ Ω(s[1].NotUsed).Should(BeZero())
+ }
+
+ It("StructsFromRows", func() {
+ rows, err := db.Query(`
+ SELECT * from "test_table";
+ `)
+ Ω(err).Should(Succeed())
+ defer rows.Close()
+ s := []TestStruct{}
+ err = data.StructsFromRows(&s, rows)
+ Ω(err).Should(Succeed())
+ verifyResult(s)
+ })
+
+ It("DB.QueryStructs", func() {
+ s := []TestStruct{}
+ err := db.QueryStructs(&s, `
+ SELECT * from "test_table";
+ `)
+ Ω(err).Should(Succeed())
+ verifyResult(s)
+ })
+
+ It("Tx.QueryStructs", func() {
+ s := []TestStruct{}
+ tx, err := db.Begin()
+ Ω(err).Should(Succeed())
+ err = tx.QueryStructs(&s, `
+ SELECT * from "test_table";
+ `)
+ Ω(err).Should(Succeed())
+ Ω(tx.Commit()).Should(Succeed())
+ verifyResult(s)
+ })
+
+ })
})
func setup(db apid.DB) {
diff --git a/data_service.go b/data_service.go
index 355566d..0081952 100644
--- a/data_service.go
+++ b/data_service.go
@@ -44,6 +44,7 @@
SetConnMaxLifetime(d time.Duration)
SetMaxIdleConns(n int)
SetMaxOpenConns(n int)
+ QueryStructs(dest interface{}, query string, args ...interface{}) error
//Close() error
//Stats() sql.DBStats
//Driver() driver.Driver
@@ -62,4 +63,5 @@
Rollback() error
Stmt(stmt *sql.Stmt) *sql.Stmt
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
+ QueryStructs(dest interface{}, query string, args ...interface{}) error
}