Source file
src/database/sql/fakedb_test.go
1
2
3
4
5 package sql
6
7 import (
8 "context"
9 "database/sql/driver"
10 "errors"
11 "fmt"
12 "io"
13 "reflect"
14 "sort"
15 "strconv"
16 "strings"
17 "sync"
18 "sync/atomic"
19 "testing"
20 "time"
21 )
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47 type fakeDriver struct {
48 mu sync.Mutex
49 openCount int
50 closeCount int
51 waitCh chan struct{}
52 waitingCh chan struct{}
53 dbs map[string]*fakeDB
54 }
55
56 type fakeConnector struct {
57 name string
58
59 waiter func(context.Context)
60 closed bool
61 }
62
63 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
64 conn, err := fdriver.Open(c.name)
65 conn.(*fakeConn).waiter = c.waiter
66 return conn, err
67 }
68
69 func (c *fakeConnector) Driver() driver.Driver {
70 return fdriver
71 }
72
73 func (c *fakeConnector) Close() error {
74 if c.closed {
75 return errors.New("fakedb: connector is closed")
76 }
77 c.closed = true
78 return nil
79 }
80
81 type fakeDriverCtx struct {
82 fakeDriver
83 }
84
85 var _ driver.DriverContext = &fakeDriverCtx{}
86
87 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
88 return &fakeConnector{name: name}, nil
89 }
90
91 type fakeDB struct {
92 name string
93
94 useRawBytes atomic.Bool
95
96 mu sync.Mutex
97 tables map[string]*table
98 badConn bool
99 allowAny bool
100 }
101
102 type fakeError struct {
103 Message string
104 Wrapped error
105 }
106
107 func (err fakeError) Error() string {
108 return err.Message
109 }
110
111 func (err fakeError) Unwrap() error {
112 return err.Wrapped
113 }
114
115 type table struct {
116 mu sync.Mutex
117 colname []string
118 coltype []string
119 rows []*row
120 }
121
122 func (t *table) columnIndex(name string) int {
123 for n, nname := range t.colname {
124 if name == nname {
125 return n
126 }
127 }
128 return -1
129 }
130
131 type row struct {
132 cols []any
133 }
134
135 type memToucher interface {
136
137 touchMem()
138 }
139
140 type fakeConn struct {
141 db *fakeDB
142
143 currTx *fakeTx
144
145
146
147 line int64
148
149
150 mu sync.Mutex
151 stmtsMade int
152 stmtsClosed int
153 numPrepare int
154
155
156 bad bool
157 stickyBad bool
158
159 skipDirtySession bool
160
161
162
163 dirtySession bool
164
165
166
167 waiter func(context.Context)
168 }
169
170 func (c *fakeConn) touchMem() {
171 c.line++
172 }
173
174 func (c *fakeConn) incrStat(v *int) {
175 c.mu.Lock()
176 *v++
177 c.mu.Unlock()
178 }
179
180 type fakeTx struct {
181 c *fakeConn
182 }
183
184 type boundCol struct {
185 Column string
186 Placeholder string
187 Ordinal int
188 }
189
190 type fakeStmt struct {
191 memToucher
192 c *fakeConn
193 q string
194
195 cmd string
196 table string
197 panic string
198 wait time.Duration
199
200 next *fakeStmt
201
202 closed bool
203
204 colName []string
205 colType []string
206 colValue []any
207 placeholders int
208
209 whereCol []boundCol
210
211 placeholderConverter []driver.ValueConverter
212 }
213
214 var fdriver driver.Driver = &fakeDriver{}
215
216 func init() {
217 Register("test", fdriver)
218 }
219
220 func contains(list []string, y string) bool {
221 for _, x := range list {
222 if x == y {
223 return true
224 }
225 }
226 return false
227 }
228
229 type Dummy struct {
230 driver.Driver
231 }
232
233 func TestDrivers(t *testing.T) {
234 unregisterAllDrivers()
235 Register("test", fdriver)
236 Register("invalid", Dummy{})
237 all := Drivers()
238 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
239 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
240 }
241 }
242
243
244 var hookOpenErr struct {
245 sync.Mutex
246 fn func() error
247 }
248
249 func setHookOpenErr(fn func() error) {
250 hookOpenErr.Lock()
251 defer hookOpenErr.Unlock()
252 hookOpenErr.fn = fn
253 }
254
255
256
257
258
259
260
261 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
262 hookOpenErr.Lock()
263 fn := hookOpenErr.fn
264 hookOpenErr.Unlock()
265 if fn != nil {
266 if err := fn(); err != nil {
267 return nil, err
268 }
269 }
270 parts := strings.Split(dsn, ";")
271 if len(parts) < 1 {
272 return nil, errors.New("fakedb: no database name")
273 }
274 name := parts[0]
275
276 db := d.getDB(name)
277
278 d.mu.Lock()
279 d.openCount++
280 d.mu.Unlock()
281 conn := &fakeConn{db: db}
282
283 if len(parts) >= 2 && parts[1] == "badConn" {
284 conn.bad = true
285 }
286 if d.waitCh != nil {
287 d.waitingCh <- struct{}{}
288 <-d.waitCh
289 d.waitCh = nil
290 d.waitingCh = nil
291 }
292 return conn, nil
293 }
294
295 func (d *fakeDriver) getDB(name string) *fakeDB {
296 d.mu.Lock()
297 defer d.mu.Unlock()
298 if d.dbs == nil {
299 d.dbs = make(map[string]*fakeDB)
300 }
301 db, ok := d.dbs[name]
302 if !ok {
303 db = &fakeDB{name: name}
304 d.dbs[name] = db
305 }
306 return db
307 }
308
309 func (db *fakeDB) wipe() {
310 db.mu.Lock()
311 defer db.mu.Unlock()
312 db.tables = nil
313 }
314
315 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
316 db.mu.Lock()
317 defer db.mu.Unlock()
318 if db.tables == nil {
319 db.tables = make(map[string]*table)
320 }
321 if _, exist := db.tables[name]; exist {
322 return fmt.Errorf("fakedb: table %q already exists", name)
323 }
324 if len(columnNames) != len(columnTypes) {
325 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
326 name, len(columnNames), len(columnTypes))
327 }
328 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
329 return nil
330 }
331
332
333 func (db *fakeDB) table(table string) (*table, bool) {
334 if db.tables == nil {
335 return nil, false
336 }
337 t, ok := db.tables[table]
338 return t, ok
339 }
340
341 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
342 db.mu.Lock()
343 defer db.mu.Unlock()
344 t, ok := db.table(table)
345 if !ok {
346 return
347 }
348 for n, cname := range t.colname {
349 if cname == column {
350 return t.coltype[n], true
351 }
352 }
353 return "", false
354 }
355
356 func (c *fakeConn) isBad() bool {
357 if c.stickyBad {
358 return true
359 } else if c.bad {
360 if c.db == nil {
361 return false
362 }
363
364 c.db.badConn = !c.db.badConn
365 return c.db.badConn
366 } else {
367 return false
368 }
369 }
370
371 func (c *fakeConn) isDirtyAndMark() bool {
372 if c.skipDirtySession {
373 return false
374 }
375 if c.currTx != nil {
376 c.dirtySession = true
377 return false
378 }
379 if c.dirtySession {
380 return true
381 }
382 c.dirtySession = true
383 return false
384 }
385
386 func (c *fakeConn) Begin() (driver.Tx, error) {
387 if c.isBad() {
388 return nil, fakeError{Wrapped: driver.ErrBadConn}
389 }
390 if c.currTx != nil {
391 return nil, errors.New("fakedb: already in a transaction")
392 }
393 c.touchMem()
394 c.currTx = &fakeTx{c: c}
395 return c.currTx, nil
396 }
397
398 var hookPostCloseConn struct {
399 sync.Mutex
400 fn func(*fakeConn, error)
401 }
402
403 func setHookpostCloseConn(fn func(*fakeConn, error)) {
404 hookPostCloseConn.Lock()
405 defer hookPostCloseConn.Unlock()
406 hookPostCloseConn.fn = fn
407 }
408
409 var testStrictClose *testing.T
410
411
412
413 func setStrictFakeConnClose(t *testing.T) {
414 testStrictClose = t
415 }
416
417 func (c *fakeConn) ResetSession(ctx context.Context) error {
418 c.dirtySession = false
419 c.currTx = nil
420 if c.isBad() {
421 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
422 }
423 return nil
424 }
425
426 var _ driver.Validator = (*fakeConn)(nil)
427
428 func (c *fakeConn) IsValid() bool {
429 return !c.isBad()
430 }
431
432 func (c *fakeConn) Close() (err error) {
433 drv := fdriver.(*fakeDriver)
434 defer func() {
435 if err != nil && testStrictClose != nil {
436 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
437 }
438 hookPostCloseConn.Lock()
439 fn := hookPostCloseConn.fn
440 hookPostCloseConn.Unlock()
441 if fn != nil {
442 fn(c, err)
443 }
444 if err == nil {
445 drv.mu.Lock()
446 drv.closeCount++
447 drv.mu.Unlock()
448 }
449 }()
450 c.touchMem()
451 if c.currTx != nil {
452 return errors.New("fakedb: can't close fakeConn; in a Transaction")
453 }
454 if c.db == nil {
455 return errors.New("fakedb: can't close fakeConn; already closed")
456 }
457 if c.stmtsMade > c.stmtsClosed {
458 return errors.New("fakedb: can't close; dangling statement(s)")
459 }
460 c.db = nil
461 return nil
462 }
463
464 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
465 for _, arg := range args {
466 switch arg.Value.(type) {
467 case int64, float64, bool, nil, []byte, string, time.Time:
468 default:
469 if !allowAny {
470 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
471 }
472 }
473 }
474 return nil
475 }
476
477 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
478
479 panic("ExecContext was not called.")
480 }
481
482 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
483
484
485
486
487 err := checkSubsetTypes(c.db.allowAny, args)
488 if err != nil {
489 return nil, err
490 }
491 return nil, driver.ErrSkip
492 }
493
494 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
495
496 panic("QueryContext was not called.")
497 }
498
499 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
500
501
502
503
504 err := checkSubsetTypes(c.db.allowAny, args)
505 if err != nil {
506 return nil, err
507 }
508 return nil, driver.ErrSkip
509 }
510
511 func errf(msg string, args ...any) error {
512 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
513 }
514
515
516
517
518 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
519 if len(parts) != 3 {
520 stmt.Close()
521 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
522 }
523 stmt.table = parts[0]
524
525 stmt.colName = strings.Split(parts[1], ",")
526 for n, colspec := range strings.Split(parts[2], ",") {
527 if colspec == "" {
528 continue
529 }
530 nameVal := strings.Split(colspec, "=")
531 if len(nameVal) != 2 {
532 stmt.Close()
533 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
534 }
535 column, value := nameVal[0], nameVal[1]
536 _, ok := c.db.columnType(stmt.table, column)
537 if !ok {
538 stmt.Close()
539 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
540 }
541 if !strings.HasPrefix(value, "?") {
542 stmt.Close()
543 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
544 stmt.table, column)
545 }
546 stmt.placeholders++
547 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
548 }
549 return stmt, nil
550 }
551
552
553 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
554 if len(parts) != 2 {
555 stmt.Close()
556 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
557 }
558 stmt.table = parts[0]
559 for n, colspec := range strings.Split(parts[1], ",") {
560 nameType := strings.Split(colspec, "=")
561 if len(nameType) != 2 {
562 stmt.Close()
563 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
564 }
565 stmt.colName = append(stmt.colName, nameType[0])
566 stmt.colType = append(stmt.colType, nameType[1])
567 }
568 return stmt, nil
569 }
570
571
572 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
573 if len(parts) != 2 {
574 stmt.Close()
575 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
576 }
577 stmt.table = parts[0]
578 for n, colspec := range strings.Split(parts[1], ",") {
579 nameVal := strings.Split(colspec, "=")
580 if len(nameVal) != 2 {
581 stmt.Close()
582 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
583 }
584 column, value := nameVal[0], nameVal[1]
585 ctype, ok := c.db.columnType(stmt.table, column)
586 if !ok {
587 stmt.Close()
588 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
589 }
590 stmt.colName = append(stmt.colName, column)
591
592 if !strings.HasPrefix(value, "?") {
593 var subsetVal any
594
595 switch ctype {
596 case "string":
597 subsetVal = []byte(value)
598 case "blob":
599 subsetVal = []byte(value)
600 case "int32":
601 i, err := strconv.Atoi(value)
602 if err != nil {
603 stmt.Close()
604 return nil, errf("invalid conversion to int32 from %q", value)
605 }
606 subsetVal = int64(i)
607 case "table":
608 c.skipDirtySession = true
609 vparts := strings.Split(value, "!")
610
611 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
612 if err != nil {
613 return nil, err
614 }
615 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
616 substmt.Close()
617 if err != nil {
618 return nil, err
619 }
620 subsetVal = cursor
621 default:
622 stmt.Close()
623 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
624 }
625 stmt.colValue = append(stmt.colValue, subsetVal)
626 } else {
627 stmt.placeholders++
628 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
629 stmt.colValue = append(stmt.colValue, value)
630 }
631 }
632 return stmt, nil
633 }
634
635
636 var hookPrepareBadConn func() bool
637
638 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
639 panic("use PrepareContext")
640 }
641
642 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
643 c.numPrepare++
644 if c.db == nil {
645 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
646 }
647
648 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
649 return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
650 }
651
652 c.touchMem()
653 var firstStmt, prev *fakeStmt
654 for _, query := range strings.Split(query, ";") {
655 parts := strings.Split(query, "|")
656 if len(parts) < 1 {
657 return nil, errf("empty query")
658 }
659 stmt := &fakeStmt{q: query, c: c, memToucher: c}
660 if firstStmt == nil {
661 firstStmt = stmt
662 }
663 if len(parts) >= 3 {
664 switch parts[0] {
665 case "PANIC":
666 stmt.panic = parts[1]
667 parts = parts[2:]
668 case "WAIT":
669 wait, err := time.ParseDuration(parts[1])
670 if err != nil {
671 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
672 }
673 parts = parts[2:]
674 stmt.wait = wait
675 }
676 }
677 cmd := parts[0]
678 stmt.cmd = cmd
679 parts = parts[1:]
680
681 if c.waiter != nil {
682 c.waiter(ctx)
683 if err := ctx.Err(); err != nil {
684 return nil, err
685 }
686 }
687
688 if stmt.wait > 0 {
689 wait := time.NewTimer(stmt.wait)
690 select {
691 case <-wait.C:
692 case <-ctx.Done():
693 wait.Stop()
694 return nil, ctx.Err()
695 }
696 }
697
698 c.incrStat(&c.stmtsMade)
699 var err error
700 switch cmd {
701 case "WIPE":
702
703 case "USE_RAWBYTES":
704 c.db.useRawBytes.Store(true)
705 case "SELECT":
706 stmt, err = c.prepareSelect(stmt, parts)
707 case "CREATE":
708 stmt, err = c.prepareCreate(stmt, parts)
709 case "INSERT":
710 stmt, err = c.prepareInsert(ctx, stmt, parts)
711 case "NOSERT":
712
713
714 stmt, err = c.prepareInsert(ctx, stmt, parts)
715 default:
716 stmt.Close()
717 return nil, errf("unsupported command type %q", cmd)
718 }
719 if err != nil {
720 return nil, err
721 }
722 if prev != nil {
723 prev.next = stmt
724 }
725 prev = stmt
726 }
727 return firstStmt, nil
728 }
729
730 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
731 if s.panic == "ColumnConverter" {
732 panic(s.panic)
733 }
734 if len(s.placeholderConverter) == 0 {
735 return driver.DefaultParameterConverter
736 }
737 return s.placeholderConverter[idx]
738 }
739
740 func (s *fakeStmt) Close() error {
741 if s.panic == "Close" {
742 panic(s.panic)
743 }
744 if s.c == nil {
745 panic("nil conn in fakeStmt.Close")
746 }
747 if s.c.db == nil {
748 panic("in fakeStmt.Close, conn's db is nil (already closed)")
749 }
750 s.touchMem()
751 if !s.closed {
752 s.c.incrStat(&s.c.stmtsClosed)
753 s.closed = true
754 }
755 if s.next != nil {
756 s.next.Close()
757 }
758 return nil
759 }
760
761 var errClosed = errors.New("fakedb: statement has been closed")
762
763
764 var hookExecBadConn func() bool
765
766 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
767 panic("Using ExecContext")
768 }
769
770 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
771
772 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
773 if s.panic == "Exec" {
774 panic(s.panic)
775 }
776 if s.closed {
777 return nil, errClosed
778 }
779
780 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
781 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
782 }
783 if s.c.isDirtyAndMark() {
784 return nil, errFakeConnSessionDirty
785 }
786
787 err := checkSubsetTypes(s.c.db.allowAny, args)
788 if err != nil {
789 return nil, err
790 }
791 s.touchMem()
792
793 if s.wait > 0 {
794 time.Sleep(s.wait)
795 }
796
797 select {
798 default:
799 case <-ctx.Done():
800 return nil, ctx.Err()
801 }
802
803 db := s.c.db
804 switch s.cmd {
805 case "WIPE":
806 db.wipe()
807 return driver.ResultNoRows, nil
808 case "USE_RAWBYTES":
809 s.c.db.useRawBytes.Store(true)
810 return driver.ResultNoRows, nil
811 case "CREATE":
812 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
813 return nil, err
814 }
815 return driver.ResultNoRows, nil
816 case "INSERT":
817 return s.execInsert(args, true)
818 case "NOSERT":
819
820
821 return s.execInsert(args, false)
822 }
823 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
824 }
825
826
827
828
829 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
830 db := s.c.db
831 if len(args) != s.placeholders {
832 panic("error in pkg db; should only get here if size is correct")
833 }
834 db.mu.Lock()
835 t, ok := db.table(s.table)
836 db.mu.Unlock()
837 if !ok {
838 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
839 }
840
841 t.mu.Lock()
842 defer t.mu.Unlock()
843
844 var cols []any
845 if doInsert {
846 cols = make([]any, len(t.colname))
847 }
848 argPos := 0
849 for n, colname := range s.colName {
850 colidx := t.columnIndex(colname)
851 if colidx == -1 {
852 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
853 }
854 var val any
855 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
856 if strvalue == "?" {
857 val = args[argPos].Value
858 } else {
859
860 for _, a := range args {
861 if a.Name == strvalue[1:] {
862 val = a.Value
863 break
864 }
865 }
866 }
867 argPos++
868 } else {
869 val = s.colValue[n]
870 }
871 if doInsert {
872 cols[colidx] = val
873 }
874 }
875
876 if doInsert {
877 t.rows = append(t.rows, &row{cols: cols})
878 }
879 return driver.RowsAffected(1), nil
880 }
881
882
883 var hookQueryBadConn func() bool
884
885 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
886 panic("Use QueryContext")
887 }
888
889 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
890 if s.panic == "Query" {
891 panic(s.panic)
892 }
893 if s.closed {
894 return nil, errClosed
895 }
896
897 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
898 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
899 }
900 if s.c.isDirtyAndMark() {
901 return nil, errFakeConnSessionDirty
902 }
903
904 err := checkSubsetTypes(s.c.db.allowAny, args)
905 if err != nil {
906 return nil, err
907 }
908
909 s.touchMem()
910 db := s.c.db
911 if len(args) != s.placeholders {
912 panic("error in pkg db; should only get here if size is correct")
913 }
914
915 setMRows := make([][]*row, 0, 1)
916 setColumns := make([][]string, 0, 1)
917 setColType := make([][]string, 0, 1)
918
919 for {
920 db.mu.Lock()
921 t, ok := db.table(s.table)
922 db.mu.Unlock()
923 if !ok {
924 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
925 }
926
927 if s.table == "magicquery" {
928 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
929 if args[0].Value == "sleep" {
930 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
931 }
932 }
933 }
934 if s.table == "tx_status" && s.colName[0] == "tx_status" {
935 txStatus := "autocommit"
936 if s.c.currTx != nil {
937 txStatus = "transaction"
938 }
939 cursor := &rowsCursor{
940 db: s.c.db,
941 parentMem: s.c,
942 posRow: -1,
943 rows: [][]*row{
944 {
945 {
946 cols: []any{
947 txStatus,
948 },
949 },
950 },
951 },
952 cols: [][]string{
953 {
954 "tx_status",
955 },
956 },
957 colType: [][]string{
958 {
959 "string",
960 },
961 },
962 errPos: -1,
963 }
964 return cursor, nil
965 }
966
967 t.mu.Lock()
968
969 colIdx := make(map[string]int)
970 for _, name := range s.colName {
971 idx := t.columnIndex(name)
972 if idx == -1 {
973 t.mu.Unlock()
974 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
975 }
976 colIdx[name] = idx
977 }
978
979 mrows := []*row{}
980 rows:
981 for _, trow := range t.rows {
982
983
984
985 for _, wcol := range s.whereCol {
986 idx := t.columnIndex(wcol.Column)
987 if idx == -1 {
988 t.mu.Unlock()
989 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
990 }
991 tcol := trow.cols[idx]
992 if bs, ok := tcol.([]byte); ok {
993
994 tcol = string(bs)
995 }
996 var argValue any
997 if wcol.Placeholder == "?" {
998 argValue = args[wcol.Ordinal-1].Value
999 } else {
1000
1001 for _, a := range args {
1002 if a.Name == wcol.Placeholder[1:] {
1003 argValue = a.Value
1004 break
1005 }
1006 }
1007 }
1008 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
1009 continue rows
1010 }
1011 }
1012 mrow := &row{cols: make([]any, len(s.colName))}
1013 for seli, name := range s.colName {
1014 mrow.cols[seli] = trow.cols[colIdx[name]]
1015 }
1016 mrows = append(mrows, mrow)
1017 }
1018
1019 var colType []string
1020 for _, column := range s.colName {
1021 colType = append(colType, t.coltype[t.columnIndex(column)])
1022 }
1023
1024 t.mu.Unlock()
1025
1026 setMRows = append(setMRows, mrows)
1027 setColumns = append(setColumns, s.colName)
1028 setColType = append(setColType, colType)
1029
1030 if s.next == nil {
1031 break
1032 }
1033 s = s.next
1034 }
1035
1036 cursor := &rowsCursor{
1037 db: s.c.db,
1038 parentMem: s.c,
1039 posRow: -1,
1040 rows: setMRows,
1041 cols: setColumns,
1042 colType: setColType,
1043 errPos: -1,
1044 }
1045 return cursor, nil
1046 }
1047
1048 func (s *fakeStmt) NumInput() int {
1049 if s.panic == "NumInput" {
1050 panic(s.panic)
1051 }
1052 return s.placeholders
1053 }
1054
1055
1056 var hookCommitBadConn func() bool
1057
1058 func (tx *fakeTx) Commit() error {
1059 tx.c.currTx = nil
1060 if hookCommitBadConn != nil && hookCommitBadConn() {
1061 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1062 }
1063 tx.c.touchMem()
1064 return nil
1065 }
1066
1067
1068 var hookRollbackBadConn func() bool
1069
1070 func (tx *fakeTx) Rollback() error {
1071 tx.c.currTx = nil
1072 if hookRollbackBadConn != nil && hookRollbackBadConn() {
1073 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1074 }
1075 tx.c.touchMem()
1076 return nil
1077 }
1078
1079 type rowsCursor struct {
1080 db *fakeDB
1081 parentMem memToucher
1082 cols [][]string
1083 colType [][]string
1084 posSet int
1085 posRow int
1086 rows [][]*row
1087 closed bool
1088
1089
1090 errPos int
1091 err error
1092
1093
1094
1095
1096 bytesClone map[*byte][]byte
1097
1098
1099
1100
1101
1102 line int64
1103
1104
1105 closeErr error
1106 }
1107
1108 func (rc *rowsCursor) touchMem() {
1109 rc.parentMem.touchMem()
1110 rc.line++
1111 }
1112
1113 func (rc *rowsCursor) Close() error {
1114 rc.touchMem()
1115 rc.parentMem.touchMem()
1116 rc.closed = true
1117 return rc.closeErr
1118 }
1119
1120 func (rc *rowsCursor) Columns() []string {
1121 return rc.cols[rc.posSet]
1122 }
1123
1124 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1125 return colTypeToReflectType(rc.colType[rc.posSet][index])
1126 }
1127
1128 var rowsCursorNextHook func(dest []driver.Value) error
1129
1130 func (rc *rowsCursor) Next(dest []driver.Value) error {
1131 if rowsCursorNextHook != nil {
1132 return rowsCursorNextHook(dest)
1133 }
1134
1135 if rc.closed {
1136 return errors.New("fakedb: cursor is closed")
1137 }
1138 rc.touchMem()
1139 rc.posRow++
1140 if rc.posRow == rc.errPos {
1141 return rc.err
1142 }
1143 if rc.posRow >= len(rc.rows[rc.posSet]) {
1144 return io.EOF
1145 }
1146 for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1147
1148
1149
1150
1151
1152
1153 dest[i] = v
1154
1155 if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
1156 if rc.bytesClone == nil {
1157 rc.bytesClone = make(map[*byte][]byte)
1158 }
1159 clone, ok := rc.bytesClone[&bs[0]]
1160 if !ok {
1161 clone = make([]byte, len(bs))
1162 copy(clone, bs)
1163 rc.bytesClone[&bs[0]] = clone
1164 }
1165 dest[i] = clone
1166 }
1167 }
1168 return nil
1169 }
1170
1171 func (rc *rowsCursor) HasNextResultSet() bool {
1172 rc.touchMem()
1173 return rc.posSet < len(rc.rows)-1
1174 }
1175
1176 func (rc *rowsCursor) NextResultSet() error {
1177 rc.touchMem()
1178 if rc.HasNextResultSet() {
1179 rc.posSet++
1180 rc.posRow = -1
1181 return nil
1182 }
1183 return io.EOF
1184 }
1185
1186
1187
1188
1189
1190
1191
1192 type fakeDriverString struct{}
1193
1194 func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
1195 switch c := v.(type) {
1196 case string, []byte:
1197 return v, nil
1198 case *string:
1199 if c == nil {
1200 return nil, nil
1201 }
1202 return *c, nil
1203 }
1204 return fmt.Sprintf("%v", v), nil
1205 }
1206
1207 type anyTypeConverter struct{}
1208
1209 func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
1210 return v, nil
1211 }
1212
1213 func converterForType(typ string) driver.ValueConverter {
1214 switch typ {
1215 case "bool":
1216 return driver.Bool
1217 case "nullbool":
1218 return driver.Null{Converter: driver.Bool}
1219 case "byte", "int16":
1220 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1221 case "int32":
1222 return driver.Int32
1223 case "nullbyte", "nullint32", "nullint16":
1224 return driver.Null{Converter: driver.DefaultParameterConverter}
1225 case "string":
1226 return driver.NotNull{Converter: fakeDriverString{}}
1227 case "nullstring":
1228 return driver.Null{Converter: fakeDriverString{}}
1229 case "int64":
1230
1231 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1232 case "nullint64":
1233
1234 return driver.Null{Converter: driver.DefaultParameterConverter}
1235 case "float64":
1236
1237 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1238 case "nullfloat64":
1239
1240 return driver.Null{Converter: driver.DefaultParameterConverter}
1241 case "datetime":
1242 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1243 case "nulldatetime":
1244 return driver.Null{Converter: driver.DefaultParameterConverter}
1245 case "any":
1246 return anyTypeConverter{}
1247 }
1248 panic("invalid fakedb column type of " + typ)
1249 }
1250
1251 func colTypeToReflectType(typ string) reflect.Type {
1252 switch typ {
1253 case "bool":
1254 return reflect.TypeFor[bool]()
1255 case "nullbool":
1256 return reflect.TypeFor[NullBool]()
1257 case "int16":
1258 return reflect.TypeFor[int16]()
1259 case "nullint16":
1260 return reflect.TypeFor[NullInt16]()
1261 case "int32":
1262 return reflect.TypeFor[int32]()
1263 case "nullint32":
1264 return reflect.TypeFor[NullInt32]()
1265 case "string":
1266 return reflect.TypeFor[string]()
1267 case "nullstring":
1268 return reflect.TypeFor[NullString]()
1269 case "int64":
1270 return reflect.TypeFor[int64]()
1271 case "nullint64":
1272 return reflect.TypeFor[NullInt64]()
1273 case "float64":
1274 return reflect.TypeFor[float64]()
1275 case "nullfloat64":
1276 return reflect.TypeFor[NullFloat64]()
1277 case "datetime":
1278 return reflect.TypeFor[time.Time]()
1279 case "any":
1280 return reflect.TypeFor[any]()
1281 }
1282 panic("invalid fakedb column type of " + typ)
1283 }
1284
View as plain text