b/68272561 Add generic "sql.rows->[]struct" mapping to apid-core (#27) * [ISSUE-68272561] add StructsFromRows * [ISSUE-68272561] support sql.Null*** * [ISSUE-68272561] rename test table * [ISSUE-68272561] add QueryStructs to db and tx
diff --git a/data/data.go b/data/data.go index 744ba6f..0ae6424 100644 --- a/data/data.go +++ b/data/data.go
@@ -18,14 +18,15 @@ "context" "database/sql" "fmt" + "github.com/Sirupsen/logrus" "github.com/apid/apid-core" "github.com/apid/apid-core/api" "github.com/apid/apid-core/data/wrap" "github.com/apid/apid-core/logger" - "github.com/Sirupsen/logrus" "github.com/mattn/go-sqlite3" "os" "path" + "reflect" "runtime" "strings" "sync" @@ -52,6 +53,7 @@ var dbMap = make(map[string]*dbMapInfo) var dbMapSync sync.RWMutex +var tagFieldMapper = make(map[reflect.Type]map[string]string) type ApidDb struct { db *sql.DB @@ -90,6 +92,16 @@ return d.db.QueryRow(query, args...) } +func (d *ApidDb) QueryStructs(dest interface{}, query string, args ...interface{}) error { + rows, err := d.db.Query(query, args...) + if err != nil { + return err + } + defer rows.Close() + err = StructsFromRows(dest, rows) + return err +} + func (d *ApidDb) Begin() (apid.Tx, error) { d.mutex.Lock() tx, err := d.db.Begin() @@ -158,6 +170,16 @@ return tx.tx.StmtContext(ctx, stmt) } +func (tx *Tx) QueryStructs(dest interface{}, query string, args ...interface{}) error { + rows, err := tx.tx.Query(query, args...) + if err != nil { + return err + } + defer rows.Close() + err = StructsFromRows(dest, rows) + return err +} + func CreateDataService() apid.DataService { config = apid.Config() log = apid.Log().ForModule("data") @@ -358,3 +380,128 @@ }() return stop } + +// StructsFromRows fill the dest slice with the values of according rows. +// Each row is marshaled into a struct. The "db" tag in the struct is used for field mapping. +// It will take care of null value. Supported type mappings from Sqlite3 to Go are: +// text->string; integer->int/int64/sql.NullInt64; float->float/float64/sql.NullFloat64; +// blob->[]byte/string/sql.NullString +func StructsFromRows(dest interface{}, rows *sql.Rows) error { + t := reflect.TypeOf(dest) + if t == nil { + return nil + } + // type of the struct + t = t.Elem().Elem() + //build mapper if not existent + m, ok := tagFieldMapper[t] + if !ok { + m = make(map[string]string) + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if key := f.Tag.Get("db"); key != "" { + m[key] = f.Name + } + + } + tagFieldMapper[t] = m + } + + colNames, err := rows.Columns() + if err != nil { + return err + } + colTypes, err := rows.ColumnTypes() + if err != nil { + return err + } + + cols := make([]interface{}, len(colNames)) + slice := reflect.New(reflect.SliceOf(t)).Elem() + + for i := range cols { + switch colTypes[i].DatabaseTypeName() { + case "null": + cols[i] = new(sql.NullString) + case "text": + cols[i] = new(sql.NullString) + case "integer": + cols[i] = new(sql.NullInt64) + case "float": + cols[i] = new(sql.NullFloat64) + case "blob": + cols[i] = new([]byte) + default: + return fmt.Errorf("unsupprted column type: %s", colTypes[i].DatabaseTypeName()) + } + } + for rows.Next() { + v := reflect.New(t).Elem() + err := rows.Scan(cols...) + if err != nil { + return err + } + for i := range cols { + switch c := cols[i].(type) { + case *sql.NullString: + if f := v.FieldByName(m[colNames[i]]); f.IsValid() { + if reflect.TypeOf(*c).AssignableTo(f.Type()) { + f.Set(reflect.ValueOf(c).Elem()) + } else if reflect.TypeOf("").AssignableTo(f.Type()) { + if c.Valid { + f.SetString(c.String) + } + } else { + return fmt.Errorf("cannot convert column type %s to field type %s", + colTypes[i].DatabaseTypeName(), f.Type().String()) + } + + } + case *sql.NullInt64: + if f := v.FieldByName(m[colNames[i]]); f.IsValid() { + if reflect.TypeOf(*c).AssignableTo(f.Type()) { + f.Set(reflect.ValueOf(c).Elem()) + } else if reflect.TypeOf(int64(0)).ConvertibleTo(f.Type()) { + if c.Valid { + f.SetInt(c.Int64) + } + } else { + return fmt.Errorf("cannot convert column type %s to field type %s", + colTypes[i].DatabaseTypeName(), f.Type().String()) + } + + } + case *sql.NullFloat64: + if f := v.FieldByName(m[colNames[i]]); f.IsValid() { + if reflect.TypeOf(*c).AssignableTo(f.Type()) { + f.Set(reflect.ValueOf(c).Elem()) + } else if reflect.TypeOf(float64(0)).ConvertibleTo(f.Type()) { + if c.Valid { + f.SetFloat(c.Float64) + } + } else { + return fmt.Errorf("cannot convert column type %s to field type %s", + colTypes[i].DatabaseTypeName(), f.Type().String()) + } + } + case *[]byte: + if f := v.FieldByName(m[colNames[i]]); f.IsValid() { + if reflect.TypeOf(*c).AssignableTo(f.Type()) { + f.SetBytes(*c) + } else if reflect.TypeOf("").AssignableTo(f.Type()) { + f.SetString(string(*c)) + } else if reflect.TypeOf(sql.NullString{}).AssignableTo(f.Type()) { + f.FieldByName("String").SetString(string(*c)) + f.FieldByName("Valid").SetBool(len(*c) > 0) + } else { + return fmt.Errorf("cannot convert column type %s to field type %s", + colTypes[i].DatabaseTypeName(), f.Type().String()) + } + } + } + } + slice = reflect.Append(slice, v) + } + reflect.ValueOf(dest).Elem().Set(slice) + return nil +}
diff --git a/data/data_test.go b/data/data_test.go index fee0d8d..c7ea93b 100644 --- a/data/data_test.go +++ b/data/data_test.go
@@ -15,6 +15,7 @@ package data_test import ( + "database/sql" "fmt" "github.com/apid/apid-core" "github.com/apid/apid-core/data" @@ -244,6 +245,149 @@ <-finished } }, 10) + + Context("StructsFromRows", func() { + type TestStruct struct { + Id string `db:"id"` + QuotaInterval int64 `db:"quota_interval"` + SignedInt int `db:"signed_int"` + SqlInt sql.NullInt64 `db:"sql_int"` + Ratio float64 `db:"ratio"` + ShortFloat float32 `db:"short_float"` + SqlFloat sql.NullFloat64 `db:"sql_float"` + CreatedAt string `db:"created_at"` + CreatedBy sql.NullString `db:"created_by"` + UpdatedAt []byte `db:"updated_at"` + StringBlob sql.NullString `db:"string_blob"` + NotInDb string `db:"not_in_db"` + NotUsed string + } + var db apid.DB + BeforeEach(func() { + version := time.Now().String() + var err error + db, err = apid.Data().DBVersionForID("test", version) + Ω(err).Should(Succeed()) + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS test_table ( + id text, + quota_interval integer, + signed_int integer, + sql_int integer, + created_at blob, + created_by text, + updated_at blob, + string_blob blob, + ratio float, + short_float float, + sql_float float, + not_used text, + primary key (id) + ); + INSERT INTO "test_table" VALUES( + 'b7e0970c-4677-4b05-8105-5ea59fdcf4e7', + 1, + -1, + -2, + '2017-10-26 22:26:50.153+00:00', + 'haoming', + '2017-10-26 22:26:50.153+00:00', + '2017-10-26 22:26:50.153+00:00', + 0.5, + 0.6, + 0.7, + 'not_used' + ); + INSERT INTO "test_table" VALUES( + 'a7e0970c-4677-4b05-8105-5ea59fdcf4e7', + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL + ); + `) + Ω(err).Should(Succeed()) + }) + + AfterEach(func() { + _, err := db.Exec(` + DROP TABLE IF EXISTS test_table; + `) + Ω(err).Should(Succeed()) + }) + + verifyResult := func(s []TestStruct) { + Ω(len(s)).Should(Equal(2)) + Ω(s[0].Id).Should(Equal("b7e0970c-4677-4b05-8105-5ea59fdcf4e7")) + Ω(s[0].QuotaInterval).Should(Equal(int64(1))) + Ω(s[0].SignedInt).Should(Equal(int(-1))) + Ω(s[0].SqlInt).Should(Equal(sql.NullInt64{-2, true})) + Ω(s[0].Ratio).Should(Equal(float64(0.5))) + Ω(s[0].ShortFloat).Should(Equal(float32(0.6))) + Ω(s[0].SqlFloat).Should(Equal(sql.NullFloat64{0.7, true})) + Ω(s[0].CreatedAt).Should(Equal("2017-10-26 22:26:50.153+00:00")) + Ω(s[0].CreatedBy).Should(Equal(sql.NullString{"haoming", true})) + Ω(s[0].UpdatedAt).Should(Equal([]byte("2017-10-26 22:26:50.153+00:00"))) + Ω(s[0].StringBlob).Should(Equal(sql.NullString{"2017-10-26 22:26:50.153+00:00", true})) + Ω(s[0].NotInDb).Should(BeZero()) + Ω(s[0].NotUsed).Should(BeZero()) + + Ω(s[1].Id).Should(Equal("a7e0970c-4677-4b05-8105-5ea59fdcf4e7")) + Ω(s[1].QuotaInterval).Should(BeZero()) + Ω(s[1].SignedInt).Should(BeZero()) + Ω(s[1].SignedInt).Should(BeZero()) + Ω(s[1].Ratio).Should(BeZero()) + Ω(s[1].ShortFloat).Should(BeZero()) + Ω(s[1].SqlFloat).Should(BeZero()) + Ω(s[1].CreatedAt).Should(BeZero()) + Ω(s[1].CreatedBy.Valid).Should(BeFalse()) + Ω(s[1].UpdatedAt).Should(BeZero()) + Ω(s[1].StringBlob.Valid).Should(BeFalse()) + Ω(s[1].NotInDb).Should(BeZero()) + Ω(s[1].NotUsed).Should(BeZero()) + } + + It("StructsFromRows", func() { + rows, err := db.Query(` + SELECT * from "test_table"; + `) + Ω(err).Should(Succeed()) + defer rows.Close() + s := []TestStruct{} + err = data.StructsFromRows(&s, rows) + Ω(err).Should(Succeed()) + verifyResult(s) + }) + + It("DB.QueryStructs", func() { + s := []TestStruct{} + err := db.QueryStructs(&s, ` + SELECT * from "test_table"; + `) + Ω(err).Should(Succeed()) + verifyResult(s) + }) + + It("Tx.QueryStructs", func() { + s := []TestStruct{} + tx, err := db.Begin() + Ω(err).Should(Succeed()) + err = tx.QueryStructs(&s, ` + SELECT * from "test_table"; + `) + Ω(err).Should(Succeed()) + Ω(tx.Commit()).Should(Succeed()) + verifyResult(s) + }) + + }) }) func setup(db apid.DB) {
diff --git a/data_service.go b/data_service.go index 355566d..0081952 100644 --- a/data_service.go +++ b/data_service.go
@@ -44,6 +44,7 @@ SetConnMaxLifetime(d time.Duration) SetMaxIdleConns(n int) SetMaxOpenConns(n int) + QueryStructs(dest interface{}, query string, args ...interface{}) error //Close() error //Stats() sql.DBStats //Driver() driver.Driver @@ -62,4 +63,5 @@ Rollback() error Stmt(stmt *sql.Stmt) *sql.Stmt StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt + QueryStructs(dest interface{}, query string, args ...interface{}) error }