|  | // Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>. | 
|  | // TODO: add "Gimpl do foo" team? | 
|  | // | 
|  | // Use of this source code is governed by an MIT-style | 
|  | // license that can be found in the LICENSE file. | 
|  | // +build trace | 
|  |  | 
|  | package sqlite3 | 
|  |  | 
|  | /* | 
|  | #ifndef USE_LIBSQLITE3 | 
|  | #include <sqlite3-binding.h> | 
|  | #else | 
|  | #include <sqlite3.h> | 
|  | #endif | 
|  | #include <stdlib.h> | 
|  |  | 
|  | void stepTrampoline(sqlite3_context*, int, sqlite3_value**); | 
|  | void doneTrampoline(sqlite3_context*); | 
|  | int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x); | 
|  | */ | 
|  | import "C" | 
|  |  | 
|  | import ( | 
|  | "errors" | 
|  | "fmt" | 
|  | "reflect" | 
|  | "strings" | 
|  | "sync" | 
|  | "unsafe" | 
|  | ) | 
|  |  | 
|  | // Trace... constants identify the possible events causing callback invocation. | 
|  | // Values are same as the corresponding SQLite Trace Event Codes. | 
|  | const ( | 
|  | TraceStmt    = C.SQLITE_TRACE_STMT | 
|  | TraceProfile = C.SQLITE_TRACE_PROFILE | 
|  | TraceRow     = C.SQLITE_TRACE_ROW | 
|  | TraceClose   = C.SQLITE_TRACE_CLOSE | 
|  | ) | 
|  |  | 
|  | type TraceInfo struct { | 
|  | // Pack together the shorter fields, to keep the struct smaller. | 
|  | // On a 64-bit machine there would be padding | 
|  | // between EventCode and ConnHandle; having AutoCommit here is "free": | 
|  | EventCode  uint32 | 
|  | AutoCommit bool | 
|  | ConnHandle uintptr | 
|  |  | 
|  | // Usually filled, unless EventCode = TraceClose = SQLITE_TRACE_CLOSE: | 
|  | // identifier for a prepared statement: | 
|  | StmtHandle uintptr | 
|  |  | 
|  | // Two strings filled when EventCode = TraceStmt = SQLITE_TRACE_STMT: | 
|  | // (1) either the unexpanded SQL text of the prepared statement, or | 
|  | //     an SQL comment that indicates the invocation of a trigger; | 
|  | // (2) expanded SQL, if requested and if (1) is not an SQL comment. | 
|  | StmtOrTrigger string | 
|  | ExpandedSQL   string // only if requested (TraceConfig.WantExpandedSQL = true) | 
|  |  | 
|  | // filled when EventCode = TraceProfile = SQLITE_TRACE_PROFILE: | 
|  | // estimated number of nanoseconds that the prepared statement took to run: | 
|  | RunTimeNanosec int64 | 
|  |  | 
|  | DBError Error | 
|  | } | 
|  |  | 
|  | // TraceUserCallback gives the signature for a trace function | 
|  | // provided by the user (Go application programmer). | 
|  | // SQLite 3.14 documentation (as of September 2, 2016) | 
|  | // for SQL Trace Hook = sqlite3_trace_v2(): | 
|  | // The integer return value from the callback is currently ignored, | 
|  | // though this may change in future releases. Callback implementations | 
|  | // should return zero to ensure future compatibility. | 
|  | type TraceUserCallback func(TraceInfo) int | 
|  |  | 
|  | type TraceConfig struct { | 
|  | Callback        TraceUserCallback | 
|  | EventMask       C.uint | 
|  | WantExpandedSQL bool | 
|  | } | 
|  |  | 
|  | func fillDBError(dbErr *Error, db *C.sqlite3) { | 
|  | // See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016) | 
|  | dbErr.Code = ErrNo(C.sqlite3_errcode(db)) | 
|  | dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db)) | 
|  | dbErr.err = C.GoString(C.sqlite3_errmsg(db)) | 
|  | } | 
|  |  | 
|  | func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) { | 
|  | if pStmt == nil { | 
|  | panic("No SQLite statement pointer in P arg of trace_v2 callback") | 
|  | } | 
|  |  | 
|  | expSQLiteCStr := C.sqlite3_expanded_sql((*C.sqlite3_stmt)(pStmt)) | 
|  | if expSQLiteCStr == nil { | 
|  | fillDBError(&info.DBError, db) | 
|  | return | 
|  | } | 
|  | info.ExpandedSQL = C.GoString(expSQLiteCStr) | 
|  | } | 
|  |  | 
|  | //export traceCallbackTrampoline | 
|  | func traceCallbackTrampoline( | 
|  | traceEventCode C.uint, | 
|  | // Parameter named 'C' in SQLite docs = Context given at registration: | 
|  | ctx unsafe.Pointer, | 
|  | // Parameter named 'P' in SQLite docs (Primary event data?): | 
|  | p unsafe.Pointer, | 
|  | // Parameter named 'X' in SQLite docs (eXtra event data?): | 
|  | xValue unsafe.Pointer) C.int { | 
|  |  | 
|  | if ctx == nil { | 
|  | panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode)) | 
|  | } | 
|  |  | 
|  | contextDB := (*C.sqlite3)(ctx) | 
|  | connHandle := uintptr(ctx) | 
|  |  | 
|  | var traceConf TraceConfig | 
|  | var found bool | 
|  | if traceEventCode == TraceClose { | 
|  | // clean up traceMap: 'pop' means get and delete | 
|  | traceConf, found = popTraceMapping(connHandle) | 
|  | } else { | 
|  | traceConf, found = lookupTraceMapping(connHandle) | 
|  | } | 
|  |  | 
|  | if !found { | 
|  | panic(fmt.Sprintf("Mapping not found for handle 0x%x (ev 0x%x)", | 
|  | connHandle, traceEventCode)) | 
|  | } | 
|  |  | 
|  | var info TraceInfo | 
|  |  | 
|  | info.EventCode = uint32(traceEventCode) | 
|  | info.AutoCommit = (int(C.sqlite3_get_autocommit(contextDB)) != 0) | 
|  | info.ConnHandle = connHandle | 
|  |  | 
|  | switch traceEventCode { | 
|  | case TraceStmt: | 
|  | info.StmtHandle = uintptr(p) | 
|  |  | 
|  | var xStr string | 
|  | if xValue != nil { | 
|  | xStr = C.GoString((*C.char)(xValue)) | 
|  | } | 
|  | info.StmtOrTrigger = xStr | 
|  | if !strings.HasPrefix(xStr, "--") { | 
|  | // Not SQL comment, therefore the current event | 
|  | // is not related to a trigger. | 
|  | // The user might want to receive the expanded SQL; | 
|  | // let's check: | 
|  | if traceConf.WantExpandedSQL { | 
|  | fillExpandedSQL(&info, contextDB, p) | 
|  | } | 
|  | } | 
|  |  | 
|  | case TraceProfile: | 
|  | info.StmtHandle = uintptr(p) | 
|  |  | 
|  | if xValue == nil { | 
|  | panic("NULL pointer in X arg of trace_v2 callback for SQLITE_TRACE_PROFILE event") | 
|  | } | 
|  |  | 
|  | info.RunTimeNanosec = *(*int64)(xValue) | 
|  |  | 
|  | // sample the error //TODO: is it safe? is it useful? | 
|  | fillDBError(&info.DBError, contextDB) | 
|  |  | 
|  | case TraceRow: | 
|  | info.StmtHandle = uintptr(p) | 
|  |  | 
|  | case TraceClose: | 
|  | handle := uintptr(p) | 
|  | if handle != info.ConnHandle { | 
|  | panic(fmt.Sprintf("Different conn handle 0x%x (expected 0x%x) in SQLITE_TRACE_CLOSE event.", | 
|  | handle, info.ConnHandle)) | 
|  | } | 
|  |  | 
|  | default: | 
|  | // Pass unsupported events to the user callback (if configured); | 
|  | // let the user callback decide whether to panic or ignore them. | 
|  | } | 
|  |  | 
|  | // Do not execute user callback when the event was not requested by user! | 
|  | // Remember that the Close event is always selected when | 
|  | // registering this callback trampoline with SQLite --- for cleanup. | 
|  | // In the future there may be more events forced to "selected" in SQLite | 
|  | // for the driver's needs. | 
|  | if traceConf.EventMask&traceEventCode == 0 { | 
|  | return 0 | 
|  | } | 
|  |  | 
|  | r := 0 | 
|  | if traceConf.Callback != nil { | 
|  | r = traceConf.Callback(info) | 
|  | } | 
|  | return C.int(r) | 
|  | } | 
|  |  | 
|  | type traceMapEntry struct { | 
|  | config TraceConfig | 
|  | } | 
|  |  | 
|  | var traceMapLock sync.Mutex | 
|  | var traceMap = make(map[uintptr]traceMapEntry) | 
|  |  | 
|  | func addTraceMapping(connHandle uintptr, traceConf TraceConfig) { | 
|  | traceMapLock.Lock() | 
|  | defer traceMapLock.Unlock() | 
|  |  | 
|  | oldEntryCopy, found := traceMap[connHandle] | 
|  | if found { | 
|  | panic(fmt.Sprintf("Adding trace config %v: handle 0x%x already registered (%v).", | 
|  | traceConf, connHandle, oldEntryCopy.config)) | 
|  | } | 
|  | traceMap[connHandle] = traceMapEntry{config: traceConf} | 
|  | fmt.Printf("Added trace config %v: handle 0x%x.\n", traceConf, connHandle) | 
|  | } | 
|  |  | 
|  | func lookupTraceMapping(connHandle uintptr) (TraceConfig, bool) { | 
|  | traceMapLock.Lock() | 
|  | defer traceMapLock.Unlock() | 
|  |  | 
|  | entryCopy, found := traceMap[connHandle] | 
|  | return entryCopy.config, found | 
|  | } | 
|  |  | 
|  | // 'pop' = get and delete from map before returning the value to the caller | 
|  | func popTraceMapping(connHandle uintptr) (TraceConfig, bool) { | 
|  | traceMapLock.Lock() | 
|  | defer traceMapLock.Unlock() | 
|  |  | 
|  | entryCopy, found := traceMap[connHandle] | 
|  | if found { | 
|  | delete(traceMap, connHandle) | 
|  | fmt.Printf("Pop handle 0x%x: deleted trace config %v.\n", connHandle, entryCopy.config) | 
|  | } | 
|  | return entryCopy.config, found | 
|  | } | 
|  |  | 
|  | // RegisterAggregator makes a Go type available as a SQLite aggregation function. | 
|  | // | 
|  | // Because aggregation is incremental, it's implemented in Go with a | 
|  | // type that has 2 methods: func Step(values) accumulates one row of | 
|  | // data into the accumulator, and func Done() ret finalizes and | 
|  | // returns the aggregate value. "values" and "ret" may be any type | 
|  | // supported by RegisterFunc. | 
|  | // | 
|  | // RegisterAggregator takes as implementation a constructor function | 
|  | // that constructs an instance of the aggregator type each time an | 
|  | // aggregation begins. The constructor must return a pointer to a | 
|  | // type, or an interface that implements Step() and Done(). | 
|  | // | 
|  | // The constructor function and the Step/Done methods may optionally | 
|  | // return an error in addition to their other return values. | 
|  | // | 
|  | // See _example/go_custom_funcs for a detailed example. | 
|  | func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error { | 
|  | var ai aggInfo | 
|  | ai.constructor = reflect.ValueOf(impl) | 
|  | t := ai.constructor.Type() | 
|  | if t.Kind() != reflect.Func { | 
|  | return errors.New("non-function passed to RegisterAggregator") | 
|  | } | 
|  | if t.NumOut() != 1 && t.NumOut() != 2 { | 
|  | return errors.New("SQLite aggregator constructors must return 1 or 2 values") | 
|  | } | 
|  | if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { | 
|  | return errors.New("Second return value of SQLite function must be error") | 
|  | } | 
|  | if t.NumIn() != 0 { | 
|  | return errors.New("SQLite aggregator constructors must not have arguments") | 
|  | } | 
|  |  | 
|  | agg := t.Out(0) | 
|  | switch agg.Kind() { | 
|  | case reflect.Ptr, reflect.Interface: | 
|  | default: | 
|  | return errors.New("SQlite aggregator constructor must return a pointer object") | 
|  | } | 
|  | stepFn, found := agg.MethodByName("Step") | 
|  | if !found { | 
|  | return errors.New("SQlite aggregator doesn't have a Step() function") | 
|  | } | 
|  | step := stepFn.Type | 
|  | if step.NumOut() != 0 && step.NumOut() != 1 { | 
|  | return errors.New("SQlite aggregator Step() function must return 0 or 1 values") | 
|  | } | 
|  | if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { | 
|  | return errors.New("type of SQlite aggregator Step() return value must be error") | 
|  | } | 
|  |  | 
|  | stepNArgs := step.NumIn() | 
|  | start := 0 | 
|  | if agg.Kind() == reflect.Ptr { | 
|  | // Skip over the method receiver | 
|  | stepNArgs-- | 
|  | start++ | 
|  | } | 
|  | if step.IsVariadic() { | 
|  | stepNArgs-- | 
|  | } | 
|  | for i := start; i < start+stepNArgs; i++ { | 
|  | conv, err := callbackArg(step.In(i)) | 
|  | if err != nil { | 
|  | return err | 
|  | } | 
|  | ai.stepArgConverters = append(ai.stepArgConverters, conv) | 
|  | } | 
|  | if step.IsVariadic() { | 
|  | conv, err := callbackArg(t.In(start + stepNArgs).Elem()) | 
|  | if err != nil { | 
|  | return err | 
|  | } | 
|  | ai.stepVariadicConverter = conv | 
|  | // Pass -1 to sqlite so that it allows any number of | 
|  | // arguments. The call helper verifies that the minimum number | 
|  | // of arguments is present for variadic functions. | 
|  | stepNArgs = -1 | 
|  | } | 
|  |  | 
|  | doneFn, found := agg.MethodByName("Done") | 
|  | if !found { | 
|  | return errors.New("SQlite aggregator doesn't have a Done() function") | 
|  | } | 
|  | done := doneFn.Type | 
|  | doneNArgs := done.NumIn() | 
|  | if agg.Kind() == reflect.Ptr { | 
|  | // Skip over the method receiver | 
|  | doneNArgs-- | 
|  | } | 
|  | if doneNArgs != 0 { | 
|  | return errors.New("SQlite aggregator Done() function must have no arguments") | 
|  | } | 
|  | if done.NumOut() != 1 && done.NumOut() != 2 { | 
|  | return errors.New("SQLite aggregator Done() function must return 1 or 2 values") | 
|  | } | 
|  | if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { | 
|  | return errors.New("second return value of SQLite aggregator Done() function must be error") | 
|  | } | 
|  |  | 
|  | conv, err := callbackRet(done.Out(0)) | 
|  | if err != nil { | 
|  | return err | 
|  | } | 
|  | ai.doneRetConverter = conv | 
|  | ai.active = make(map[int64]reflect.Value) | 
|  | ai.next = 1 | 
|  |  | 
|  | // ai must outlast the database connection, or we'll have dangling pointers. | 
|  | c.aggregators = append(c.aggregators, &ai) | 
|  |  | 
|  | cname := C.CString(name) | 
|  | defer C.free(unsafe.Pointer(cname)) | 
|  | opts := C.SQLITE_UTF8 | 
|  | if pure { | 
|  | opts |= C.SQLITE_DETERMINISTIC | 
|  | } | 
|  | rv := sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) | 
|  | if rv != C.SQLITE_OK { | 
|  | return c.lastError() | 
|  | } | 
|  | return nil | 
|  | } | 
|  |  | 
|  | // SetTrace installs or removes the trace callback for the given database connection. | 
|  | // It's not named 'RegisterTrace' because only one callback can be kept and called. | 
|  | // Calling SetTrace a second time on same database connection | 
|  | // overrides (cancels) any prior callback and all its settings: | 
|  | // event mask, etc. | 
|  | func (c *SQLiteConn) SetTrace(requested *TraceConfig) error { | 
|  | connHandle := uintptr(unsafe.Pointer(c.db)) | 
|  |  | 
|  | _, _ = popTraceMapping(connHandle) | 
|  |  | 
|  | if requested == nil { | 
|  | // The traceMap entry was deleted already by popTraceMapping(): | 
|  | // can disable all events now, no need to watch for TraceClose. | 
|  | err := c.setSQLiteTrace(0) | 
|  | return err | 
|  | } | 
|  |  | 
|  | reqCopy := *requested | 
|  |  | 
|  | // Disable potentially expensive operations | 
|  | // if their result will not be used. We are doing this | 
|  | // just in case the caller provided nonsensical input. | 
|  | if reqCopy.EventMask&TraceStmt == 0 { | 
|  | reqCopy.WantExpandedSQL = false | 
|  | } | 
|  |  | 
|  | addTraceMapping(connHandle, reqCopy) | 
|  |  | 
|  | // The callback trampoline function does cleanup on Close event, | 
|  | // regardless of the presence or absence of the user callback. | 
|  | // Therefore it needs the Close event to be selected: | 
|  | actualEventMask := uint(reqCopy.EventMask | TraceClose) | 
|  | err := c.setSQLiteTrace(actualEventMask) | 
|  | return err | 
|  | } | 
|  |  | 
|  | func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error { | 
|  | rv := C.sqlite3_trace_v2(c.db, | 
|  | C.uint(sqliteEventMask), | 
|  | (*[0]byte)(unsafe.Pointer(C.traceCallbackTrampoline)), | 
|  | unsafe.Pointer(c.db)) // Fourth arg is same as first: we are | 
|  | // passing the database connection handle as callback context. | 
|  |  | 
|  | if rv != C.SQLITE_OK { | 
|  | return c.lastError() | 
|  | } | 
|  | return nil | 
|  | } |