1
2
3
4
5 package sql
6
7 import (
8 "bytes"
9 "context"
10 "database/sql/driver"
11 "errors"
12 "fmt"
13 "internal/race"
14 "internal/testenv"
15 "math/rand"
16 "reflect"
17 "runtime"
18 "slices"
19 "strings"
20 "sync"
21 "sync/atomic"
22 "testing"
23 "testing/synctest"
24 "time"
25 "uuid"
26 )
27
28 type requireFeature string
29
30
31
32
33
34
35
36
37
38
39
40 func testDatabase(t *testing.T, f func(t *testing.T, db *DB), opts ...any) {
41 var require []string
42 for _, o := range opts {
43 switch o := o.(type) {
44 case requireFeature:
45 require = append(require, string(o))
46 default:
47 t.Fatalf("unrecognized option %T", o)
48 }
49 }
50 Test:
51 for _, test := range []struct {
52 name string
53 connector driver.Connector
54 features []string
55 }{
56 {
57
58 name: "basic",
59 connector: &basicConnector{name: fakeDBName},
60 },
61 {
62
63
64
65 name: "default",
66 connector: &fakeConnector{name: fakeDBName},
67 features: []string{
68 "ConnBeginTx",
69 "NamedValue",
70 "Validator",
71 },
72 },
73 {
74 name: "scancols",
75 connector: &rowsColumnScannerConnector{name: fakeDBName},
76 features: []string{
77 "ConnBeginTx",
78 "NamedValue",
79 "Validator",
80 "ScanColumn",
81 },
82 },
83 } {
84 for _, req := range require {
85 if !slices.Contains(test.features, req) {
86 continue Test
87 }
88 }
89 t.Run(test.name, func(t *testing.T) {
90 synctest.Test(t, func(t *testing.T) {
91 db := OpenDB(test.connector)
92 if _, err := db.Exec("WIPE"); err != nil {
93 t.Fatalf("exec wipe: %v", err)
94 }
95 t.Cleanup(func() {
96 closeDB(t, db)
97 })
98 f(t, db)
99 })
100 })
101 }
102 }
103
104 func init() {
105 type dbConn struct {
106 db *DB
107 c *driverConn
108 }
109 freedFrom := make(map[dbConn]string)
110 var mu sync.Mutex
111 getFreedFrom := func(c dbConn) string {
112 mu.Lock()
113 defer mu.Unlock()
114 return freedFrom[c]
115 }
116 setFreedFrom := func(c dbConn, s string) {
117 mu.Lock()
118 defer mu.Unlock()
119 freedFrom[c] = s
120 }
121 putConnHook = func(db *DB, c *driverConn) {
122 if slices.Contains(db.freeConn, c) {
123
124
125
126 println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack())
127 panic("double free of conn.")
128 }
129 setFreedFrom(dbConn{db, c}, stack())
130 }
131 }
132
133 const fakeDBName = "foo"
134
135 var chrisBirthday = time.Unix(123456789, 0)
136
137 func newTestDB(t testing.TB, name string) *DB {
138 return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name)
139 }
140
141 func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
142 fc.name = fakeDBName
143 db := OpenDB(fc)
144 if _, err := db.Exec("WIPE"); err != nil {
145 t.Fatalf("exec wipe: %v", err)
146 }
147 t.Cleanup(func() {
148 closeDB(t, db)
149 })
150 if name != "" {
151 populate(t, db, name)
152 }
153 return db
154 }
155
156 func populate(t testing.TB, db *DB, name string) {
157 t.Helper()
158 switch name {
159 case "people":
160 exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
161 exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
162 exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
163 exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
164 case "magicquery":
165
166 exec(t, db, "CREATE|magicquery|op=string,millis=int32")
167 exec(t, db, "INSERT|magicquery|op=sleep,millis=10")
168 case "tx_status":
169
170 exec(t, db, "CREATE|tx_status|tx_status=string")
171 exec(t, db, "INSERT|tx_status|tx_status=invalid")
172 default:
173 t.Fatalf("unknown database name %q", name)
174 }
175 }
176
177
178 type basicConn struct {
179 driver.Conn
180 }
181
182 func (c *basicConn) getFakeConn() *fakeConn {
183 return c.Conn.(*fakeConn)
184 }
185
186 func (c *basicConn) Prepare(query string) (driver.Stmt, error) {
187 stmt, err := c.Conn.(*fakeConn).PrepareContext(context.Background(), query)
188 if err != nil {
189 return nil, err
190 }
191 return &basicStmt{fc: c.Conn.(*fakeConn), Stmt: stmt}, nil
192 }
193
194 func (c *basicConn) Begin() (driver.Tx, error) {
195 return c.Conn.(*fakeConn).Begin()
196 }
197
198 func (c *basicConn) Close() error {
199 return c.Conn.(*fakeConn).Close()
200 }
201
202 type basicStmt struct {
203 fc *fakeConn
204 driver.Stmt
205 }
206
207 func (s *basicStmt) valuesToNamedValues(args []driver.Value) ([]driver.NamedValue, error) {
208 nv := make([]driver.NamedValue, len(args))
209 for i, arg := range args {
210 val, err := s.Stmt.(*fakeStmt).ColumnConverter(i).ConvertValue(arg)
211 if err != nil {
212 return nil, fmt.Errorf("sql: converting argument $%v type: %w", i+1, err)
213 }
214 nv[i] = driver.NamedValue{
215 Ordinal: i + 1,
216 Value: val,
217 }
218 }
219 return nv, nil
220 }
221
222 func (s *basicStmt) Exec(args []driver.Value) (driver.Result, error) {
223 nvs, err := s.valuesToNamedValues(args)
224 if err != nil {
225 return nil, err
226 }
227 return s.Stmt.(*fakeStmt).ExecContext(context.Background(), nvs)
228 }
229
230 func (s *basicStmt) Query(args []driver.Value) (driver.Rows, error) {
231 nvs, err := s.valuesToNamedValues(args)
232 if err != nil {
233 return nil, err
234 }
235 return s.Stmt.(*fakeStmt).QueryContext(context.Background(), nvs)
236 }
237
238 type basicConnector struct {
239 fakeConnector
240 }
241
242 func (c *basicConnector) Connect(ctx context.Context) (driver.Conn, error) {
243 conn, err := c.fakeConnector.Connect(ctx)
244 if err != nil {
245 return nil, err
246 }
247 fc := getFakeConn(conn)
248 fc.skipDirtySession = true
249 return &basicConn{fc}, nil
250 }
251
252 func TestOpenDB(t *testing.T) {
253 synctest.Test(t, testOpenDB)
254 }
255 func testOpenDB(t *testing.T) {
256 db := OpenDB(dsnConnector{dsn: fakeDBName, driver: fdriver})
257 defer db.Close()
258 if db.Driver() != fdriver {
259 t.Fatalf("OpenDB should return the driver of the Connector")
260 }
261 }
262
263 func TestDriverPanic(t *testing.T) {
264 synctest.Test(t, testDriverPanic)
265 }
266 func testDriverPanic(t *testing.T) {
267 db, err := Open("test", fakeDBName)
268 if err != nil {
269 t.Fatalf("Open: %v", err)
270 }
271 defer db.Close()
272
273 expectPanic := func(name string, f func()) {
274 defer func() {
275 err := recover()
276 if err == nil {
277 t.Fatalf("%s did not panic", name)
278 }
279 }()
280 f()
281 }
282
283 expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") })
284 exec(t, db, "WIPE")
285 expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") })
286 exec(t, db, "WIPE")
287 expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") })
288 exec(t, db, "WIPE")
289 exec(t, db, "PANIC|Query|WIPE")
290 exec(t, db, "WIPE")
291
292 exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
293
294 expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") })
295 expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") })
296 expectPanic("Query Close", func() {
297 rows, err := db.Query("PANIC|Close|SELECT|people|age,name|")
298 if err != nil {
299 t.Fatal(err)
300 }
301 rows.Close()
302 })
303 db.Query("PANIC|Exec|SELECT|people|age,name|")
304 exec(t, db, "WIPE")
305 }
306
307 func exec(t testing.TB, db *DB, query string, args ...any) {
308 t.Helper()
309 _, err := db.Exec(query, args...)
310 if err != nil {
311 t.Fatalf("Exec of %q: %v", query, err)
312 }
313 }
314
315 func closeDB(t testing.TB, db *DB) {
316 t.Helper()
317
318 if e := recover(); e != nil {
319 fmt.Printf("Panic: %v\n", e)
320 panic(e)
321 }
322 defer setHookpostCloseConn(nil)
323 setHookpostCloseConn(func(_ *fakeConn, err error) {
324 if err != nil {
325 t.Errorf("Error closing fakeConn: %v", err)
326 }
327 })
328 db.mu.Lock()
329 for i, dc := range db.freeConn {
330 if n := len(dc.openStmt); n > 0 {
331
332
333
334
335
336 t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
337 }
338 }
339 db.mu.Unlock()
340
341 err := db.Close()
342 if err != nil {
343 t.Fatalf("error closing DB: %v", err)
344 }
345
346
347 numOpenConns := func() int {
348 db.mu.Lock()
349 defer db.mu.Unlock()
350 return db.numOpen
351 }
352 if _, ok := t.(*testing.B); ok {
353
354 deadline := time.Now().Add(5 * time.Second)
355 for numOpenConns() > 0 && time.Now().Before(deadline) {
356 time.Sleep(1 * time.Millisecond)
357 }
358 } else {
359 synctest.Wait()
360 }
361 if numOpen := numOpenConns(); numOpen != 0 {
362 t.Fatalf("%d connections still open after closing DB", numOpen)
363 }
364 }
365
366
367
368 func numPrepares(t *testing.T, db *DB) int {
369 if n := len(db.freeConn); n != 1 {
370 t.Fatalf("free conns = %d; want 1", n)
371 }
372 return getFakeConn(db.freeConn[0].ci).numPrepare
373 }
374
375 func (db *DB) numDeps() int {
376 synctest.Wait()
377 db.mu.Lock()
378 defer db.mu.Unlock()
379 return len(db.dep)
380 }
381
382 func (db *DB) numFreeConns() int {
383 synctest.Wait()
384 db.mu.Lock()
385 defer db.mu.Unlock()
386 return len(db.freeConn)
387 }
388
389
390 func (db *DB) clearAllConns(t *testing.T) {
391 db.SetMaxIdleConns(0)
392
393 if g, w := db.numFreeConns(), 0; g != w {
394 t.Errorf("free conns = %d; want %d", g, w)
395 }
396
397 if n := db.numDeps(); n > 0 {
398 t.Errorf("number of dependencies = %d; expected 0", n)
399 db.dumpDeps(t)
400 }
401 }
402
403 func (db *DB) dumpDeps(t *testing.T) {
404 for fc := range db.dep {
405 db.dumpDep(t, 0, fc, map[finalCloser]bool{})
406 }
407 }
408
409 func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) {
410 seen[dep] = true
411 indent := strings.Repeat(" ", depth)
412 ds := db.dep[dep]
413 for k := range ds {
414 t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k)
415 if fc, ok := k.(finalCloser); ok {
416 if !seen[fc] {
417 db.dumpDep(t, depth+1, fc, seen)
418 }
419 }
420 }
421 }
422
423 func TestQuery(t *testing.T) {
424 testDatabase(t, testQuery)
425 }
426 func testQuery(t *testing.T, db *DB) {
427 populate(t, db, "people")
428 prepares0 := numPrepares(t, db)
429 rows, err := db.Query("SELECT|people|age,name|")
430 if err != nil {
431 t.Fatalf("Query: %v", err)
432 }
433 defer rows.Close()
434 type row struct {
435 age int
436 name string
437 }
438 got := []row{}
439 for rows.Next() {
440 var r row
441 err = rows.Scan(&r.age, &r.name)
442 if err != nil {
443 t.Fatalf("Scan: %v", err)
444 }
445 got = append(got, r)
446 }
447 err = rows.Err()
448 if err != nil {
449 t.Fatalf("Err: %v", err)
450 }
451 want := []row{
452 {age: 1, name: "Alice"},
453 {age: 2, name: "Bob"},
454 {age: 3, name: "Chris"},
455 }
456 if !slices.Equal(got, want) {
457 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
458 }
459
460
461
462 if n := db.numFreeConns(); n != 1 {
463 t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
464 }
465 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
466 t.Errorf("executed %d Prepare statements; want 1", prepares)
467 }
468 }
469
470
471 func TestQueryContext(t *testing.T) {
472 testDatabase(t, testQueryContext)
473 }
474 func testQueryContext(t *testing.T, db *DB) {
475 populate(t, db, "people")
476 prepares0 := numPrepares(t, db)
477
478 ctx, cancel := context.WithCancel(context.Background())
479 defer cancel()
480
481 rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
482 if err != nil {
483 t.Fatalf("Query: %v", err)
484 }
485 type row struct {
486 age int
487 name string
488 }
489 got := []row{}
490 index := 0
491 for rows.Next() {
492 if index == 2 {
493 cancel()
494 waitForRowsClose(t, rows)
495 }
496 var r row
497 err = rows.Scan(&r.age, &r.name)
498 if err != nil {
499 if index == 2 {
500 break
501 }
502 t.Fatalf("Scan: %v", err)
503 }
504 if index == 2 && err != context.Canceled {
505 t.Fatalf("Scan: %v; want context.Canceled", err)
506 }
507 got = append(got, r)
508 index++
509 }
510 select {
511 case <-ctx.Done():
512 if err := ctx.Err(); err != context.Canceled {
513 t.Fatalf("context err = %v; want context.Canceled", err)
514 }
515 default:
516 t.Fatalf("context err = nil; want context.Canceled")
517 }
518 want := []row{
519 {age: 1, name: "Alice"},
520 {age: 2, name: "Bob"},
521 }
522 if !slices.Equal(got, want) {
523 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
524 }
525
526
527
528 waitForRowsClose(t, rows)
529 waitForFree(t, db, 1)
530 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
531 t.Errorf("executed %d Prepare statements; want 1", prepares)
532 }
533 }
534
535
536
537 func waitForFree(t *testing.T, db *DB, want int) {
538 synctest.Wait()
539 numFree := db.numFreeConns()
540 if numFree != want {
541 t.Fatalf("free conns after hitting EOF = %d; want %d", numFree, want)
542 }
543 }
544
545 func waitForRowsClose(t *testing.T, rows *Rows) {
546 synctest.Wait()
547 rows.closemu.RLock()
548 defer rows.closemu.RUnlock()
549 if !rows.closed {
550 t.Fatal("failed to close rows")
551 }
552 }
553
554
555
556 func TestQueryContextWait(t *testing.T) {
557 testDatabase(t, testQueryContextWait)
558 }
559 func testQueryContextWait(t *testing.T, db *DB) {
560 populate(t, db, "people")
561 prepares0 := numPrepares(t, db)
562
563 ctx, cancel := context.WithCancel(context.Background())
564 defer cancel()
565
566
567
568
569 c, err := db.Conn(ctx)
570 if err != nil {
571 t.Fatal(err)
572 }
573
574 getFakeConn(c.dc.ci).waiter = func(c context.Context) {
575 cancel()
576 <-ctx.Done()
577 }
578 _, err = c.QueryContext(ctx, "SELECT|people|age,name|")
579 c.Close()
580 if err != context.Canceled {
581 t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
582 }
583
584
585 waitForFree(t, db, 1)
586 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
587 t.Fatalf("executed %d Prepare statements; want 1", prepares)
588 }
589 }
590
591
592
593 func TestTxContextWait(t *testing.T) {
594 testDatabase(t, func(t *testing.T, db *DB) {
595 testContextWait(t, false, db)
596 })
597 }
598
599
600
601 func TestTxContextWaitNoDiscard(t *testing.T) {
602 testDatabase(t, func(t *testing.T, db *DB) {
603 testContextWait(t, true, db)
604 })
605 }
606
607 func testContextWait(t *testing.T, keepConnOnRollback bool, db *DB) {
608 populate(t, db, "people")
609
610 ctx, cancel := context.WithCancel(context.Background())
611
612 tx, err := db.BeginTx(ctx, nil)
613 if err != nil {
614 t.Fatal(err)
615 }
616 tx.keepConnOnRollback = keepConnOnRollback
617
618 getFakeConn(tx.dc.ci).waiter = func(c context.Context) {
619 cancel()
620 <-ctx.Done()
621 }
622
623
624
625 _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
626 if err != context.Canceled {
627 t.Fatalf("expected QueryContext to error with context canceled but returned %v", err)
628 }
629
630 if keepConnOnRollback {
631 waitForFree(t, db, 1)
632 } else {
633 waitForFree(t, db, 0)
634 }
635 }
636
637
638
639
640 func TestUnsupportedOptions(t *testing.T) {
641 testDatabase(t, testUnsupportedOptions)
642 }
643 func testUnsupportedOptions(t *testing.T, db *DB) {
644 populate(t, db, "people")
645 _, err := db.BeginTx(context.Background(), &TxOptions{
646 Isolation: LevelSerializable, ReadOnly: true,
647 })
648 if err == nil {
649 t.Fatal("expected error when using unsupported options, got nil")
650 }
651 }
652
653 func TestMultiResultSetQuery(t *testing.T) {
654 testDatabase(t, testMultiResultSetQuery)
655 }
656 func testMultiResultSetQuery(t *testing.T, db *DB) {
657 populate(t, db, "people")
658 prepares0 := numPrepares(t, db)
659 rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
660 if err != nil {
661 t.Fatalf("Query: %v", err)
662 }
663 type row1 struct {
664 age int
665 name string
666 }
667 type row2 struct {
668 name string
669 }
670 got1 := []row1{}
671 for rows.Next() {
672 var r row1
673 err = rows.Scan(&r.age, &r.name)
674 if err != nil {
675 t.Fatalf("Scan: %v", err)
676 }
677 got1 = append(got1, r)
678 }
679 err = rows.Err()
680 if err != nil {
681 t.Fatalf("Err: %v", err)
682 }
683 want1 := []row1{
684 {age: 1, name: "Alice"},
685 {age: 2, name: "Bob"},
686 {age: 3, name: "Chris"},
687 }
688 if !slices.Equal(got1, want1) {
689 t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
690 }
691
692 if !rows.NextResultSet() {
693 t.Errorf("expected another result set")
694 }
695
696 got2 := []row2{}
697 for rows.Next() {
698 var r row2
699 err = rows.Scan(&r.name)
700 if err != nil {
701 t.Fatalf("Scan: %v", err)
702 }
703 got2 = append(got2, r)
704 }
705 err = rows.Err()
706 if err != nil {
707 t.Fatalf("Err: %v", err)
708 }
709 want2 := []row2{
710 {name: "Alice"},
711 {name: "Bob"},
712 {name: "Chris"},
713 }
714 if !slices.Equal(got2, want2) {
715 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
716 }
717 if rows.NextResultSet() {
718 t.Errorf("expected no more result sets")
719 }
720
721
722
723 waitForFree(t, db, 1)
724 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
725 t.Errorf("executed %d Prepare statements; want 1", prepares)
726 }
727 }
728
729 func TestQueryNamedArg(t *testing.T) {
730 testDatabase(t, testQueryNamedArg, requireFeature("NamedValue"))
731 }
732 func testQueryNamedArg(t *testing.T, db *DB) {
733 populate(t, db, "people")
734 prepares0 := numPrepares(t, db)
735 rows, err := db.Query(
736
737 "SELECT|people|age,name|name=?name,age=?age",
738 Named("age", 2),
739 Named("name", "Bob"),
740 )
741 if err != nil {
742 t.Fatalf("Query: %v", err)
743 }
744 type row struct {
745 age int
746 name string
747 }
748 got := []row{}
749 for rows.Next() {
750 var r row
751 err = rows.Scan(&r.age, &r.name)
752 if err != nil {
753 t.Fatalf("Scan: %v", err)
754 }
755 got = append(got, r)
756 }
757 err = rows.Err()
758 if err != nil {
759 t.Fatalf("Err: %v", err)
760 }
761 want := []row{
762 {age: 2, name: "Bob"},
763 }
764 if !slices.Equal(got, want) {
765 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
766 }
767
768
769
770 if n := db.numFreeConns(); n != 1 {
771 t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
772 }
773 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
774 t.Errorf("executed %d Prepare statements; want 1", prepares)
775 }
776 }
777
778 func TestPoolExhaustOnCancel(t *testing.T) {
779 synctest.Test(t, testPoolExhaustOnCancel)
780 }
781 func testPoolExhaustOnCancel(t *testing.T) {
782 max := 3
783 var saturate, saturateDone sync.WaitGroup
784 saturate.Add(max)
785 saturateDone.Add(max)
786
787 donePing := make(chan bool)
788 state := 0
789
790
791
792
793
794
795
796 waiter := func(ctx context.Context) {
797 switch state {
798 case 0:
799
800 case 1:
801 saturate.Done()
802 select {
803 case <-ctx.Done():
804 case <-donePing:
805 }
806 case 2:
807 }
808 }
809 db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people")
810
811 db.SetMaxOpenConns(max)
812
813
814
815
816 state = 1
817 for i := 0; i < max; i++ {
818 go func() {
819 rows, err := db.Query("SELECT|people|name,photo|")
820 if err != nil {
821 t.Errorf("Query: %v", err)
822 return
823 }
824 rows.Close()
825 saturateDone.Done()
826 }()
827 }
828
829 saturate.Wait()
830 if t.Failed() {
831 t.FailNow()
832 }
833 state = 2
834
835
836 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
837 defer cancel()
838
839 for i := 0; i < max; i++ {
840 ctxReq, cancelReq := context.WithCancel(ctx)
841 go func() {
842 time.Sleep(100 * time.Millisecond)
843 cancelReq()
844 }()
845 err := db.PingContext(ctxReq)
846 if err != context.Canceled {
847 t.Fatalf("PingContext (Exhaust): %v", err)
848 }
849 }
850 close(donePing)
851 saturateDone.Wait()
852
853
854 err := db.PingContext(ctx)
855 if err != nil {
856 t.Fatalf("PingContext (Normal): %v", err)
857 }
858 }
859
860 func TestRowsColumns(t *testing.T) {
861 testDatabase(t, testRowsColumns)
862 }
863 func testRowsColumns(t *testing.T, db *DB) {
864 populate(t, db, "people")
865 rows, err := db.Query("SELECT|people|age,name|")
866 if err != nil {
867 t.Fatalf("Query: %v", err)
868 }
869 cols, err := rows.Columns()
870 if err != nil {
871 t.Fatalf("Columns: %v", err)
872 }
873 want := []string{"age", "name"}
874 if !slices.Equal(cols, want) {
875 t.Errorf("got %#v; want %#v", cols, want)
876 }
877 if err := rows.Close(); err != nil {
878 t.Errorf("error closing rows: %s", err)
879 }
880 }
881
882 func TestRowsColumnTypes(t *testing.T) {
883 testDatabase(t, testRowsColumnTypes)
884 }
885 func testRowsColumnTypes(t *testing.T, db *DB) {
886 populate(t, db, "people")
887 rows, err := db.Query("SELECT|people|age,name|")
888 if err != nil {
889 t.Fatalf("Query: %v", err)
890 }
891 tt, err := rows.ColumnTypes()
892 if err != nil {
893 t.Fatalf("ColumnTypes: %v", err)
894 }
895
896 types := make([]reflect.Type, len(tt))
897 for i, tp := range tt {
898 st := tp.ScanType()
899 if st == nil {
900 t.Errorf("scantype is null for column %q", tp.Name())
901 continue
902 }
903 types[i] = st
904 }
905 values := make([]any, len(tt))
906 for i := range values {
907 values[i] = reflect.New(types[i]).Interface()
908 }
909 ct := 0
910 for rows.Next() {
911 err = rows.Scan(values...)
912 if err != nil {
913 t.Fatalf("failed to scan values in %v", err)
914 }
915 if ct == 1 {
916 if age := *values[0].(*int32); age != 2 {
917 t.Errorf("Expected 2, got %v", age)
918 }
919 if name := *values[1].(*string); name != "Bob" {
920 t.Errorf("Expected Bob, got %v", name)
921 }
922 }
923 ct++
924 }
925 if ct != 3 {
926 t.Errorf("expected 3 rows, got %d", ct)
927 }
928
929 if err := rows.Close(); err != nil {
930 t.Errorf("error closing rows: %s", err)
931 }
932 }
933
934 func TestQueryRow(t *testing.T) {
935 testDatabase(t, testQueryRow)
936 }
937 func testQueryRow(t *testing.T, db *DB) {
938 populate(t, db, "people")
939 var name string
940 var age int
941 var birthday time.Time
942
943 err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age)
944 if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") {
945 t.Errorf("expected error from wrong number of arguments; actually got: %v", err)
946 }
947
948 err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday)
949 if err != nil || !birthday.Equal(chrisBirthday) {
950 t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday)
951 }
952
953 err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name)
954 if err != nil {
955 t.Fatalf("age QueryRow+Scan: %v", err)
956 }
957 if name != "Bob" {
958 t.Errorf("expected name Bob, got %q", name)
959 }
960 if age != 2 {
961 t.Errorf("expected age 2, got %d", age)
962 }
963
964 err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name)
965 if err != nil {
966 t.Fatalf("name QueryRow+Scan: %v", err)
967 }
968 if name != "Alice" {
969 t.Errorf("expected name Alice, got %q", name)
970 }
971 if age != 1 {
972 t.Errorf("expected age 1, got %d", age)
973 }
974
975 var photo []byte
976 err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
977 if err != nil {
978 t.Fatalf("photo QueryRow+Scan: %v", err)
979 }
980 want := []byte("APHOTO")
981 if !slices.Equal(photo, want) {
982 t.Errorf("photo = %q; want %q", photo, want)
983 }
984 }
985
986 func TestRowErr(t *testing.T) {
987 testDatabase(t, testRowErr)
988 }
989 func testRowErr(t *testing.T, db *DB) {
990 populate(t, db, "people")
991
992 row := db.QueryRowContext(context.Background(), "SELECT|people|bdate|age=?", 3)
993 if err := row.Err(); err != nil {
994 t.Errorf("Unexpected err = %v; want %v", err, nil)
995 }
996 row.Scan()
997
998 ctx, cancel := context.WithCancel(context.Background())
999 cancel()
1000
1001 err := db.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 3).Err()
1002 exp := "context canceled"
1003 if err == nil || !strings.Contains(err.Error(), exp) {
1004 t.Errorf("Expected err = %v; got %v", exp, err)
1005 }
1006 }
1007
1008 func TestTxRollbackCommitErr(t *testing.T) {
1009 testDatabase(t, testTxRollbackCommitErr)
1010 }
1011 func testTxRollbackCommitErr(t *testing.T, db *DB) {
1012 populate(t, db, "people")
1013
1014 tx, err := db.Begin()
1015 if err != nil {
1016 t.Fatal(err)
1017 }
1018 err = tx.Rollback()
1019 if err != nil {
1020 t.Errorf("expected nil error from Rollback; got %v", err)
1021 }
1022 err = tx.Commit()
1023 if err != ErrTxDone {
1024 t.Errorf("expected %q from Commit; got %q", ErrTxDone, err)
1025 }
1026
1027 tx, err = db.Begin()
1028 if err != nil {
1029 t.Fatal(err)
1030 }
1031 err = tx.Commit()
1032 if err != nil {
1033 t.Errorf("expected nil error from Commit; got %v", err)
1034 }
1035 err = tx.Rollback()
1036 if err != ErrTxDone {
1037 t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err)
1038 }
1039 }
1040
1041 func TestStatementErrorAfterClose(t *testing.T) {
1042 testDatabase(t, testStatementErrorAfterClose)
1043 }
1044 func testStatementErrorAfterClose(t *testing.T, db *DB) {
1045 populate(t, db, "people")
1046 stmt, err := db.Prepare("SELECT|people|age|name=?")
1047 if err != nil {
1048 t.Fatalf("Prepare: %v", err)
1049 }
1050 err = stmt.Close()
1051 if err != nil {
1052 t.Fatalf("Close: %v", err)
1053 }
1054 var name string
1055 err = stmt.QueryRow("foo").Scan(&name)
1056 if err == nil {
1057 t.Errorf("expected error from QueryRow.Scan after Stmt.Close")
1058 }
1059 }
1060
1061 func TestStatementQueryRow(t *testing.T) {
1062 testDatabase(t, testStatementQueryRow)
1063 }
1064 func testStatementQueryRow(t *testing.T, db *DB) {
1065 populate(t, db, "people")
1066 stmt, err := db.Prepare("SELECT|people|age|name=?")
1067 if err != nil {
1068 t.Fatalf("Prepare: %v", err)
1069 }
1070 defer stmt.Close()
1071 var age int
1072 for n, tt := range []struct {
1073 name string
1074 want int
1075 }{
1076 {"Alice", 1},
1077 {"Bob", 2},
1078 {"Chris", 3},
1079 } {
1080 if err := stmt.QueryRow(tt.name).Scan(&age); err != nil {
1081 t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err)
1082 } else if age != tt.want {
1083 t.Errorf("%d: age=%d, want %d", n, age, tt.want)
1084 }
1085 }
1086 }
1087
1088 type stubDriverStmt struct {
1089 err error
1090 }
1091
1092 func (s stubDriverStmt) Close() error {
1093 return s.err
1094 }
1095
1096 func (s stubDriverStmt) NumInput() int {
1097 return -1
1098 }
1099
1100 func (s stubDriverStmt) Exec(args []driver.Value) (driver.Result, error) {
1101 return nil, nil
1102 }
1103
1104 func (s stubDriverStmt) Query(args []driver.Value) (driver.Rows, error) {
1105 return nil, nil
1106 }
1107
1108
1109 func TestStatementClose(t *testing.T) {
1110 synctest.Test(t, testStatementClose)
1111 }
1112 func testStatementClose(t *testing.T) {
1113 want := errors.New("STMT ERROR")
1114
1115 tests := []struct {
1116 stmt *Stmt
1117 msg string
1118 }{
1119 {&Stmt{stickyErr: want}, "stickyErr not propagated"},
1120 {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
1121 }
1122 for _, test := range tests {
1123 if err := test.stmt.Close(); err != want {
1124 t.Errorf("%s. Got stmt.Close() = %v, want = %v", test.msg, err, want)
1125 }
1126 }
1127 }
1128
1129
1130 func TestStatementQueryRowConcurrent(t *testing.T) {
1131 testDatabase(t, testStatementQueryRowConcurrent)
1132 }
1133 func testStatementQueryRowConcurrent(t *testing.T, db *DB) {
1134 populate(t, db, "people")
1135 stmt, err := db.Prepare("SELECT|people|age|name=?")
1136 if err != nil {
1137 t.Fatalf("Prepare: %v", err)
1138 }
1139 defer stmt.Close()
1140
1141 const n = 10
1142 ch := make(chan error, n)
1143 for i := 0; i < n; i++ {
1144 go func() {
1145 var age int
1146 err := stmt.QueryRow("Alice").Scan(&age)
1147 if err == nil && age != 1 {
1148 err = fmt.Errorf("unexpected age %d", age)
1149 }
1150 ch <- err
1151 }()
1152 }
1153 for i := 0; i < n; i++ {
1154 if err := <-ch; err != nil {
1155 t.Error(err)
1156 }
1157 }
1158 }
1159
1160
1161 func TestBogusPreboundParameters(t *testing.T) {
1162 testDatabase(t, testBogusPreboundParameters)
1163 }
1164 func testBogusPreboundParameters(t *testing.T, db *DB) {
1165 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1166 _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
1167 if err == nil {
1168 t.Fatalf("expected error")
1169 }
1170 if err.Error() != `fakedb: invalid conversion to int32 from "bogusconversion"` {
1171 t.Errorf("unexpected error: %v", err)
1172 }
1173 }
1174
1175 func TestExec(t *testing.T) {
1176 testDatabase(t, testExec)
1177 }
1178 func testExec(t *testing.T, db *DB) {
1179 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1180 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1181 if err != nil {
1182 t.Errorf("Stmt, err = %v, %v", stmt, err)
1183 }
1184 defer stmt.Close()
1185
1186 type execTest struct {
1187 args []any
1188 wantErr string
1189 }
1190 execTests := []execTest{
1191
1192 {[]any{"Brad", 31}, ""},
1193 {[]any{"Brad", int64(31)}, ""},
1194 {[]any{"Bob", "32"}, ""},
1195 {[]any{7, 9}, ""},
1196
1197
1198 {[]any{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"},
1199 {[]any{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`},
1200
1201
1202 {[]any{}, "sql: expected 2 arguments, got 0"},
1203 {[]any{1, 2, 3}, "sql: expected 2 arguments, got 3"},
1204 }
1205 for n, et := range execTests {
1206 _, err := stmt.Exec(et.args...)
1207 errStr := ""
1208 if err != nil {
1209 errStr = err.Error()
1210 }
1211 if errStr != et.wantErr {
1212 t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q",
1213 n, et.args, errStr, et.wantErr)
1214 }
1215 }
1216 }
1217
1218 func TestTxPrepare(t *testing.T) {
1219 testDatabase(t, testTxPrepare)
1220 }
1221 func testTxPrepare(t *testing.T, db *DB) {
1222 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1223 tx, err := db.Begin()
1224 if err != nil {
1225 t.Fatalf("Begin = %v", err)
1226 }
1227 stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
1228 if err != nil {
1229 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1230 }
1231 defer stmt.Close()
1232 _, err = stmt.Exec("Bobby", 7)
1233 if err != nil {
1234 t.Fatalf("Exec = %v", err)
1235 }
1236 err = tx.Commit()
1237 if err != nil {
1238 t.Fatalf("Commit = %v", err)
1239 }
1240
1241 if !stmt.closed {
1242 t.Fatal("Stmt not closed after Commit")
1243 }
1244 }
1245
1246 func TestTxStmt(t *testing.T) {
1247 testDatabase(t, testTxStmt)
1248 }
1249 func testTxStmt(t *testing.T, db *DB) {
1250 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1251 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1252 if err != nil {
1253 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1254 }
1255 defer stmt.Close()
1256 tx, err := db.Begin()
1257 if err != nil {
1258 t.Fatalf("Begin = %v", err)
1259 }
1260 txs := tx.Stmt(stmt)
1261 defer txs.Close()
1262 _, err = txs.Exec("Bobby", 7)
1263 if err != nil {
1264 t.Fatalf("Exec = %v", err)
1265 }
1266 err = tx.Commit()
1267 if err != nil {
1268 t.Fatalf("Commit = %v", err)
1269 }
1270
1271 if !txs.closed {
1272 t.Fatal("Stmt not closed after Commit")
1273 }
1274 }
1275
1276 func TestTxStmtPreparedOnce(t *testing.T) {
1277 testDatabase(t, testTxStmtPreparedOnce)
1278 }
1279 func testTxStmtPreparedOnce(t *testing.T, db *DB) {
1280 exec(t, db, "CREATE|t1|name=string,age=int32")
1281
1282 prepares0 := numPrepares(t, db)
1283
1284
1285 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1286 if err != nil {
1287 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1288 }
1289 defer stmt.Close()
1290
1291 tx, err := db.Begin()
1292 if err != nil {
1293 t.Fatalf("Begin = %v", err)
1294 }
1295
1296 txs1 := tx.Stmt(stmt)
1297 txs2 := tx.Stmt(stmt)
1298
1299 _, err = txs1.Exec("Go", 7)
1300 if err != nil {
1301 t.Fatalf("Exec = %v", err)
1302 }
1303 txs1.Close()
1304
1305 _, err = txs2.Exec("Gopher", 8)
1306 if err != nil {
1307 t.Fatalf("Exec = %v", err)
1308 }
1309 txs2.Close()
1310
1311 err = tx.Commit()
1312 if err != nil {
1313 t.Fatalf("Commit = %v", err)
1314 }
1315
1316 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
1317 t.Errorf("executed %d Prepare statements; want 1", prepares)
1318 }
1319 }
1320
1321 func TestTxStmtClosedRePrepares(t *testing.T) {
1322 testDatabase(t, testTxStmtClosedRePrepares)
1323 }
1324 func testTxStmtClosedRePrepares(t *testing.T, db *DB) {
1325 exec(t, db, "CREATE|t1|name=string,age=int32")
1326
1327 prepares0 := numPrepares(t, db)
1328
1329
1330 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1331 if err != nil {
1332 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1333 }
1334 tx, err := db.Begin()
1335 if err != nil {
1336 t.Fatalf("Begin = %v", err)
1337 }
1338 err = stmt.Close()
1339 if err != nil {
1340 t.Fatalf("stmt.Close() = %v", err)
1341 }
1342
1343 txs := tx.Stmt(stmt)
1344 if txs.stickyErr != nil {
1345 t.Fatal(txs.stickyErr)
1346 }
1347 if txs.parentStmt != nil {
1348 t.Fatal("expected nil parentStmt")
1349 }
1350 _, err = txs.Exec(`Eric`, 82)
1351 if err != nil {
1352 t.Fatalf("txs.Exec = %v", err)
1353 }
1354
1355 err = txs.Close()
1356 if err != nil {
1357 t.Fatalf("txs.Close = %v", err)
1358 }
1359
1360 tx.Rollback()
1361
1362 if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
1363 t.Errorf("executed %d Prepare statements; want 2", prepares)
1364 }
1365 }
1366
1367 func TestParentStmtOutlivesTxStmt(t *testing.T) {
1368 testDatabase(t, testParentStmtOutlivesTxStmt)
1369 }
1370 func testParentStmtOutlivesTxStmt(t *testing.T, db *DB) {
1371 exec(t, db, "CREATE|t1|name=string,age=int32")
1372
1373
1374 db.SetMaxOpenConns(1)
1375
1376 prepares0 := numPrepares(t, db)
1377
1378
1379 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1380 if err != nil {
1381 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1382 }
1383 defer stmt.Close()
1384 tx, err := db.Begin()
1385 if err != nil {
1386 t.Fatalf("Begin = %v", err)
1387 }
1388 txs := tx.Stmt(stmt)
1389 if len(stmt.css) != 1 {
1390 t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css))
1391 }
1392 err = txs.Close()
1393 if err != nil {
1394 t.Fatalf("txs.Close() = %v", err)
1395 }
1396 err = tx.Rollback()
1397 if err != nil {
1398 t.Fatalf("tx.Rollback() = %v", err)
1399 }
1400
1401 _, err = txs.Exec("Suzan", 30)
1402 if err == nil {
1403 t.Fatalf("txs.Exec(), expected err")
1404 }
1405
1406 _, err = stmt.Exec("Janina", 25)
1407 if err != nil {
1408 t.Fatalf("stmt.Exec() = %v", err)
1409 }
1410
1411 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
1412 t.Errorf("executed %d Prepare statements; want 1", prepares)
1413 }
1414 }
1415
1416
1417
1418
1419 func TestTxStmtFromTxStmtRePrepares(t *testing.T) {
1420 testDatabase(t, testTxStmtFromTxStmtRePrepares)
1421 }
1422 func testTxStmtFromTxStmtRePrepares(t *testing.T, db *DB) {
1423 exec(t, db, "CREATE|t1|name=string,age=int32")
1424 prepares0 := numPrepares(t, db)
1425
1426 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1427 if err != nil {
1428 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1429 }
1430 defer stmt.Close()
1431
1432 tx, err := db.Begin()
1433 if err != nil {
1434 t.Fatalf("Begin = %v", err)
1435 }
1436 txs1 := tx.Stmt(stmt)
1437
1438
1439
1440 txs2 := tx.Stmt(txs1)
1441 if txs2.stickyErr != nil {
1442 t.Fatal(txs2.stickyErr)
1443 }
1444 if txs2.parentStmt != nil {
1445 t.Fatal("expected nil parentStmt")
1446 }
1447 _, err = txs2.Exec(`Eric`, 82)
1448 if err != nil {
1449 t.Fatal(err)
1450 }
1451
1452 err = txs1.Close()
1453 if err != nil {
1454 t.Fatalf("txs1.Close = %v", err)
1455 }
1456 err = txs2.Close()
1457 if err != nil {
1458 t.Fatalf("txs1.Close = %v", err)
1459 }
1460 err = tx.Rollback()
1461 if err != nil {
1462 t.Fatalf("tx.Rollback = %v", err)
1463 }
1464
1465 if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
1466 t.Errorf("executed %d Prepare statements; want 2", prepares)
1467 }
1468 }
1469
1470
1471
1472
1473 func TestTxQuery(t *testing.T) {
1474 testDatabase(t, testTxQuery)
1475 }
1476 func testTxQuery(t *testing.T, db *DB) {
1477 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1478 exec(t, db, "INSERT|t1|name=Alice")
1479
1480 tx, err := db.Begin()
1481 if err != nil {
1482 t.Fatal(err)
1483 }
1484 defer tx.Rollback()
1485
1486 r, err := tx.Query("SELECT|t1|name|")
1487 if err != nil {
1488 t.Fatal(err)
1489 }
1490 defer r.Close()
1491
1492 if !r.Next() {
1493 if r.Err() != nil {
1494 t.Fatal(r.Err())
1495 }
1496 t.Fatal("expected one row")
1497 }
1498
1499 var x string
1500 err = r.Scan(&x)
1501 if err != nil {
1502 t.Fatal(err)
1503 }
1504 }
1505
1506 func TestTxQueryInvalid(t *testing.T) {
1507 testDatabase(t, testTxQueryInvalid)
1508 }
1509 func testTxQueryInvalid(t *testing.T, db *DB) {
1510 defer closeDB(t, db)
1511
1512 tx, err := db.Begin()
1513 if err != nil {
1514 t.Fatal(err)
1515 }
1516 defer tx.Rollback()
1517
1518 _, err = tx.Query("SELECT|t1|name|")
1519 if err == nil {
1520 t.Fatal("Error expected")
1521 }
1522 }
1523
1524
1525
1526 func TestTxErrBadConn(t *testing.T) {
1527 synctest.Test(t, testTxErrBadConn)
1528 }
1529 func testTxErrBadConn(t *testing.T) {
1530 db, err := Open("test", fakeDBName+";badConn")
1531 if err != nil {
1532 t.Fatalf("Open: %v", err)
1533 }
1534 if _, err := db.Exec("WIPE"); err != nil {
1535 t.Fatalf("exec wipe: %v", err)
1536 }
1537 defer closeDB(t, db)
1538 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1539 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1540 if err != nil {
1541 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1542 }
1543 defer stmt.Close()
1544 tx, err := db.Begin()
1545 if err != nil {
1546 t.Fatalf("Begin = %v", err)
1547 }
1548 txs := tx.Stmt(stmt)
1549 defer txs.Close()
1550 _, err = txs.Exec("Bobby", 7)
1551 if err != nil {
1552 t.Fatalf("Exec = %v", err)
1553 }
1554 err = tx.Commit()
1555 if err != nil {
1556 t.Fatalf("Commit = %v", err)
1557 }
1558 }
1559
1560 func TestConnQuery(t *testing.T) {
1561 testDatabase(t, testConnQuery)
1562 }
1563 func testConnQuery(t *testing.T, db *DB) {
1564 populate(t, db, "people")
1565
1566 ctx := t.Context()
1567 conn, err := db.Conn(ctx)
1568 if err != nil {
1569 t.Fatal(err)
1570 }
1571 getFakeConn(conn.dc.ci).skipDirtySession = true
1572 defer conn.Close()
1573
1574 var name string
1575 err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", 3).Scan(&name)
1576 if err != nil {
1577 t.Fatal(err)
1578 }
1579 if name != "Chris" {
1580 t.Fatalf("unexpected result, got %q want Chris", name)
1581 }
1582
1583 err = conn.PingContext(ctx)
1584 if err != nil {
1585 t.Fatal(err)
1586 }
1587 }
1588
1589 func TestConnRaw(t *testing.T) {
1590 testDatabase(t, testConnRaw)
1591 }
1592 func testConnRaw(t *testing.T, db *DB) {
1593 populate(t, db, "people")
1594
1595 ctx := t.Context()
1596 conn, err := db.Conn(ctx)
1597 if err != nil {
1598 t.Fatal(err)
1599 }
1600 getFakeConn(conn.dc.ci).skipDirtySession = true
1601 defer conn.Close()
1602
1603 sawFunc := false
1604 err = conn.Raw(func(dc any) error {
1605 sawFunc = true
1606 _ = getFakeConn(dc.(driver.Conn))
1607 return nil
1608 })
1609 if err != nil {
1610 t.Fatal(err)
1611 }
1612 if !sawFunc {
1613 t.Fatal("Raw func not called")
1614 }
1615
1616 func() {
1617 defer func() {
1618 x := recover()
1619 if x == nil {
1620 t.Fatal("expected panic")
1621 }
1622 conn.closemu.Lock()
1623 closed := conn.dc == nil
1624 conn.closemu.Unlock()
1625 if !closed {
1626 t.Fatal("expected connection to be closed after panic")
1627 }
1628 }()
1629 err = conn.Raw(func(dc any) error {
1630 panic("Conn.Raw panic should return an error")
1631 })
1632 t.Fatal("expected panic from Raw func")
1633 }()
1634 }
1635
1636 func TestCursorFake(t *testing.T) {
1637 testDatabase(t, testCursorFake)
1638 }
1639 func testCursorFake(t *testing.T, db *DB) {
1640 populate(t, db, "people")
1641
1642 ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
1643 defer cancel()
1644
1645 exec(t, db, "CREATE|peoplecursor|list=table")
1646 exec(t, db, "INSERT|peoplecursor|list=people!name!age")
1647
1648 rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`)
1649 if err != nil {
1650 t.Fatal(err)
1651 }
1652 defer rows.Close()
1653
1654 if !rows.Next() {
1655 t.Fatal("no rows")
1656 }
1657 var cursor = &Rows{}
1658 err = rows.Scan(cursor)
1659 if err != nil {
1660 t.Fatal(err)
1661 }
1662 defer cursor.Close()
1663
1664 const expectedRows = 3
1665 var currentRow int64
1666
1667 var n int64
1668 var s string
1669 for cursor.Next() {
1670 currentRow++
1671 err = cursor.Scan(&s, &n)
1672 if err != nil {
1673 t.Fatal(err)
1674 }
1675 if n != currentRow {
1676 t.Errorf("expected number(Age)=%d, got %d", currentRow, n)
1677 }
1678 }
1679 if currentRow != expectedRows {
1680 t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow)
1681 }
1682 }
1683
1684 func TestCursorDoubleRowsPointer(t *testing.T) {
1685 testDatabase(t, testCursorDoubleRowsPointer)
1686 }
1687 func testCursorDoubleRowsPointer(t *testing.T, db *DB) {
1688 exec(t, db, "CREATE|table1|col=string")
1689 exec(t, db, "INSERT|table1|col=value")
1690 exec(t, db, "CREATE|cursor|list=table")
1691 exec(t, db, "INSERT|cursor|list=table1!col")
1692
1693 rows, err := db.QueryContext(t.Context(), `SELECT|cursor|list|`)
1694 if err != nil {
1695 t.Fatal(err)
1696 }
1697 defer rows.Close()
1698
1699 if !rows.Next() {
1700 t.Fatal("no rows")
1701 }
1702 var cursor *Rows
1703 if err := rows.Scan(&cursor); err != nil {
1704 t.Fatal(err)
1705 }
1706 defer cursor.Close()
1707
1708 if !cursor.Next() {
1709 t.Fatal("no child rows")
1710 }
1711 var col string
1712 if err := cursor.Scan(&col); err != nil {
1713 t.Fatal(err)
1714 }
1715 if got, want := col, "value"; got != want {
1716 t.Errorf("read col=%q, want %q", got, want)
1717 }
1718 }
1719
1720 func TestCursorNull(t *testing.T) {
1721 testDatabase(t, testCursorNull)
1722 }
1723 func testCursorNull(t *testing.T, db *DB) {
1724 exec(t, db, "CREATE|cursor|list=nulltable")
1725 exec(t, db, "INSERT|cursor|list=?", nil)
1726
1727 rows, err := db.QueryContext(t.Context(), `SELECT|cursor|list|`)
1728 if err != nil {
1729 t.Fatal(err)
1730 }
1731 defer rows.Close()
1732
1733 if !rows.Next() {
1734 t.Fatal("no rows")
1735 }
1736
1737 var cursor *Rows
1738 if err := rows.Scan(&cursor); err != nil {
1739 t.Fatal(err)
1740 }
1741 if cursor != nil {
1742 t.Errorf("Scan returned cursor, expected nil")
1743 }
1744 }
1745
1746
1747
1748 func TestCursorCancel(t *testing.T) {
1749 for _, test := range []struct {
1750 name string
1751 cancelOn string
1752 want []string
1753 }{{
1754
1755 name: "no cancel",
1756 want: []string{
1757 "table1",
1758 "1.1",
1759 "1.2",
1760 "table2",
1761 "2.1",
1762 "2.2",
1763 },
1764 }, {
1765 name: "outer cancel",
1766 cancelOn: "table2",
1767 want: []string{
1768 "table1",
1769 "1.1",
1770 "1.2",
1771 "table2",
1772 },
1773 }, {
1774 name: "inner cancel",
1775 cancelOn: "1.1",
1776 want: []string{
1777 "table1",
1778 "1.1",
1779 "table2",
1780 "2.1",
1781 "2.2",
1782 },
1783 }} {
1784 t.Run(test.name, func(t *testing.T) {
1785 testDatabase(t, func(t *testing.T, db *DB) {
1786 testCursorCancel(t, db, test.cancelOn, test.want)
1787 })
1788 })
1789 }
1790 }
1791 func testCursorCancel(t *testing.T, db *DB, cancelOn string, want []string) {
1792 exec(t, db, "CREATE|table1|col=string")
1793 exec(t, db, "INSERT|table1|col=1.1")
1794 exec(t, db, "INSERT|table1|col=1.2")
1795 exec(t, db, "CREATE|table2|col=string")
1796 exec(t, db, "INSERT|table2|col=2.1")
1797 exec(t, db, "INSERT|table2|col=2.2")
1798
1799 exec(t, db, "CREATE|cursor|name=string,list=table")
1800 exec(t, db, "INSERT|cursor|name=table1,list=table1!col")
1801 exec(t, db, "INSERT|cursor|name=table2,list=table2!col")
1802
1803 ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
1804 defer cancel()
1805
1806 rows, err := db.QueryContext(ctx, `SELECT|cursor|name,list|`)
1807 if err != nil {
1808 t.Fatal(err)
1809 }
1810 defer rows.Close()
1811
1812 var got []string
1813 for rows.Next() {
1814 var name string
1815 cursor := &Rows{}
1816 if err := rows.Scan(&name, cursor); err != nil {
1817 t.Fatal(err)
1818 }
1819 got = append(got, name)
1820 if name == cancelOn {
1821 rows.Close()
1822 }
1823 for cursor.Next() {
1824 var col string
1825 if err := cursor.Scan(&col); err != nil {
1826 t.Fatal(err)
1827 }
1828 got = append(got, col)
1829 if col == cancelOn {
1830 cursor.Close()
1831 }
1832 }
1833 }
1834
1835 if !slices.Equal(got, want) {
1836 t.Errorf("cancel after reading %q:\ngot: %v\nwant: %v", cancelOn, got, want)
1837 }
1838 }
1839
1840 func TestInvalidNilValues(t *testing.T) {
1841 var date1 time.Time
1842 var date2 int
1843
1844 tests := []struct {
1845 name string
1846 input any
1847 expectedError string
1848 }{
1849 {
1850 name: "time.Time",
1851 input: &date1,
1852 expectedError: `sql: Scan error on column index 0, name "bdate": unsupported Scan, storing driver.Value type <nil> into type *time.Time`,
1853 },
1854 {
1855 name: "int",
1856 input: &date2,
1857 expectedError: `sql: Scan error on column index 0, name "bdate": converting NULL to int is unsupported`,
1858 },
1859 }
1860
1861 for _, tt := range tests {
1862 synctestSubtest(t, tt.name, func(t *testing.T) {
1863 db := newTestDB(t, "people")
1864 defer closeDB(t, db)
1865
1866 ctx := t.Context()
1867 conn, err := db.Conn(ctx)
1868 if err != nil {
1869 t.Fatal(err)
1870 }
1871 getFakeConn(conn.dc.ci).skipDirtySession = true
1872 defer conn.Close()
1873
1874 err = conn.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 1).Scan(tt.input)
1875 if err == nil {
1876 t.Fatal("expected error when querying nil column, but succeeded")
1877 }
1878 if err.Error() != tt.expectedError {
1879 t.Fatalf("Expected error: %s\nReceived: %s", tt.expectedError, err.Error())
1880 }
1881
1882 err = conn.PingContext(ctx)
1883 if err != nil {
1884 t.Fatal(err)
1885 }
1886 })
1887 }
1888 }
1889
1890 func TestConnTx(t *testing.T) {
1891 testDatabase(t, testConnTx)
1892 }
1893 func testConnTx(t *testing.T, db *DB) {
1894 populate(t, db, "people")
1895
1896 ctx := t.Context()
1897 conn, err := db.Conn(ctx)
1898 if err != nil {
1899 t.Fatal(err)
1900 }
1901 getFakeConn(conn.dc.ci).skipDirtySession = true
1902 defer conn.Close()
1903
1904 tx, err := conn.BeginTx(ctx, nil)
1905 if err != nil {
1906 t.Fatal(err)
1907 }
1908 insertName, insertAge := "Nancy", 33
1909 _, err = tx.ExecContext(ctx, "INSERT|people|name=?,age=?,photo=APHOTO", insertName, insertAge)
1910 if err != nil {
1911 t.Fatal(err)
1912 }
1913 err = tx.Commit()
1914 if err != nil {
1915 t.Fatal(err)
1916 }
1917
1918 var selectName string
1919 err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", insertAge).Scan(&selectName)
1920 if err != nil {
1921 t.Fatal(err)
1922 }
1923 if selectName != insertName {
1924 t.Fatalf("got %q want %q", selectName, insertName)
1925 }
1926 }
1927
1928
1929
1930
1931 func TestConnIsValid(t *testing.T) {
1932 testDatabase(t, testConnIsValid, requireFeature("Validator"))
1933 }
1934 func testConnIsValid(t *testing.T, db *DB) {
1935 populate(t, db, "people")
1936
1937 db.SetMaxOpenConns(1)
1938
1939 ctx := context.Background()
1940
1941 c, err := db.Conn(ctx)
1942 if err != nil {
1943 t.Fatal(err)
1944 }
1945
1946 err = c.Raw(func(raw any) error {
1947 dc := getFakeConn(raw.(driver.Conn))
1948 dc.stickyBad = true
1949 return nil
1950 })
1951 if err != nil {
1952 t.Fatal(err)
1953 }
1954 c.Close()
1955
1956 if len(db.freeConn) > 0 && getFakeConn(db.freeConn[0].ci).stickyBad {
1957 t.Fatal("bad connection returned to pool; expected bad connection to be discarded")
1958 }
1959 }
1960
1961
1962
1963 func TestIssue2542Deadlock(t *testing.T) {
1964 testDatabase(t, testIssue2542Deadlock)
1965 }
1966 func testIssue2542Deadlock(t *testing.T, db *DB) {
1967 populate(t, db, "people")
1968 closeDB(t, db)
1969 for i := 0; i < 2; i++ {
1970 _, err := db.Query("SELECT|people|age,name|")
1971 if err == nil {
1972 t.Fatalf("expected error")
1973 }
1974 }
1975 }
1976
1977
1978 func TestCloseStmtBeforeRows(t *testing.T) {
1979 testDatabase(t, testCloseStmtBeforeRows)
1980 }
1981 func testCloseStmtBeforeRows(t *testing.T, db *DB) {
1982 populate(t, db, "people")
1983
1984 s, err := db.Prepare("SELECT|people|name|")
1985 if err != nil {
1986 t.Fatal(err)
1987 }
1988
1989 r, err := s.Query()
1990 if err != nil {
1991 s.Close()
1992 t.Fatal(err)
1993 }
1994
1995 err = s.Close()
1996 if err != nil {
1997 t.Fatal(err)
1998 }
1999
2000 r.Close()
2001 }
2002
2003
2004
2005 func TestNullByteSlice(t *testing.T) {
2006 testDatabase(t, testNullByteSlice)
2007 }
2008 func testNullByteSlice(t *testing.T, db *DB) {
2009 exec(t, db, "CREATE|t|id=int32,name=nullstring")
2010 exec(t, db, "INSERT|t|id=10,name=?", nil)
2011
2012 var name []byte
2013
2014 err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
2015 if err != nil {
2016 t.Fatal(err)
2017 }
2018 if name != nil {
2019 t.Fatalf("name []byte should be nil for null column value, got: %#v", name)
2020 }
2021
2022 exec(t, db, "INSERT|t|id=11,name=?", "bob")
2023 err = db.QueryRow("SELECT|t|name|id=?", 11).Scan(&name)
2024 if err != nil {
2025 t.Fatal(err)
2026 }
2027 if string(name) != "bob" {
2028 t.Fatalf("name []byte should be bob, got: %q", string(name))
2029 }
2030 }
2031
2032 func TestPointerParamsAndScans(t *testing.T) {
2033 testDatabase(t, testPointerParamsAndScans)
2034 }
2035 func testPointerParamsAndScans(t *testing.T, db *DB) {
2036 exec(t, db, "CREATE|t|id=int32,name=nullstring")
2037
2038 bob := "bob"
2039 var name *string
2040
2041 name = &bob
2042 exec(t, db, "INSERT|t|id=10,name=?", name)
2043 name = nil
2044 exec(t, db, "INSERT|t|id=20,name=?", name)
2045
2046 err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
2047 if err != nil {
2048 t.Fatalf("querying id 10: %v", err)
2049 }
2050 if name == nil {
2051 t.Errorf("id 10's name = nil; want bob")
2052 } else if *name != "bob" {
2053 t.Errorf("id 10's name = %q; want bob", *name)
2054 }
2055
2056 err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name)
2057 if err != nil {
2058 t.Fatalf("querying id 20: %v", err)
2059 }
2060 if name != nil {
2061 t.Errorf("id 20 = %q; want nil", *name)
2062 }
2063 }
2064
2065 func TestQueryRowClosingStmt(t *testing.T) {
2066 testDatabase(t, testQueryRowClosingStmt)
2067 }
2068 func testQueryRowClosingStmt(t *testing.T, db *DB) {
2069 populate(t, db, "people")
2070 var name string
2071 var age int
2072 err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name)
2073 if err != nil {
2074 t.Fatal(err)
2075 }
2076 if len(db.freeConn) != 1 {
2077 t.Fatalf("expected 1 free conn")
2078 }
2079 fakeConn := getFakeConn(db.freeConn[0].ci)
2080 if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
2081 t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
2082 }
2083 }
2084
2085 var atomicRowsCloseHook atomic.Value
2086
2087 func init() {
2088 rowsCloseHook = func() func(*Rows, *error) {
2089 fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
2090 return fn
2091 }
2092 }
2093
2094 func setRowsCloseHook(fn func(*Rows, *error)) {
2095 if fn == nil {
2096
2097
2098 fn = func(*Rows, *error) {}
2099 }
2100 atomicRowsCloseHook.Store(fn)
2101 }
2102
2103
2104 func TestIssue6651(t *testing.T) {
2105 testDatabase(t, testIssue6651)
2106 }
2107 func testIssue6651(t *testing.T, db *DB) {
2108 populate(t, db, "people")
2109
2110 var v string
2111
2112 want := "error in rows.Next"
2113 rowsCursorNextHook = func(dest []driver.Value) error {
2114 return errors.New(want)
2115 }
2116 defer func() { rowsCursorNextHook = nil }()
2117
2118 err := db.QueryRow("SELECT|people|name|").Scan(&v)
2119 if err == nil || err.Error() != want {
2120 t.Errorf("error = %q; want %q", err, want)
2121 }
2122 rowsCursorNextHook = nil
2123
2124 want = "error in rows.Close"
2125 setRowsCloseHook(func(rows *Rows, err *error) {
2126 *err = errors.New(want)
2127 })
2128 defer setRowsCloseHook(nil)
2129 err = db.QueryRow("SELECT|people|name|").Scan(&v)
2130 if err == nil || err.Error() != want {
2131 t.Errorf("error = %q; want %q", err, want)
2132 }
2133 }
2134
2135 type nullTestRow struct {
2136 nullParam any
2137 notNullParam any
2138 scanNullVal any
2139 }
2140
2141 type nullTestSpec struct {
2142 nullType string
2143 notNullType string
2144 rows [6]nullTestRow
2145 }
2146
2147 func TestNullStringParam(t *testing.T) {
2148 spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
2149 {NullString{"aqua", true}, "", NullString{"aqua", true}},
2150 {NullString{"brown", false}, "", NullString{"", false}},
2151 {"chartreuse", "", NullString{"chartreuse", true}},
2152 {NullString{"darkred", true}, "", NullString{"darkred", true}},
2153 {NullString{"eel", false}, "", NullString{"", false}},
2154 {"foo", NullString{"black", false}, nil},
2155 }}
2156 testDatabase(t, func(t *testing.T, db *DB) {
2157 nullTestRun(t, spec, db)
2158 })
2159 }
2160
2161 func TestGenericNullStringParam(t *testing.T) {
2162 spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
2163 {Null[string]{"aqua", true}, "", Null[string]{"aqua", true}},
2164 {Null[string]{"brown", false}, "", Null[string]{"", false}},
2165 {"chartreuse", "", Null[string]{"chartreuse", true}},
2166 {Null[string]{"darkred", true}, "", Null[string]{"darkred", true}},
2167 {Null[string]{"eel", false}, "", Null[string]{"", false}},
2168 {"foo", Null[string]{"black", false}, nil},
2169 }}
2170 testDatabase(t, func(t *testing.T, db *DB) {
2171 nullTestRun(t, spec, db)
2172 })
2173 }
2174
2175 func TestNullInt64Param(t *testing.T) {
2176 spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{
2177 {NullInt64{31, true}, 1, NullInt64{31, true}},
2178 {NullInt64{-22, false}, 1, NullInt64{0, false}},
2179 {22, 1, NullInt64{22, true}},
2180 {NullInt64{33, true}, 1, NullInt64{33, true}},
2181 {NullInt64{222, false}, 1, NullInt64{0, false}},
2182 {0, NullInt64{31, false}, nil},
2183 }}
2184 testDatabase(t, func(t *testing.T, db *DB) {
2185 nullTestRun(t, spec, db)
2186 })
2187 }
2188
2189 func TestNullInt32Param(t *testing.T) {
2190 spec := nullTestSpec{"nullint32", "int32", [6]nullTestRow{
2191 {NullInt32{31, true}, 1, NullInt32{31, true}},
2192 {NullInt32{-22, false}, 1, NullInt32{0, false}},
2193 {22, 1, NullInt32{22, true}},
2194 {NullInt32{33, true}, 1, NullInt32{33, true}},
2195 {NullInt32{222, false}, 1, NullInt32{0, false}},
2196 {0, NullInt32{31, false}, nil},
2197 }}
2198 testDatabase(t, func(t *testing.T, db *DB) {
2199 nullTestRun(t, spec, db)
2200 })
2201 }
2202
2203 func TestNullInt16Param(t *testing.T) {
2204 spec := nullTestSpec{"nullint16", "int16", [6]nullTestRow{
2205 {NullInt16{31, true}, 1, NullInt16{31, true}},
2206 {NullInt16{-22, false}, 1, NullInt16{0, false}},
2207 {22, 1, NullInt16{22, true}},
2208 {NullInt16{33, true}, 1, NullInt16{33, true}},
2209 {NullInt16{222, false}, 1, NullInt16{0, false}},
2210 {0, NullInt16{31, false}, nil},
2211 }}
2212 testDatabase(t, func(t *testing.T, db *DB) {
2213 nullTestRun(t, spec, db)
2214 })
2215 }
2216
2217 func TestNullByteParam(t *testing.T) {
2218 spec := nullTestSpec{"nullbyte", "byte", [6]nullTestRow{
2219 {NullByte{31, true}, 1, NullByte{31, true}},
2220 {NullByte{0, false}, 1, NullByte{0, false}},
2221 {22, 1, NullByte{22, true}},
2222 {NullByte{33, true}, 1, NullByte{33, true}},
2223 {NullByte{222, false}, 1, NullByte{0, false}},
2224 {0, NullByte{31, false}, nil},
2225 }}
2226 testDatabase(t, func(t *testing.T, db *DB) {
2227 nullTestRun(t, spec, db)
2228 })
2229 }
2230
2231 func TestNullFloat64Param(t *testing.T) {
2232 spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{
2233 {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}},
2234 {NullFloat64{13.1, false}, 1, NullFloat64{0, false}},
2235 {-22.9, 1, NullFloat64{-22.9, true}},
2236 {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}},
2237 {NullFloat64{222, false}, 1, NullFloat64{0, false}},
2238 {10, NullFloat64{31.2, false}, nil},
2239 }}
2240 testDatabase(t, func(t *testing.T, db *DB) {
2241 nullTestRun(t, spec, db)
2242 })
2243 }
2244
2245 func TestNullBoolParam(t *testing.T) {
2246 spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{
2247 {NullBool{false, true}, true, NullBool{false, true}},
2248 {NullBool{true, false}, false, NullBool{false, false}},
2249 {true, true, NullBool{true, true}},
2250 {NullBool{true, true}, false, NullBool{true, true}},
2251 {NullBool{true, false}, true, NullBool{false, false}},
2252 {true, NullBool{true, false}, nil},
2253 }}
2254 testDatabase(t, func(t *testing.T, db *DB) {
2255 nullTestRun(t, spec, db)
2256 })
2257 }
2258
2259 func TestNullTimeParam(t *testing.T) {
2260 t0 := time.Time{}
2261 t1 := time.Date(2000, 1, 1, 8, 9, 10, 11, time.UTC)
2262 t2 := time.Date(2010, 1, 1, 8, 9, 10, 11, time.UTC)
2263 spec := nullTestSpec{"nulldatetime", "datetime", [6]nullTestRow{
2264 {NullTime{t1, true}, t2, NullTime{t1, true}},
2265 {NullTime{t1, false}, t2, NullTime{t0, false}},
2266 {t1, t2, NullTime{t1, true}},
2267 {NullTime{t1, true}, t2, NullTime{t1, true}},
2268 {NullTime{t1, false}, t2, NullTime{t0, false}},
2269 {t2, NullTime{t1, false}, nil},
2270 }}
2271 testDatabase(t, func(t *testing.T, db *DB) {
2272 nullTestRun(t, spec, db)
2273 })
2274 }
2275
2276 func TestNullUUIDParam(t *testing.T) {
2277 u0 := uuid.UUID{}
2278 u1 := uuid.MustParse("46cd2740-6081-4289-a659-03b61ebb92f7")
2279 u2 := uuid.MustParse("46cd2740-6081-4289-a659-03b61ebb92f7")
2280 spec := nullTestSpec{"nulluuid", "uuid", [6]nullTestRow{
2281 {Null[uuid.UUID]{u1, true}, u2, Null[uuid.UUID]{u1, true}},
2282 {Null[uuid.UUID]{u1, false}, u2, Null[uuid.UUID]{u0, false}},
2283 {u1, u2, Null[uuid.UUID]{u1, true}},
2284 {Null[uuid.UUID]{u1, true}, u2, Null[uuid.UUID]{u1, true}},
2285 {Null[uuid.UUID]{u1, false}, u2, Null[uuid.UUID]{u0, false}},
2286 {u2, Null[uuid.UUID]{u1, false}, nil},
2287 }}
2288 testDatabase(t, func(t *testing.T, db *DB) {
2289 nullTestRun(t, spec, db)
2290 })
2291 }
2292
2293 func nullTestRun(t *testing.T, spec nullTestSpec, db *DB) {
2294 exec(t, db, fmt.Sprintf("CREATE|t|id=int32,name=string,nullf=%s,notnullf=%s", spec.nullType, spec.notNullType))
2295
2296
2297 exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 1, "alice", spec.rows[0].nullParam, spec.rows[0].notNullParam)
2298 exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 2, "bob", spec.rows[1].nullParam, spec.rows[1].notNullParam)
2299
2300
2301 stmt, err := db.Prepare("INSERT|t|id=?,name=?,nullf=?,notnullf=?")
2302 if err != nil {
2303 t.Fatalf("prepare: %v", err)
2304 }
2305 defer stmt.Close()
2306 if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
2307 t.Errorf("exec insert chris: %v", err)
2308 }
2309 if _, err := stmt.Exec(4, "dave", spec.rows[3].nullParam, spec.rows[3].notNullParam); err != nil {
2310 t.Errorf("exec insert dave: %v", err)
2311 }
2312 if _, err := stmt.Exec(5, "eleanor", spec.rows[4].nullParam, spec.rows[4].notNullParam); err != nil {
2313 t.Errorf("exec insert eleanor: %v", err)
2314 }
2315
2316
2317 row5 := spec.rows[5]
2318 if _, err := stmt.Exec(6, "bob", row5.nullParam, row5.notNullParam); err == nil {
2319 t.Errorf("expected error inserting nil val with prepared statement Exec: NULL=%#v, NOT-NULL=%#v", row5.nullParam, row5.notNullParam)
2320 }
2321
2322 _, err = db.Exec("INSERT|t|id=?,name=?,nullf=?", 999, nil, nil)
2323 if err == nil {
2324
2325
2326
2327
2328
2329
2330 }
2331
2332 paramtype := reflect.TypeOf(spec.rows[0].nullParam)
2333 bindVal := reflect.New(paramtype).Interface()
2334
2335 for i := 0; i < 5; i++ {
2336 id := i + 1
2337 if err := db.QueryRow("SELECT|t|nullf|id=?", id).Scan(bindVal); err != nil {
2338 t.Errorf("id=%d Scan: %v", id, err)
2339 }
2340 bindValDeref := reflect.ValueOf(bindVal).Elem().Interface()
2341 if !reflect.DeepEqual(bindValDeref, spec.rows[i].scanNullVal) {
2342 t.Errorf("id=%d got %#v, want %#v", id, bindValDeref, spec.rows[i].scanNullVal)
2343 }
2344 }
2345 }
2346
2347
2348 func TestQueryRowNilScanDest(t *testing.T) {
2349 testDatabase(t, testQueryRowNilScanDest)
2350 }
2351 func testQueryRowNilScanDest(t *testing.T, db *DB) {
2352 populate(t, db, "people")
2353 var name *string
2354 err := db.QueryRow("SELECT|people|name|").Scan(name)
2355 want := `sql: Scan error on column index 0, name "name": destination pointer is nil`
2356 if err == nil || err.Error() != want {
2357 t.Errorf("error = %q; want %q", err.Error(), want)
2358 }
2359 }
2360
2361 func TestIssue4902(t *testing.T) {
2362 testDatabase(t, testIssue4902)
2363 }
2364 func testIssue4902(t *testing.T, db *DB) {
2365 populate(t, db, "people")
2366
2367 driver := db.Driver().(*fakeDriver)
2368 opens0 := driver.openCount
2369
2370 var stmt *Stmt
2371 var err error
2372 for i := 0; i < 10; i++ {
2373 stmt, err = db.Prepare("SELECT|people|name|")
2374 if err != nil {
2375 t.Fatal(err)
2376 }
2377 err = stmt.Close()
2378 if err != nil {
2379 t.Fatal(err)
2380 }
2381 }
2382
2383 opens := driver.openCount - opens0
2384 if opens > 1 {
2385 t.Errorf("opens = %d; want <= 1", opens)
2386 t.Logf("db = %#v", db)
2387 t.Logf("driver = %#v", driver)
2388 t.Logf("stmt = %#v", stmt)
2389 }
2390 }
2391
2392
2393
2394 func TestSimultaneousQueries(t *testing.T) {
2395 testDatabase(t, testSimultaneousQueries)
2396 }
2397 func testSimultaneousQueries(t *testing.T, db *DB) {
2398 populate(t, db, "people")
2399
2400 tx, err := db.Begin()
2401 if err != nil {
2402 t.Fatal(err)
2403 }
2404 defer tx.Rollback()
2405
2406 r1, err := tx.Query("SELECT|people|name|")
2407 if err != nil {
2408 t.Fatal(err)
2409 }
2410 defer r1.Close()
2411
2412 r2, err := tx.Query("SELECT|people|name|")
2413 if err != nil {
2414 t.Fatal(err)
2415 }
2416 defer r2.Close()
2417 }
2418
2419 func TestMaxIdleConns(t *testing.T) {
2420 testDatabase(t, testMaxIdleConns)
2421 }
2422 func testMaxIdleConns(t *testing.T, db *DB) {
2423 populate(t, db, "people")
2424
2425 tx, err := db.Begin()
2426 if err != nil {
2427 t.Fatal(err)
2428 }
2429 tx.Commit()
2430 if got := len(db.freeConn); got != 1 {
2431 t.Errorf("freeConns = %d; want 1", got)
2432 }
2433
2434 db.SetMaxIdleConns(0)
2435
2436 if got := len(db.freeConn); got != 0 {
2437 t.Errorf("freeConns after set to zero = %d; want 0", got)
2438 }
2439
2440 tx, err = db.Begin()
2441 if err != nil {
2442 t.Fatal(err)
2443 }
2444 tx.Commit()
2445 if got := len(db.freeConn); got != 0 {
2446 t.Errorf("freeConns = %d; want 0", got)
2447 }
2448 }
2449
2450 func TestMaxOpenConns(t *testing.T) {
2451 testDatabase(t, testMaxOpenConns)
2452 }
2453 func testMaxOpenConns(t *testing.T, db *DB) {
2454 if testing.Short() {
2455 t.Skip("skipping in short mode")
2456 }
2457 defer setHookpostCloseConn(nil)
2458 setHookpostCloseConn(func(_ *fakeConn, err error) {
2459 if err != nil {
2460 t.Errorf("Error closing fakeConn: %v", err)
2461 }
2462 })
2463
2464 populate(t, db, "magicquery")
2465
2466 driver := db.Driver().(*fakeDriver)
2467
2468
2469
2470 db.clearAllConns(t)
2471
2472 driver.mu.Lock()
2473 opens0 := driver.openCount
2474 closes0 := driver.closeCount
2475 driver.mu.Unlock()
2476
2477 db.SetMaxIdleConns(10)
2478 db.SetMaxOpenConns(10)
2479
2480 stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
2481 if err != nil {
2482 t.Fatal(err)
2483 }
2484
2485
2486 const (
2487 nquery = 50
2488 sleepMillis = 25
2489 nbatch = 2
2490 )
2491 var wg sync.WaitGroup
2492 for batch := 0; batch < nbatch; batch++ {
2493 for i := 0; i < nquery; i++ {
2494 wg.Add(1)
2495 go func() {
2496 defer wg.Done()
2497 var op string
2498 if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
2499 t.Error(err)
2500 }
2501 }()
2502 }
2503
2504 wg.Wait()
2505 }
2506
2507 if g, w := db.numFreeConns(), 10; g != w {
2508 t.Errorf("free conns = %d; want %d", g, w)
2509 }
2510
2511 if n := db.numDeps(); n > 20 {
2512 t.Errorf("number of dependencies = %d; expected <= 20", n)
2513 db.dumpDeps(t)
2514 }
2515
2516 driver.mu.Lock()
2517 opens := driver.openCount - opens0
2518 closes := driver.closeCount - closes0
2519 driver.mu.Unlock()
2520
2521 if opens > 10 {
2522 t.Logf("open calls = %d", opens)
2523 t.Logf("close calls = %d", closes)
2524 t.Errorf("db connections opened = %d; want <= 10", opens)
2525 db.dumpDeps(t)
2526 }
2527
2528 if err := stmt.Close(); err != nil {
2529 t.Fatal(err)
2530 }
2531
2532 if g, w := db.numFreeConns(), 10; g != w {
2533 t.Errorf("free conns = %d; want %d", g, w)
2534 }
2535
2536 if n := db.numDeps(); n > 10 {
2537 t.Errorf("number of dependencies = %d; expected <= 10", n)
2538 db.dumpDeps(t)
2539 }
2540
2541 db.SetMaxOpenConns(5)
2542
2543 if g, w := db.numFreeConns(), 5; g != w {
2544 t.Errorf("free conns = %d; want %d", g, w)
2545 }
2546
2547 if n := db.numDeps(); n > 5 {
2548 t.Errorf("number of dependencies = %d; expected 0", n)
2549 db.dumpDeps(t)
2550 }
2551
2552 db.SetMaxOpenConns(0)
2553
2554 if g, w := db.numFreeConns(), 5; g != w {
2555 t.Errorf("free conns = %d; want %d", g, w)
2556 }
2557
2558 if n := db.numDeps(); n > 5 {
2559 t.Errorf("number of dependencies = %d; expected 0", n)
2560 db.dumpDeps(t)
2561 }
2562
2563 db.clearAllConns(t)
2564 }
2565
2566
2567
2568 func TestMaxOpenConnsOnBusy(t *testing.T) {
2569 testDatabase(t, testMaxOpenConnsOnBusy)
2570 }
2571 func testMaxOpenConnsOnBusy(t *testing.T, db *DB) {
2572 defer setHookpostCloseConn(nil)
2573 setHookpostCloseConn(func(_ *fakeConn, err error) {
2574 if err != nil {
2575 t.Errorf("Error closing fakeConn: %v", err)
2576 }
2577 })
2578
2579 populate(t, db, "magicquery")
2580
2581 db.SetMaxOpenConns(3)
2582
2583 ctx := context.Background()
2584
2585 conn0, err := db.conn(ctx, cachedOrNewConn)
2586 if err != nil {
2587 t.Fatalf("db open conn fail: %v", err)
2588 }
2589
2590 conn1, err := db.conn(ctx, cachedOrNewConn)
2591 if err != nil {
2592 t.Fatalf("db open conn fail: %v", err)
2593 }
2594
2595 conn2, err := db.conn(ctx, cachedOrNewConn)
2596 if err != nil {
2597 t.Fatalf("db open conn fail: %v", err)
2598 }
2599
2600 if g, w := db.numOpen, 3; g != w {
2601 t.Errorf("free conns = %d; want %d", g, w)
2602 }
2603
2604 db.SetMaxOpenConns(2)
2605 if g, w := db.numOpen, 3; g != w {
2606 t.Errorf("free conns = %d; want %d", g, w)
2607 }
2608
2609 conn0.releaseConn(nil)
2610 conn1.releaseConn(nil)
2611 if g, w := db.numOpen, 2; g != w {
2612 t.Errorf("free conns = %d; want %d", g, w)
2613 }
2614
2615 conn2.releaseConn(nil)
2616 if g, w := db.numOpen, 2; g != w {
2617 t.Errorf("free conns = %d; want %d", g, w)
2618 }
2619 }
2620
2621
2622
2623 func TestPendingConnsAfterErr(t *testing.T) {
2624 synctest.Test(t, testPendingConnsAfterErr)
2625 }
2626 func testPendingConnsAfterErr(t *testing.T) {
2627 const (
2628 maxOpen = 2
2629 tryOpen = maxOpen*2 + 2
2630 )
2631
2632
2633 db, err := Open("test", fakeDBName)
2634 if err != nil {
2635 t.Fatalf("Open: %v", err)
2636 }
2637 defer closeDB(t, db)
2638 defer func() {
2639 for k, v := range db.lastPut {
2640 t.Logf("%p: %v", k, v)
2641 }
2642 }()
2643
2644 db.SetMaxOpenConns(maxOpen)
2645 db.SetMaxIdleConns(0)
2646
2647 errOffline := errors.New("db offline")
2648
2649 defer func() { setHookOpenErr(nil) }()
2650
2651 errs := make(chan error, tryOpen)
2652
2653 var opening sync.WaitGroup
2654 opening.Add(tryOpen)
2655
2656 setHookOpenErr(func() error {
2657
2658 opening.Wait()
2659 return errOffline
2660 })
2661
2662 for i := 0; i < tryOpen; i++ {
2663 go func() {
2664 opening.Done()
2665 _, err := db.Exec("will never run")
2666 errs <- err
2667 }()
2668 }
2669
2670 opening.Wait()
2671
2672 const timeout = 5 * time.Second
2673 to := time.NewTimer(timeout)
2674 defer to.Stop()
2675
2676
2677 for i := 0; i < tryOpen; i++ {
2678 select {
2679 case err := <-errs:
2680 if got, want := err, errOffline; got != want {
2681 t.Errorf("unexpected err: got %v, want %v", got, want)
2682 }
2683 case <-to.C:
2684 t.Fatalf("orphaned connection request(s), still waiting after %v", timeout)
2685 }
2686 }
2687
2688
2689 tick := time.NewTicker(3 * time.Millisecond)
2690 defer tick.Stop()
2691 for {
2692 select {
2693 case <-tick.C:
2694 db.mu.Lock()
2695 if db.numOpen == 0 {
2696 db.mu.Unlock()
2697 return
2698 }
2699 db.mu.Unlock()
2700 case <-to.C:
2701
2702 return
2703 }
2704 }
2705 }
2706
2707 func TestSingleOpenConn(t *testing.T) {
2708 testDatabase(t, testSingleOpenConn)
2709 }
2710 func testSingleOpenConn(t *testing.T, db *DB) {
2711 populate(t, db, "people")
2712
2713 db.SetMaxOpenConns(1)
2714
2715 rows, err := db.Query("SELECT|people|name|")
2716 if err != nil {
2717 t.Fatal(err)
2718 }
2719 if err = rows.Close(); err != nil {
2720 t.Fatal(err)
2721 }
2722
2723 rows, err = db.Query("SELECT|people|name|")
2724 if err != nil {
2725 t.Fatal(err)
2726 }
2727 if err = rows.Close(); err != nil {
2728 t.Fatal(err)
2729 }
2730 }
2731
2732 func TestStats(t *testing.T) {
2733 testDatabase(t, testStats)
2734 }
2735 func testStats(t *testing.T, db *DB) {
2736 populate(t, db, "people")
2737 stats := db.Stats()
2738 if got := stats.OpenConnections; got != 1 {
2739 t.Errorf("stats.OpenConnections = %d; want 1", got)
2740 }
2741
2742 tx, err := db.Begin()
2743 if err != nil {
2744 t.Fatal(err)
2745 }
2746 tx.Commit()
2747
2748 closeDB(t, db)
2749 stats = db.Stats()
2750 if got := stats.OpenConnections; got != 0 {
2751 t.Errorf("stats.OpenConnections = %d; want 0", got)
2752 }
2753 }
2754
2755 func TestConnMaxLifetime(t *testing.T) {
2756 testDatabase(t, testConnMaxLifetime)
2757 }
2758 func testConnMaxLifetime(t *testing.T, db *DB) {
2759 populate(t, db, "magicquery")
2760
2761 driver := db.Driver().(*fakeDriver)
2762
2763
2764
2765 db.clearAllConns(t)
2766
2767 driver.mu.Lock()
2768 opens0 := driver.openCount
2769 closes0 := driver.closeCount
2770 driver.mu.Unlock()
2771
2772 db.SetMaxIdleConns(10)
2773 db.SetMaxOpenConns(10)
2774
2775 tx, err := db.Begin()
2776 if err != nil {
2777 t.Fatal(err)
2778 }
2779
2780 synctest.Sleep(1 * time.Second)
2781 tx2, err := db.Begin()
2782 if err != nil {
2783 t.Fatal(err)
2784 }
2785
2786 tx.Commit()
2787 tx2.Commit()
2788
2789 driver.mu.Lock()
2790 opens := driver.openCount - opens0
2791 closes := driver.closeCount - closes0
2792 driver.mu.Unlock()
2793
2794 if opens != 2 {
2795 t.Errorf("opens = %d; want 2", opens)
2796 }
2797 if closes != 0 {
2798 t.Errorf("closes = %d; want 0", closes)
2799 }
2800 if g, w := db.numFreeConns(), 2; g != w {
2801 t.Errorf("free conns = %d; want %d", g, w)
2802 }
2803
2804
2805 synctest.Sleep(10 * time.Second)
2806 db.SetConnMaxLifetime(10 * time.Second)
2807
2808 tx, err = db.Begin()
2809 if err != nil {
2810 t.Fatal(err)
2811 }
2812 tx2, err = db.Begin()
2813 if err != nil {
2814 t.Fatal(err)
2815 }
2816 tx.Commit()
2817 tx2.Commit()
2818
2819
2820 synctest.Wait()
2821 driver.mu.Lock()
2822 opens = driver.openCount - opens0
2823 closes = driver.closeCount - closes0
2824 driver.mu.Unlock()
2825
2826 if opens != 3 {
2827 t.Errorf("opens = %d; want 3", opens)
2828 }
2829 if closes != 1 {
2830 t.Errorf("closes = %d; want 1", closes)
2831 }
2832
2833 if s := db.Stats(); s.MaxLifetimeClosed != 1 {
2834 t.Errorf("MaxLifetimeClosed = %d; want 1 %#v", s.MaxLifetimeClosed, s)
2835 }
2836 }
2837
2838
2839 func TestStmtCloseDeps(t *testing.T) {
2840 testDatabase(t, testStmtCloseDeps)
2841 }
2842 func testStmtCloseDeps(t *testing.T, db *DB) {
2843 if testing.Short() {
2844 t.Skip("skipping in short mode")
2845 }
2846 defer setHookpostCloseConn(nil)
2847 setHookpostCloseConn(func(_ *fakeConn, err error) {
2848 if err != nil {
2849 t.Errorf("Error closing fakeConn: %v", err)
2850 }
2851 })
2852
2853 populate(t, db, "magicquery")
2854
2855 driver := db.Driver().(*fakeDriver)
2856
2857 driver.mu.Lock()
2858 opens0 := driver.openCount
2859 closes0 := driver.closeCount
2860 driver.mu.Unlock()
2861 openDelta0 := opens0 - closes0
2862
2863 stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
2864 if err != nil {
2865 t.Fatal(err)
2866 }
2867
2868
2869 const (
2870 nquery = 50
2871 sleepMillis = 25
2872 nbatch = 2
2873 )
2874 var wg sync.WaitGroup
2875 for batch := 0; batch < nbatch; batch++ {
2876 for i := 0; i < nquery; i++ {
2877 wg.Add(1)
2878 go func() {
2879 defer wg.Done()
2880 var op string
2881 if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
2882 t.Error(err)
2883 }
2884 }()
2885 }
2886
2887 wg.Wait()
2888 }
2889
2890 if g, w := db.numFreeConns(), 2; g != w {
2891 t.Errorf("free conns = %d; want %d", g, w)
2892 }
2893
2894 if n := db.numDeps(); n > 4 {
2895 t.Errorf("number of dependencies = %d; expected <= 4", n)
2896 db.dumpDeps(t)
2897 }
2898
2899 driver.mu.Lock()
2900 opens := driver.openCount - opens0
2901 closes := driver.closeCount - closes0
2902 openDelta := (driver.openCount - driver.closeCount) - openDelta0
2903 driver.mu.Unlock()
2904
2905 if openDelta > 2 {
2906 t.Logf("open calls = %d", opens)
2907 t.Logf("close calls = %d", closes)
2908 t.Logf("open delta = %d", openDelta)
2909 t.Errorf("db connections opened = %d; want <= 2", openDelta)
2910 db.dumpDeps(t)
2911 }
2912
2913 synctest.Wait()
2914 if len(stmt.css) > nquery {
2915 t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
2916 }
2917
2918 if err := stmt.Close(); err != nil {
2919 t.Fatal(err)
2920 }
2921
2922 if g, w := db.numFreeConns(), 2; g != w {
2923 t.Errorf("free conns = %d; want %d", g, w)
2924 }
2925
2926 if n := db.numDeps(); n > 2 {
2927 t.Errorf("number of dependencies = %d; expected <= 2", n)
2928 db.dumpDeps(t)
2929 }
2930
2931 db.clearAllConns(t)
2932 }
2933
2934
2935 func TestCloseConnBeforeStmts(t *testing.T) {
2936 testDatabase(t, testCloseConnBeforeStmts)
2937 }
2938 func testCloseConnBeforeStmts(t *testing.T, db *DB) {
2939 populate(t, db, "people")
2940
2941 defer setHookpostCloseConn(nil)
2942 setHookpostCloseConn(func(_ *fakeConn, err error) {
2943 if err != nil {
2944 t.Errorf("Error closing fakeConn: %v; from %s", err, stack())
2945 db.dumpDeps(t)
2946 t.Errorf("DB = %#v", db)
2947 }
2948 })
2949
2950 stmt, err := db.Prepare("SELECT|people|name|")
2951 if err != nil {
2952 t.Fatal(err)
2953 }
2954
2955 if len(db.freeConn) != 1 {
2956 t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
2957 }
2958 dc := db.freeConn[0]
2959 if dc.closed {
2960 t.Errorf("conn shouldn't be closed")
2961 }
2962
2963 if n := len(dc.openStmt); n != 1 {
2964 t.Errorf("driverConn num openStmt = %d; want 1", n)
2965 }
2966 err = db.Close()
2967 if err != nil {
2968 t.Errorf("db Close = %v", err)
2969 }
2970 if !dc.closed {
2971 t.Errorf("after db.Close, driverConn should be closed")
2972 }
2973 if n := len(dc.openStmt); n != 0 {
2974 t.Errorf("driverConn num openStmt = %d; want 0", n)
2975 }
2976
2977 err = stmt.Close()
2978 if err != nil {
2979 t.Errorf("Stmt close = %v", err)
2980 }
2981
2982 if !dc.closed {
2983 t.Errorf("conn should be closed")
2984 }
2985 if dc.ci != nil {
2986 t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
2987 }
2988 }
2989
2990
2991
2992 func TestRowsCloseOrder(t *testing.T) {
2993 testDatabase(t, testRowsCloseOrder)
2994 }
2995 func testRowsCloseOrder(t *testing.T, db *DB) {
2996 populate(t, db, "people")
2997
2998 db.SetMaxIdleConns(0)
2999 setStrictFakeConnClose(t)
3000 defer setStrictFakeConnClose(nil)
3001
3002 rows, err := db.Query("SELECT|people|age,name|")
3003 if err != nil {
3004 t.Fatal(err)
3005 }
3006 err = rows.Close()
3007 if err != nil {
3008 t.Fatal(err)
3009 }
3010 }
3011
3012 func TestRowsImplicitClose(t *testing.T) {
3013 testDatabase(t, testRowsImplicitClose)
3014 }
3015 func testRowsImplicitClose(t *testing.T, db *DB) {
3016 populate(t, db, "people")
3017
3018 rows, err := db.Query("SELECT|people|age,name|")
3019 if err != nil {
3020 t.Fatal(err)
3021 }
3022
3023 want, fail := 2, errors.New("fail")
3024 r := getRowsCursor(rows)
3025 r.errPos, r.err = want, fail
3026
3027 got := 0
3028 for rows.Next() {
3029 got++
3030 }
3031 if got != want {
3032 t.Errorf("got %d rows, want %d", got, want)
3033 }
3034 if err := rows.Err(); err != fail {
3035 t.Errorf("got error %v, want %v", err, fail)
3036 }
3037 if !r.closed {
3038 t.Errorf("r.closed is false, want true")
3039 }
3040 }
3041
3042 func TestRowsCloseError(t *testing.T) {
3043 testDatabase(t, testRowsCloseError)
3044 }
3045 func testRowsCloseError(t *testing.T, db *DB) {
3046 populate(t, db, "people")
3047 rows, err := db.Query("SELECT|people|age,name|")
3048 if err != nil {
3049 t.Fatalf("Query: %v", err)
3050 }
3051 type row struct {
3052 age int
3053 name string
3054 }
3055 got := []row{}
3056
3057 rc := getRowsCursor(rows)
3058 rc.closeErr = errors.New("rowsCursor: failed to close")
3059
3060 for rows.Next() {
3061 var r row
3062 err = rows.Scan(&r.age, &r.name)
3063 if err != nil {
3064 t.Fatalf("Scan: %v", err)
3065 }
3066 got = append(got, r)
3067 }
3068 err = rows.Err()
3069 if err != rc.closeErr {
3070 t.Fatalf("unexpected err: got %v, want %v", err, rc.closeErr)
3071 }
3072 }
3073
3074 func TestStmtCloseOrder(t *testing.T) {
3075 testDatabase(t, testStmtCloseOrder)
3076 }
3077 func testStmtCloseOrder(t *testing.T, db *DB) {
3078 populate(t, db, "people")
3079
3080 db.SetMaxIdleConns(0)
3081 setStrictFakeConnClose(t)
3082 defer setStrictFakeConnClose(nil)
3083
3084 _, err := db.Query("SELECT|non_existent|name|")
3085 if err == nil {
3086 t.Fatal("Querying non-existent table should fail")
3087 }
3088 }
3089
3090
3091
3092 func TestManyErrBadConn(t *testing.T) {
3093 synctest.Test(t, testManyErrBadConn)
3094 }
3095 func testManyErrBadConn(t *testing.T) {
3096 manyErrBadConnSetup := func(first ...func(db *DB)) *DB {
3097 db := newTestDB(t, "people")
3098
3099 for _, f := range first {
3100 f(db)
3101 }
3102
3103 nconn := maxBadConnRetries + 1
3104 db.SetMaxIdleConns(nconn)
3105 db.SetMaxOpenConns(nconn)
3106
3107 func() {
3108 for i := 0; i < nconn; i++ {
3109 rows, err := db.Query("SELECT|people|age,name|")
3110 if err != nil {
3111 t.Fatal(err)
3112 }
3113 defer rows.Close()
3114 }
3115 }()
3116
3117 db.mu.Lock()
3118 defer db.mu.Unlock()
3119 if db.numOpen != nconn {
3120 t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn)
3121 } else if len(db.freeConn) != nconn {
3122 t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
3123 }
3124 for _, conn := range db.freeConn {
3125 conn.Lock()
3126 getFakeConn(conn.ci).stickyBad = true
3127 conn.Unlock()
3128 }
3129 return db
3130 }
3131
3132
3133 db := manyErrBadConnSetup()
3134 defer closeDB(t, db)
3135 rows, err := db.Query("SELECT|people|age,name|")
3136 if err != nil {
3137 t.Fatal(err)
3138 }
3139 if err = rows.Close(); err != nil {
3140 t.Fatal(err)
3141 }
3142
3143
3144 db = manyErrBadConnSetup()
3145 defer closeDB(t, db)
3146 _, err = db.Exec("INSERT|people|name=Julia,age=19")
3147 if err != nil {
3148 t.Fatal(err)
3149 }
3150
3151
3152 db = manyErrBadConnSetup()
3153 defer closeDB(t, db)
3154 tx, err := db.Begin()
3155 if err != nil {
3156 t.Fatal(err)
3157 }
3158 if err = tx.Rollback(); err != nil {
3159 t.Fatal(err)
3160 }
3161
3162
3163 db = manyErrBadConnSetup()
3164 defer closeDB(t, db)
3165 stmt, err := db.Prepare("SELECT|people|age,name|")
3166 if err != nil {
3167 t.Fatal(err)
3168 }
3169 if err = stmt.Close(); err != nil {
3170 t.Fatal(err)
3171 }
3172
3173
3174 db = manyErrBadConnSetup(func(db *DB) {
3175 stmt, err = db.Prepare("INSERT|people|name=Julia,age=19")
3176 if err != nil {
3177 t.Fatal(err)
3178 }
3179 })
3180 defer closeDB(t, db)
3181 _, err = stmt.Exec()
3182 if err != nil {
3183 t.Fatal(err)
3184 }
3185 if err = stmt.Close(); err != nil {
3186 t.Fatal(err)
3187 }
3188
3189
3190 db = manyErrBadConnSetup(func(db *DB) {
3191 stmt, err = db.Prepare("SELECT|people|age,name|")
3192 if err != nil {
3193 t.Fatal(err)
3194 }
3195 })
3196 defer closeDB(t, db)
3197 rows, err = stmt.Query()
3198 if err != nil {
3199 t.Fatal(err)
3200 }
3201 if err = rows.Close(); err != nil {
3202 t.Fatal(err)
3203 }
3204 if err = stmt.Close(); err != nil {
3205 t.Fatal(err)
3206 }
3207
3208
3209 db = manyErrBadConnSetup()
3210 defer closeDB(t, db)
3211 ctx := t.Context()
3212 conn, err := db.Conn(ctx)
3213 if err != nil {
3214 t.Fatal(err)
3215 }
3216 getFakeConn(conn.dc.ci).skipDirtySession = true
3217 err = conn.Close()
3218 if err != nil {
3219 t.Fatal(err)
3220 }
3221
3222
3223 db = manyErrBadConnSetup()
3224 defer closeDB(t, db)
3225 err = db.PingContext(ctx)
3226 if err != nil {
3227 t.Fatal(err)
3228 }
3229 }
3230
3231
3232 func TestTxCannotCommitAfterRollback(t *testing.T) {
3233 testDatabase(t, testTxCannotCommitAfterRollback)
3234 }
3235 func testTxCannotCommitAfterRollback(t *testing.T, db *DB) {
3236 populate(t, db, "tx_status")
3237
3238
3239 var txStatus string
3240 err := db.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
3241 if err != nil {
3242 t.Fatal(err)
3243 }
3244 if g, w := txStatus, "autocommit"; g != w {
3245 t.Fatalf("tx_status=%q, wanted %q", g, w)
3246 }
3247
3248 ctx, cancel := context.WithCancel(context.Background())
3249 defer cancel()
3250
3251 tx, err := db.BeginTx(ctx, nil)
3252 if err != nil {
3253 t.Fatal(err)
3254 }
3255
3256
3257
3258
3259 tx.txi.(*fakeTx).c.skipDirtySession = true
3260
3261 defer tx.Rollback()
3262
3263 err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
3264 if err != nil {
3265 t.Fatal(err)
3266 }
3267 if g, w := txStatus, "transaction"; g != w {
3268 t.Fatalf("tx_status=%q, wanted %q", g, w)
3269 }
3270
3271
3272
3273
3274 sendQuery := make(chan struct{})
3275
3276
3277 bypassRowsAwaitDone = true
3278 hookTxGrabConn = func() {
3279 cancel()
3280 <-sendQuery
3281 }
3282 rollbackHook = func() {
3283 close(sendQuery)
3284 }
3285 defer func() {
3286 hookTxGrabConn = nil
3287 rollbackHook = nil
3288 bypassRowsAwaitDone = false
3289 }()
3290
3291 err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
3292 if err != nil {
3293
3294 t.Fatal(err)
3295 }
3296 if g, w := txStatus, "transaction"; g != w {
3297 t.Fatalf("tx_status=%q, wanted %q", g, w)
3298 }
3299 }
3300
3301
3302 func TestTxStmtDeadlock(t *testing.T) {
3303 testDatabase(t, testTxStmtDeadlock)
3304 }
3305 func testTxStmtDeadlock(t *testing.T, db *DB) {
3306 populate(t, db, "people")
3307
3308 ctx, cancel := context.WithCancel(context.Background())
3309 defer cancel()
3310 tx, err := db.BeginTx(ctx, nil)
3311 if err != nil {
3312 t.Fatal(err)
3313 }
3314
3315 stmt, err := tx.Prepare("SELECT|people|name,age|age=?")
3316 if err != nil {
3317 t.Fatal(err)
3318 }
3319 cancel()
3320
3321 for i := 0; i < 1e3; i++ {
3322
3323
3324 _, err = stmt.Query(1)
3325 if err != nil {
3326 break
3327 }
3328 }
3329 _ = tx.Rollback()
3330 }
3331
3332 func synctestSubtest(t *testing.T, name string, f func(t *testing.T)) {
3333 t.Run(name, func(t *testing.T) {
3334 synctest.Test(t, f)
3335 })
3336 }
3337
3338
3339
3340
3341 func TestConnExpiresFreshOutOfPool(t *testing.T) {
3342 execCases := []struct {
3343 expired bool
3344 badReset bool
3345 }{
3346 {false, false},
3347 {true, false},
3348 {false, true},
3349 }
3350
3351 for _, ec := range execCases {
3352 name := fmt.Sprintf("expired=%t,badReset=%t", ec.expired, ec.badReset)
3353 synctestSubtest(t, name, func(t *testing.T) {
3354 ctx := t.Context()
3355
3356 db := newTestDB(t, "magicquery")
3357
3358 db.SetMaxOpenConns(1)
3359
3360 db.clearAllConns(t)
3361
3362 db.SetMaxIdleConns(1)
3363 db.SetConnMaxLifetime(10 * time.Second)
3364
3365 conn, err := db.conn(ctx, alwaysNewConn)
3366 if err != nil {
3367 t.Fatal(err)
3368 }
3369
3370 afterPutConn := make(chan struct{})
3371
3372 go func() {
3373 defer close(afterPutConn)
3374
3375 conn, err := db.conn(ctx, alwaysNewConn)
3376 if err == nil {
3377 db.putConn(conn, err, false)
3378 } else {
3379 t.Errorf("db.conn: %v", err)
3380 }
3381 }()
3382 synctest.Wait()
3383
3384 if t.Failed() {
3385 return
3386 }
3387
3388 synctest.Sleep(11 * time.Second)
3389
3390 getFakeConn(conn.ci).stickyBad = ec.badReset
3391
3392 db.putConn(conn, err, true)
3393
3394 <-afterPutConn
3395 })
3396 }
3397 }
3398
3399
3400
3401 func TestIssue20575(t *testing.T) {
3402 testDatabase(t, testIssue20575)
3403 }
3404 func testIssue20575(t *testing.T, db *DB) {
3405 populate(t, db, "people")
3406
3407 tx, err := db.Begin()
3408 if err != nil {
3409 t.Fatal(err)
3410 }
3411 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
3412 defer cancel()
3413 _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
3414 if err != nil {
3415 t.Fatal(err)
3416 }
3417
3418 err = tx.Rollback()
3419 if err != nil {
3420 t.Fatal(err)
3421 }
3422 select {
3423 default:
3424 case <-ctx.Done():
3425 t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err())
3426 }
3427 }
3428
3429
3430
3431 func TestIssue20622(t *testing.T) {
3432 testDatabase(t, testIssue20622)
3433 }
3434 func testIssue20622(t *testing.T, db *DB) {
3435 populate(t, db, "people")
3436
3437 ctx, cancel := context.WithCancel(context.Background())
3438 defer cancel()
3439
3440 tx, err := db.BeginTx(ctx, nil)
3441 if err != nil {
3442 t.Fatal(err)
3443 }
3444
3445 rows, err := tx.Query("SELECT|people|age,name|")
3446 if err != nil {
3447 t.Fatal(err)
3448 }
3449
3450 count := 0
3451 for rows.Next() {
3452 count++
3453 var age int
3454 var name string
3455 if err := rows.Scan(&age, &name); err != nil {
3456 t.Fatal("scan failed", err)
3457 }
3458
3459 if count == 1 {
3460 cancel()
3461 }
3462 time.Sleep(100 * time.Millisecond)
3463 }
3464 rows.Close()
3465 tx.Commit()
3466 }
3467
3468
3469 func TestErrBadConnReconnect(t *testing.T) {
3470 testDatabase(t, testErrBadConnReconnect)
3471 }
3472 func testErrBadConnReconnect(t *testing.T, db *DB) {
3473 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
3474
3475 simulateBadConn := func(name string, hook *func() bool, op func() error) {
3476 broken, retried := false, false
3477 numOpen := db.numOpen
3478
3479
3480 *hook = func() bool {
3481 if !broken {
3482 broken = true
3483 return true
3484 }
3485 retried = true
3486 return false
3487 }
3488
3489 if err := op(); err != nil {
3490 t.Errorf(name+": %v", err)
3491 return
3492 }
3493
3494 if !broken || !retried {
3495 t.Error(name + ": Failed to simulate broken connection")
3496 }
3497 *hook = nil
3498
3499 if numOpen != db.numOpen {
3500 t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
3501 numOpen = db.numOpen
3502 }
3503 }
3504
3505
3506 dbExec := func() error {
3507 _, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
3508 return err
3509 }
3510 simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec)
3511 simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec)
3512
3513
3514 dbQuery := func() error {
3515 rows, err := db.Query("SELECT|t1|age,name|")
3516 if err == nil {
3517 err = rows.Close()
3518 }
3519 return err
3520 }
3521 simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery)
3522 simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery)
3523
3524
3525 simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error {
3526 stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
3527 if err != nil {
3528 return err
3529 }
3530 stmt.Close()
3531 return nil
3532 })
3533
3534
3535 forcePrepare := func(stmt *Stmt) {
3536 stmt.css = nil
3537 }
3538
3539
3540 stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
3541 if err != nil {
3542 t.Fatalf("prepare: %v", err)
3543 }
3544 defer stmt1.Close()
3545
3546 forcePrepare(stmt1)
3547
3548 stmtExec := func() error {
3549 _, err := stmt1.Exec("Gopher", 3, false)
3550 return err
3551 }
3552 simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec)
3553 simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec)
3554
3555
3556 stmt2, err := db.Prepare("SELECT|t1|age,name|")
3557 if err != nil {
3558 t.Fatalf("prepare: %v", err)
3559 }
3560 defer stmt2.Close()
3561
3562 forcePrepare(stmt2)
3563
3564 stmtQuery := func() error {
3565 rows, err := stmt2.Query()
3566 if err == nil {
3567 err = rows.Close()
3568 }
3569 return err
3570 }
3571 simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery)
3572 simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery)
3573 }
3574
3575
3576 func TestTxEndBadConn(t *testing.T) {
3577 testDatabase(t, testTxEndBadConn)
3578 }
3579 func testTxEndBadConn(t *testing.T, db *DB) {
3580 db.SetMaxIdleConns(0)
3581 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
3582 db.SetMaxIdleConns(1)
3583
3584 simulateBadConn := func(name string, hook *func() bool, op func() error) {
3585 broken := false
3586 numOpen := db.numOpen
3587
3588 *hook = func() bool {
3589 if !broken {
3590 broken = true
3591 }
3592 return broken
3593 }
3594
3595 if err := op(); !errors.Is(err, driver.ErrBadConn) {
3596 t.Errorf(name+": %v", err)
3597 return
3598 }
3599
3600 if !broken {
3601 t.Error(name + ": Failed to simulate broken connection")
3602 }
3603 *hook = nil
3604
3605 if numOpen != db.numOpen {
3606 t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
3607 }
3608 }
3609
3610
3611 dbExec := func(endTx func(tx *Tx) error) func() error {
3612 return func() error {
3613 tx, err := db.Begin()
3614 if err != nil {
3615 return err
3616 }
3617 _, err = tx.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
3618 if err != nil {
3619 return err
3620 }
3621 return endTx(tx)
3622 }
3623 }
3624 simulateBadConn("db.Tx.Exec commit", &hookCommitBadConn, dbExec((*Tx).Commit))
3625 simulateBadConn("db.Tx.Exec rollback", &hookRollbackBadConn, dbExec((*Tx).Rollback))
3626
3627
3628 dbQuery := func(endTx func(tx *Tx) error) func() error {
3629 return func() error {
3630 tx, err := db.Begin()
3631 if err != nil {
3632 return err
3633 }
3634 rows, err := tx.Query("SELECT|t1|age,name|")
3635 if err == nil {
3636 err = rows.Close()
3637 } else {
3638 return err
3639 }
3640 return endTx(tx)
3641 }
3642 }
3643 simulateBadConn("db.Tx.Query commit", &hookCommitBadConn, dbQuery((*Tx).Commit))
3644 simulateBadConn("db.Tx.Query rollback", &hookRollbackBadConn, dbQuery((*Tx).Rollback))
3645 }
3646
3647 type concurrentTest interface {
3648 init(t testing.TB, db *DB)
3649 finish(t testing.TB)
3650 test(t testing.TB) error
3651 }
3652
3653 type concurrentDBQueryTest struct {
3654 db *DB
3655 }
3656
3657 func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) {
3658 c.db = db
3659 }
3660
3661 func (c *concurrentDBQueryTest) finish(t testing.TB) {
3662 c.db = nil
3663 }
3664
3665 func (c *concurrentDBQueryTest) test(t testing.TB) error {
3666 rows, err := c.db.Query("SELECT|people|name|")
3667 if err != nil {
3668 t.Error(err)
3669 return err
3670 }
3671 var name string
3672 for rows.Next() {
3673 rows.Scan(&name)
3674 }
3675 rows.Close()
3676 return nil
3677 }
3678
3679 type concurrentDBExecTest struct {
3680 db *DB
3681 }
3682
3683 func (c *concurrentDBExecTest) init(t testing.TB, db *DB) {
3684 c.db = db
3685 }
3686
3687 func (c *concurrentDBExecTest) finish(t testing.TB) {
3688 c.db = nil
3689 }
3690
3691 func (c *concurrentDBExecTest) test(t testing.TB) error {
3692 _, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
3693 if err != nil {
3694 t.Error(err)
3695 return err
3696 }
3697 return nil
3698 }
3699
3700 type concurrentStmtQueryTest struct {
3701 db *DB
3702 stmt *Stmt
3703 }
3704
3705 func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) {
3706 c.db = db
3707 var err error
3708 c.stmt, err = db.Prepare("SELECT|people|name|")
3709 if err != nil {
3710 t.Fatal(err)
3711 }
3712 }
3713
3714 func (c *concurrentStmtQueryTest) finish(t testing.TB) {
3715 if c.stmt != nil {
3716 c.stmt.Close()
3717 c.stmt = nil
3718 }
3719 c.db = nil
3720 }
3721
3722 func (c *concurrentStmtQueryTest) test(t testing.TB) error {
3723 rows, err := c.stmt.Query()
3724 if err != nil {
3725 t.Errorf("error on query: %v", err)
3726 return err
3727 }
3728
3729 var name string
3730 for rows.Next() {
3731 rows.Scan(&name)
3732 }
3733 rows.Close()
3734 return nil
3735 }
3736
3737 type concurrentStmtExecTest struct {
3738 db *DB
3739 stmt *Stmt
3740 }
3741
3742 func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) {
3743 c.db = db
3744 var err error
3745 c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
3746 if err != nil {
3747 t.Fatal(err)
3748 }
3749 }
3750
3751 func (c *concurrentStmtExecTest) finish(t testing.TB) {
3752 if c.stmt != nil {
3753 c.stmt.Close()
3754 c.stmt = nil
3755 }
3756 c.db = nil
3757 }
3758
3759 func (c *concurrentStmtExecTest) test(t testing.TB) error {
3760 _, err := c.stmt.Exec(3, chrisBirthday)
3761 if err != nil {
3762 t.Errorf("error on exec: %v", err)
3763 return err
3764 }
3765 return nil
3766 }
3767
3768 type concurrentTxQueryTest struct {
3769 db *DB
3770 tx *Tx
3771 }
3772
3773 func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) {
3774 c.db = db
3775 var err error
3776 c.tx, err = c.db.Begin()
3777 if err != nil {
3778 t.Fatal(err)
3779 }
3780 }
3781
3782 func (c *concurrentTxQueryTest) finish(t testing.TB) {
3783 if c.tx != nil {
3784 c.tx.Rollback()
3785 c.tx = nil
3786 }
3787 c.db = nil
3788 }
3789
3790 func (c *concurrentTxQueryTest) test(t testing.TB) error {
3791 rows, err := c.db.Query("SELECT|people|name|")
3792 if err != nil {
3793 t.Error(err)
3794 return err
3795 }
3796 var name string
3797 for rows.Next() {
3798 rows.Scan(&name)
3799 }
3800 rows.Close()
3801 return nil
3802 }
3803
3804 type concurrentTxExecTest struct {
3805 db *DB
3806 tx *Tx
3807 }
3808
3809 func (c *concurrentTxExecTest) init(t testing.TB, db *DB) {
3810 c.db = db
3811 var err error
3812 c.tx, err = c.db.Begin()
3813 if err != nil {
3814 t.Fatal(err)
3815 }
3816 }
3817
3818 func (c *concurrentTxExecTest) finish(t testing.TB) {
3819 if c.tx != nil {
3820 c.tx.Rollback()
3821 c.tx = nil
3822 }
3823 c.db = nil
3824 }
3825
3826 func (c *concurrentTxExecTest) test(t testing.TB) error {
3827 _, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
3828 if err != nil {
3829 t.Error(err)
3830 return err
3831 }
3832 return nil
3833 }
3834
3835 type concurrentTxStmtQueryTest struct {
3836 db *DB
3837 tx *Tx
3838 stmt *Stmt
3839 }
3840
3841 func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) {
3842 c.db = db
3843 var err error
3844 c.tx, err = c.db.Begin()
3845 if err != nil {
3846 t.Fatal(err)
3847 }
3848 c.stmt, err = c.tx.Prepare("SELECT|people|name|")
3849 if err != nil {
3850 t.Fatal(err)
3851 }
3852 }
3853
3854 func (c *concurrentTxStmtQueryTest) finish(t testing.TB) {
3855 if c.stmt != nil {
3856 c.stmt.Close()
3857 c.stmt = nil
3858 }
3859 if c.tx != nil {
3860 c.tx.Rollback()
3861 c.tx = nil
3862 }
3863 c.db = nil
3864 }
3865
3866 func (c *concurrentTxStmtQueryTest) test(t testing.TB) error {
3867 rows, err := c.stmt.Query()
3868 if err != nil {
3869 t.Errorf("error on query: %v", err)
3870 return err
3871 }
3872
3873 var name string
3874 for rows.Next() {
3875 rows.Scan(&name)
3876 }
3877 rows.Close()
3878 return nil
3879 }
3880
3881 type concurrentTxStmtExecTest struct {
3882 db *DB
3883 tx *Tx
3884 stmt *Stmt
3885 }
3886
3887 func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) {
3888 c.db = db
3889 var err error
3890 c.tx, err = c.db.Begin()
3891 if err != nil {
3892 t.Fatal(err)
3893 }
3894 c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
3895 if err != nil {
3896 t.Fatal(err)
3897 }
3898 }
3899
3900 func (c *concurrentTxStmtExecTest) finish(t testing.TB) {
3901 if c.stmt != nil {
3902 c.stmt.Close()
3903 c.stmt = nil
3904 }
3905 if c.tx != nil {
3906 c.tx.Rollback()
3907 c.tx = nil
3908 }
3909 c.db = nil
3910 }
3911
3912 func (c *concurrentTxStmtExecTest) test(t testing.TB) error {
3913 _, err := c.stmt.Exec(3, chrisBirthday)
3914 if err != nil {
3915 t.Errorf("error on exec: %v", err)
3916 return err
3917 }
3918 return nil
3919 }
3920
3921 type concurrentRandomTest struct {
3922 tests []concurrentTest
3923 }
3924
3925 func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
3926 c.tests = []concurrentTest{
3927 new(concurrentDBQueryTest),
3928 new(concurrentDBExecTest),
3929 new(concurrentStmtQueryTest),
3930 new(concurrentStmtExecTest),
3931 new(concurrentTxQueryTest),
3932 new(concurrentTxExecTest),
3933 new(concurrentTxStmtQueryTest),
3934 new(concurrentTxStmtExecTest),
3935 }
3936 for _, ct := range c.tests {
3937 ct.init(t, db)
3938 }
3939 }
3940
3941 func (c *concurrentRandomTest) finish(t testing.TB) {
3942 for _, ct := range c.tests {
3943 ct.finish(t)
3944 }
3945 }
3946
3947 func (c *concurrentRandomTest) test(t testing.TB) error {
3948 ct := c.tests[rand.Intn(len(c.tests))]
3949 return ct.test(t)
3950 }
3951
3952 func doConcurrentTest(t testing.TB, ct concurrentTest) {
3953 maxProcs, numReqs := 1, 500
3954 if testing.Short() {
3955 maxProcs, numReqs = 4, 50
3956 }
3957 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
3958
3959 db := newTestDB(t, "people")
3960 defer closeDB(t, db)
3961
3962 ct.init(t, db)
3963 defer ct.finish(t)
3964
3965 var wg sync.WaitGroup
3966 wg.Add(numReqs)
3967
3968 reqs := make(chan bool)
3969 defer close(reqs)
3970
3971 for i := 0; i < maxProcs*2; i++ {
3972 go func() {
3973 for range reqs {
3974 err := ct.test(t)
3975 if err != nil {
3976 wg.Done()
3977 continue
3978 }
3979 wg.Done()
3980 }
3981 }()
3982 }
3983
3984 for i := 0; i < numReqs; i++ {
3985 reqs <- true
3986 }
3987
3988 wg.Wait()
3989 }
3990
3991 func TestIssue6081(t *testing.T) {
3992 testDatabase(t, testIssue6081)
3993 }
3994 func testIssue6081(t *testing.T, db *DB) {
3995 populate(t, db, "people")
3996
3997 drv := db.Driver().(*fakeDriver)
3998 drv.mu.Lock()
3999 opens0 := drv.openCount
4000 closes0 := drv.closeCount
4001 drv.mu.Unlock()
4002
4003 stmt, err := db.Prepare("SELECT|people|name|")
4004 if err != nil {
4005 t.Fatal(err)
4006 }
4007 setRowsCloseHook(func(rows *Rows, err *error) {
4008 *err = driver.ErrBadConn
4009 })
4010 defer setRowsCloseHook(nil)
4011 for i := 0; i < 10; i++ {
4012 rows, err := stmt.Query()
4013 if err != nil {
4014 t.Fatal(err)
4015 }
4016 rows.Close()
4017 }
4018 if n := len(stmt.css); n > 1 {
4019 t.Errorf("len(css slice) = %d; want <= 1", n)
4020 }
4021 stmt.Close()
4022 if n := len(stmt.css); n != 0 {
4023 t.Errorf("len(css slice) after Close = %d; want 0", n)
4024 }
4025
4026 drv.mu.Lock()
4027 opens := drv.openCount - opens0
4028 closes := drv.closeCount - closes0
4029 drv.mu.Unlock()
4030 if opens < 9 {
4031 t.Errorf("opens = %d; want >= 9", opens)
4032 }
4033 if closes < 9 {
4034 t.Errorf("closes = %d; want >= 9", closes)
4035 }
4036 }
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049 func TestIssue18429(t *testing.T) {
4050 testDatabase(t, testIssue18429)
4051 }
4052 func testIssue18429(t *testing.T, db *DB) {
4053 populate(t, db, "people")
4054
4055 ctx := context.Background()
4056 sem := make(chan bool, 20)
4057 var wg sync.WaitGroup
4058
4059 const milliWait = 30
4060
4061 for i := 0; i < 100; i++ {
4062 sem <- true
4063 wg.Add(1)
4064 go func() {
4065 defer func() {
4066 <-sem
4067 wg.Done()
4068 }()
4069 qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
4070
4071 ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
4072 defer cancel()
4073
4074 tx, err := db.BeginTx(ctx, nil)
4075 if err != nil {
4076 return
4077 }
4078
4079
4080
4081 rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
4082 if rows != nil {
4083 var name string
4084
4085 for rows.Next() {
4086
4087 rows.Scan(&name)
4088 }
4089 rows.Close()
4090 }
4091
4092
4093 tx.Rollback()
4094 }()
4095 }
4096 wg.Wait()
4097 }
4098
4099
4100 func TestIssue20160(t *testing.T) {
4101 testDatabase(t, testIssue20160)
4102 }
4103 func testIssue20160(t *testing.T, db *DB) {
4104 populate(t, db, "people")
4105
4106 ctx := context.Background()
4107 sem := make(chan bool, 20)
4108 var wg sync.WaitGroup
4109
4110 const milliWait = 30
4111
4112 stmt, err := db.PrepareContext(ctx, "SELECT|people|name|")
4113 if err != nil {
4114 t.Fatal(err)
4115 }
4116 defer stmt.Close()
4117
4118 for i := 0; i < 100; i++ {
4119 sem <- true
4120 wg.Add(1)
4121 go func() {
4122 defer func() {
4123 <-sem
4124 wg.Done()
4125 }()
4126 ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
4127 defer cancel()
4128
4129
4130
4131
4132 rows, _ := stmt.QueryContext(ctx)
4133 if rows != nil {
4134 rows.Close()
4135 }
4136 }()
4137 }
4138 wg.Wait()
4139 }
4140
4141
4142
4143
4144
4145
4146 func TestIssue18719(t *testing.T) {
4147 testDatabase(t, testIssue18719, requireFeature("BeginTx"))
4148 }
4149 func testIssue18719(t *testing.T, db *DB) {
4150 populate(t, db, "people")
4151
4152 ctx, cancel := context.WithCancel(context.Background())
4153 defer cancel()
4154
4155 tx, err := db.BeginTx(ctx, nil)
4156 if err != nil {
4157 t.Fatal(err)
4158 }
4159
4160 hookTxGrabConn = func() {
4161 cancel()
4162 synctest.Wait()
4163 if !tx.isDone() {
4164 t.Errorf("tx is not done")
4165 }
4166 }
4167 defer func() { hookTxGrabConn = nil }()
4168
4169
4170
4171 _, err = tx.QueryContext(ctx, "SELECT|people|name|")
4172 if err != nil {
4173 t.Fatalf("expected error %v but got %v", nil, err)
4174 }
4175
4176
4177
4178
4179
4180
4181 cancel()
4182 }
4183
4184 func TestIssue20647(t *testing.T) {
4185 testDatabase(t, testIssue20647)
4186 }
4187 func testIssue20647(t *testing.T, db *DB) {
4188 populate(t, db, "people")
4189
4190 ctx := t.Context()
4191
4192 conn, err := db.Conn(ctx)
4193 if err != nil {
4194 t.Fatal(err)
4195 }
4196 getFakeConn(conn.dc.ci).skipDirtySession = true
4197 defer conn.Close()
4198
4199 stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
4200 if err != nil {
4201 t.Fatal(err)
4202 }
4203 defer stmt.Close()
4204
4205 rows1, err := stmt.QueryContext(ctx)
4206 if err != nil {
4207 t.Fatal("rows1", err)
4208 }
4209 defer rows1.Close()
4210
4211 rows2, err := stmt.QueryContext(ctx)
4212 if err != nil {
4213 t.Fatal("rows2", err)
4214 }
4215 defer rows2.Close()
4216
4217 if rows1.dc != rows2.dc {
4218 t.Fatal("stmt prepared on Conn does not use same connection")
4219 }
4220 }
4221
4222 func TestConcurrency(t *testing.T) {
4223 list := []struct {
4224 name string
4225 ct concurrentTest
4226 }{
4227 {"Query", new(concurrentDBQueryTest)},
4228 {"Exec", new(concurrentDBExecTest)},
4229 {"StmtQuery", new(concurrentStmtQueryTest)},
4230 {"StmtExec", new(concurrentStmtExecTest)},
4231 {"TxQuery", new(concurrentTxQueryTest)},
4232 {"TxExec", new(concurrentTxExecTest)},
4233 {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
4234 {"TxStmtExec", new(concurrentTxStmtExecTest)},
4235 {"Random", new(concurrentRandomTest)},
4236 }
4237 for _, item := range list {
4238 synctestSubtest(t, item.name, func(t *testing.T) {
4239 doConcurrentTest(t, item.ct)
4240 })
4241 }
4242 }
4243
4244 func TestConnectionLeak(t *testing.T) {
4245 testDatabase(t, testConnectionLeak)
4246 }
4247 func testConnectionLeak(t *testing.T, db *DB) {
4248 populate(t, db, "people")
4249
4250 rows := make([]*Rows, defaultMaxIdleConns)
4251
4252
4253
4254 db.SetMaxOpenConns(len(rows) + 1)
4255 for ii := range rows {
4256 r, err := db.Query("SELECT|people|name|")
4257 if err != nil {
4258 t.Fatal(err)
4259 }
4260 r.Next()
4261 if err := r.Err(); err != nil {
4262 t.Fatal(err)
4263 }
4264 rows[ii] = r
4265 }
4266
4267
4268
4269 drv := db.Driver().(*fakeDriver)
4270 drv.waitCh = make(chan struct{}, 1)
4271 drv.waitingCh = make(chan struct{}, 1)
4272 var wg sync.WaitGroup
4273 wg.Add(1)
4274 go func() {
4275 r, err := db.Query("SELECT|people|name|")
4276 if err != nil {
4277 t.Error(err)
4278 return
4279 }
4280 r.Close()
4281 wg.Done()
4282 }()
4283
4284 <-drv.waitingCh
4285
4286
4287 for _, v := range rows {
4288 v.Close()
4289 }
4290
4291
4292
4293
4294 drv.waitCh <- struct{}{}
4295 wg.Wait()
4296 }
4297
4298 func TestStatsMaxIdleClosedZero(t *testing.T) {
4299 testDatabase(t, testStatsMaxIdleClosedZero)
4300 }
4301 func testStatsMaxIdleClosedZero(t *testing.T, db *DB) {
4302 populate(t, db, "people")
4303
4304 db.SetMaxOpenConns(1)
4305 db.SetMaxIdleConns(1)
4306 db.SetConnMaxLifetime(0)
4307
4308 preMaxIdleClosed := db.Stats().MaxIdleClosed
4309
4310 for i := 0; i < 10; i++ {
4311 rows, err := db.Query("SELECT|people|name|")
4312 if err != nil {
4313 t.Fatal(err)
4314 }
4315 rows.Close()
4316 }
4317
4318 st := db.Stats()
4319 maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
4320 t.Logf("MaxIdleClosed: %d", maxIdleClosed)
4321 if maxIdleClosed != 0 {
4322 t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
4323 }
4324 }
4325
4326 func TestStatsMaxIdleClosedTen(t *testing.T) {
4327 testDatabase(t, testStatsMaxIdleClosedTen)
4328 }
4329 func testStatsMaxIdleClosedTen(t *testing.T, db *DB) {
4330 populate(t, db, "people")
4331
4332 db.SetMaxOpenConns(1)
4333 db.SetMaxIdleConns(0)
4334 db.SetConnMaxLifetime(0)
4335
4336 preMaxIdleClosed := db.Stats().MaxIdleClosed
4337
4338 for i := 0; i < 10; i++ {
4339 rows, err := db.Query("SELECT|people|name|")
4340 if err != nil {
4341 t.Fatal(err)
4342 }
4343 rows.Close()
4344 }
4345
4346 st := db.Stats()
4347 maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
4348 t.Logf("MaxIdleClosed: %d", maxIdleClosed)
4349 if maxIdleClosed != 10 {
4350 t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
4351 }
4352 }
4353
4354
4355 func testUseConns(t *testing.T, count int, db *DB) {
4356 conns := make([]*Conn, count)
4357 ctx := context.Background()
4358 for i := range conns {
4359 time.Sleep(1 * time.Nanosecond)
4360 c, err := db.Conn(ctx)
4361 if err != nil {
4362 t.Error(err)
4363 }
4364 conns[i] = c
4365 }
4366
4367 for i := len(conns) - 1; i >= 0; i-- {
4368 time.Sleep(1 * time.Nanosecond)
4369 if err := conns[i].Close(); err != nil {
4370 t.Error(err)
4371 }
4372 }
4373 }
4374
4375 func TestMaxIdleTime(t *testing.T) {
4376 usedConns := 5
4377 reusedConns := 2
4378 list := []struct {
4379 wantMaxIdleTime time.Duration
4380 wantMaxLifetime time.Duration
4381 wantNextCheck time.Duration
4382 wantIdleClosed int64
4383 wantMaxIdleClosed int64
4384 timeOffset time.Duration
4385 secondTimeOffset time.Duration
4386 }{
4387 {
4388 time.Millisecond,
4389 0,
4390 time.Millisecond - time.Nanosecond,
4391 int64(usedConns - reusedConns),
4392 int64(usedConns - reusedConns),
4393 10 * time.Millisecond,
4394 0,
4395 },
4396 {
4397
4398 time.Millisecond,
4399
4400
4401
4402 10*time.Millisecond + 100*time.Nanosecond + 3*time.Nanosecond,
4403 time.Nanosecond,
4404
4405 int64(usedConns - reusedConns + 1),
4406 int64(usedConns - reusedConns),
4407 10 * time.Millisecond,
4408
4409 100 * time.Nanosecond,
4410 },
4411 {
4412 time.Hour,
4413 0,
4414 time.Second,
4415 0,
4416 0,
4417 10 * time.Millisecond,
4418 0},
4419 }
4420 for _, item := range list {
4421 synctestSubtest(t, fmt.Sprintf("%v", item.wantMaxIdleTime), func(t *testing.T) {
4422 startTime := time.Now()
4423 db := newTestDB(t, "people")
4424 defer closeDB(t, db)
4425
4426 db.SetMaxOpenConns(usedConns)
4427 db.SetMaxIdleConns(usedConns)
4428 db.SetConnMaxIdleTime(item.wantMaxIdleTime)
4429 db.SetConnMaxLifetime(item.wantMaxLifetime)
4430
4431 preMaxIdleClosed := db.Stats().MaxIdleTimeClosed
4432
4433
4434 testUseConns(t, usedConns, db)
4435
4436 time.Sleep(time.Until(startTime.Add(item.timeOffset)))
4437
4438
4439
4440 testUseConns(t, reusedConns, db)
4441
4442 time.Sleep(item.secondTimeOffset)
4443
4444 db.mu.Lock()
4445 nc, closing := db.connectionCleanerRunLocked(time.Second)
4446 if nc != item.wantNextCheck {
4447 t.Errorf("got %v; want %v next check duration", nc, item.wantNextCheck)
4448 }
4449
4450
4451 var last time.Time
4452 for _, c := range db.freeConn {
4453 if last.After(c.returnedAt) {
4454 t.Error("freeConn is not ordered by returnedAt")
4455 break
4456 }
4457 last = c.returnedAt
4458 }
4459
4460 db.mu.Unlock()
4461 for _, c := range closing {
4462 c.Close()
4463 }
4464 if g, w := int64(len(closing)), item.wantIdleClosed; g != w {
4465 t.Errorf("got: %d; want %d closed conns", g, w)
4466 }
4467
4468 st := db.Stats()
4469 maxIdleClosed := st.MaxIdleTimeClosed - preMaxIdleClosed
4470 if g, w := maxIdleClosed, item.wantMaxIdleClosed; g != w {
4471 t.Errorf("got: %d; want %d max idle closed conns", g, w)
4472 }
4473 })
4474 }
4475 }
4476
4477 type nvcDriver struct {
4478 fakeDriver
4479 skipNamedValueCheck bool
4480 }
4481
4482 func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
4483 c, err := d.fakeDriver.Open(dsn)
4484 fc := getFakeConn(c)
4485 fc.db.allowAny = true
4486 return &nvcConn{fc, d.skipNamedValueCheck}, err
4487 }
4488
4489 type nvcConn struct {
4490 *fakeConn
4491 skipNamedValueCheck bool
4492 }
4493
4494 type decimalInt struct {
4495 value int
4496 }
4497
4498 type doNotInclude struct{}
4499
4500 var _ driver.NamedValueChecker = &nvcConn{}
4501
4502 func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
4503 if c.skipNamedValueCheck {
4504 return driver.ErrSkip
4505 }
4506 switch v := nv.Value.(type) {
4507 default:
4508 return driver.ErrSkip
4509 case Out:
4510 switch ov := v.Dest.(type) {
4511 default:
4512 return errors.New("unknown NameValueCheck OUTPUT type")
4513 case *string:
4514 *ov = "from-server"
4515 nv.Value = "OUT:*string"
4516 }
4517 return nil
4518 case decimalInt, []int64:
4519 return nil
4520 case doNotInclude:
4521 return driver.ErrRemoveArgument
4522 }
4523 }
4524
4525 func TestNamedValueChecker(t *testing.T) {
4526 synctest.Test(t, testNamedValueChecker)
4527 }
4528 func testNamedValueChecker(t *testing.T) {
4529 Register("NamedValueCheck", &nvcDriver{})
4530 db, err := Open("NamedValueCheck", "")
4531 if err != nil {
4532 t.Fatal(err)
4533 }
4534 defer db.Close()
4535
4536 ctx := t.Context()
4537 _, err = db.ExecContext(ctx, "WIPE")
4538 if err != nil {
4539 t.Fatal("exec wipe", err)
4540 }
4541
4542 _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
4543 if err != nil {
4544 t.Fatal("exec create", err)
4545 }
4546
4547 o1 := ""
4548 _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimalInt{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
4549 if err != nil {
4550 t.Fatal("exec insert", err)
4551 }
4552 var (
4553 str1 string
4554 dec1 decimalInt
4555 arr1 []int64
4556 )
4557 err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
4558 if err != nil {
4559 t.Fatal("select", err)
4560 }
4561
4562 list := []struct{ got, want any }{
4563 {o1, "from-server"},
4564 {dec1, decimalInt{123}},
4565 {str1, "hello"},
4566 {arr1, []int64{42, 128, 707}},
4567 }
4568
4569 for index, item := range list {
4570 if !reflect.DeepEqual(item.got, item.want) {
4571 t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
4572 }
4573 }
4574 }
4575
4576 func TestNamedValueCheckerSkip(t *testing.T) {
4577 synctest.Test(t, testNamedValueCheckerSkip)
4578 }
4579 func testNamedValueCheckerSkip(t *testing.T) {
4580 Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
4581 db, err := Open("NamedValueCheckSkip", "")
4582 if err != nil {
4583 t.Fatal(err)
4584 }
4585 defer db.Close()
4586
4587 ctx := t.Context()
4588 _, err = db.ExecContext(ctx, "WIPE")
4589 if err != nil {
4590 t.Fatal("exec wipe", err)
4591 }
4592
4593 _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
4594 if err != nil {
4595 t.Fatal("exec create", err)
4596 }
4597
4598 _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimalInt{123}))
4599 if err == nil {
4600 t.Fatalf("expected error with bad argument, got %v", err)
4601 }
4602 }
4603
4604 func TestOpenConnector(t *testing.T) {
4605 synctest.Test(t, testOpenConnector)
4606 }
4607 func testOpenConnector(t *testing.T) {
4608 Register("testctx", &fakeDriverCtx{})
4609 db, err := Open("testctx", "people")
4610 if err != nil {
4611 t.Fatal(err)
4612 }
4613 defer db.Close()
4614
4615 c, ok := db.connector.(*fakeConnector)
4616 if !ok {
4617 t.Fatal("not using *fakeConnector")
4618 }
4619
4620 if err := db.Close(); err != nil {
4621 t.Fatal(err)
4622 }
4623
4624 if !c.closed {
4625 t.Fatal("connector is not closed")
4626 }
4627 }
4628
4629 type ctxOnlyDriver struct {
4630 fakeDriver
4631 }
4632
4633 func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) {
4634 conn, err := d.fakeDriver.Open(dsn)
4635 if err != nil {
4636 return nil, err
4637 }
4638 return &ctxOnlyConn{fc: getFakeConn(conn)}, nil
4639 }
4640
4641 var (
4642 _ driver.Conn = &ctxOnlyConn{}
4643 _ driver.QueryerContext = &ctxOnlyConn{}
4644 _ driver.ExecerContext = &ctxOnlyConn{}
4645 )
4646
4647 type ctxOnlyConn struct {
4648 fc *fakeConn
4649
4650 queryCtxCalled bool
4651 execCtxCalled bool
4652 }
4653
4654 func (c *ctxOnlyConn) Begin() (driver.Tx, error) {
4655 return c.fc.Begin()
4656 }
4657
4658 func (c *ctxOnlyConn) Close() error {
4659 return c.fc.Close()
4660 }
4661
4662
4663
4664 func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) {
4665 panic("not used")
4666 }
4667
4668 func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
4669 return c.fc.PrepareContext(ctx, q)
4670 }
4671
4672 func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
4673 c.queryCtxCalled = true
4674 return c.fc.QueryContext(ctx, q, args)
4675 }
4676
4677 func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
4678 c.execCtxCalled = true
4679 return c.fc.ExecContext(ctx, q, args)
4680 }
4681
4682
4683
4684 func TestQueryExecContextOnly(t *testing.T) {
4685 synctest.Test(t, testQueryExecContextOnly)
4686 }
4687 func testQueryExecContextOnly(t *testing.T) {
4688
4689 var connType driver.Conn = &ctxOnlyConn{}
4690 if _, ok := connType.(driver.Execer); ok {
4691 t.Fatalf("%T must not implement driver.Execer", connType)
4692 }
4693 if _, ok := connType.(driver.Queryer); ok {
4694 t.Fatalf("%T must not implement driver.Queryer", connType)
4695 }
4696
4697 Register("ContextOnly", &ctxOnlyDriver{})
4698 db, err := Open("ContextOnly", "")
4699 if err != nil {
4700 t.Fatal(err)
4701 }
4702 defer db.Close()
4703
4704 ctx := t.Context()
4705
4706 conn, err := db.Conn(ctx)
4707 if err != nil {
4708 t.Fatal("db.Conn", err)
4709 }
4710 defer conn.Close()
4711 coc := conn.dc.ci.(*ctxOnlyConn)
4712 coc.fc.skipDirtySession = true
4713
4714 _, err = conn.ExecContext(ctx, "WIPE")
4715 if err != nil {
4716 t.Fatal("exec wipe", err)
4717 }
4718
4719 _, err = conn.ExecContext(ctx, "CREATE|keys|v1=string")
4720 if err != nil {
4721 t.Fatal("exec create", err)
4722 }
4723 expectedValue := "value1"
4724 _, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue)
4725 if err != nil {
4726 t.Fatal("exec insert", err)
4727 }
4728 rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|")
4729 if err != nil {
4730 t.Fatal("query select", err)
4731 }
4732 v1 := ""
4733 for rows.Next() {
4734 err = rows.Scan(&v1)
4735 if err != nil {
4736 t.Fatal("rows scan", err)
4737 }
4738 }
4739 rows.Close()
4740
4741 if v1 != expectedValue {
4742 t.Fatalf("expected %q, got %q", expectedValue, v1)
4743 }
4744
4745 if !coc.execCtxCalled {
4746 t.Error("ExecContext not called")
4747 }
4748 if !coc.queryCtxCalled {
4749 t.Error("QueryContext not called")
4750 }
4751 }
4752
4753 type alwaysErrScanner struct{}
4754
4755 var errTestScanWrap = errors.New("errTestScanWrap")
4756
4757 func (alwaysErrScanner) Scan(any) error {
4758 return errTestScanWrap
4759 }
4760
4761
4762 func TestRowsScanProperlyWrapsErrors(t *testing.T) {
4763 testDatabase(t, testRowsScanProperlyWrapsErrors)
4764 }
4765 func testRowsScanProperlyWrapsErrors(t *testing.T, db *DB) {
4766 populate(t, db, "people")
4767
4768 rows, err := db.Query("SELECT|people|age|")
4769 if err != nil {
4770 t.Fatalf("Query: %v", err)
4771 }
4772
4773 var res alwaysErrScanner
4774
4775 for rows.Next() {
4776 err = rows.Scan(&res)
4777 if err == nil {
4778 t.Fatal("expecting back an error")
4779 }
4780 if !errors.Is(err, errTestScanWrap) {
4781 t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errTestScanWrap)
4782 }
4783
4784 if !strings.Contains(err.Error(), errTestScanWrap.Error()) {
4785 t.Fatalf("Error %v does not contain %v", err, errTestScanWrap)
4786 }
4787 }
4788 }
4789
4790 type alwaysErrValuer struct{}
4791
4792
4793 var errEmpty = errors.New("empty value")
4794
4795 func (v alwaysErrValuer) Value() (driver.Value, error) {
4796 return nil, errEmpty
4797 }
4798
4799
4800 func TestDriverArgsWrapsErrors(t *testing.T) {
4801
4802 synctestSubtest(t, "exec", func(t *testing.T) {
4803 db := newTestDB(t, "people")
4804 defer closeDB(t, db)
4805 _, err := db.Exec("INSERT|keys|dec1=?", alwaysErrValuer{})
4806 if err == nil {
4807 t.Fatal("expecting back an error")
4808 }
4809 if !errors.Is(err, errEmpty) {
4810 t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errEmpty)
4811 }
4812
4813 if !strings.Contains(err.Error(), errEmpty.Error()) {
4814 t.Fatalf("Error %v does not contain %v", err, errEmpty)
4815 }
4816 })
4817
4818 synctestSubtest(t, "query", func(t *testing.T) {
4819 db := newTestDB(t, "people")
4820 defer closeDB(t, db)
4821 _, err := db.Query("INSERT|keys|dec1=?", alwaysErrValuer{})
4822 if err == nil {
4823 t.Fatal("expecting back an error")
4824 }
4825 if !errors.Is(err, errEmpty) {
4826 t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errEmpty)
4827 }
4828
4829 if !strings.Contains(err.Error(), errEmpty.Error()) {
4830 t.Fatalf("Error %v does not contain %v", err, errEmpty)
4831 }
4832 })
4833 }
4834
4835 func TestContextCancelDuringRawBytesScan(t *testing.T) {
4836 for _, mode := range []string{"nocancel", "top", "bottom", "go"} {
4837 synctestSubtest(t, mode, func(t *testing.T) {
4838 testContextCancelDuringRawBytesScan(t, mode)
4839 })
4840 }
4841 }
4842
4843
4844 func testContextCancelDuringRawBytesScan(t *testing.T, mode string) {
4845 db := newTestDB(t, "people")
4846 defer closeDB(t, db)
4847
4848
4849
4850 ctx, cancel := context.WithCancel(context.Background())
4851 defer cancel()
4852
4853 r, err := db.QueryContext(ctx, "SELECT|people|name|")
4854 if err != nil {
4855 t.Fatal(err)
4856 }
4857 numRows := 0
4858 var sink byte
4859 for r.Next() {
4860 if mode == "top" && numRows == 2 {
4861
4862 cancel()
4863 synctest.Wait()
4864 }
4865 numRows++
4866 var s RawBytes
4867 err = r.Scan(&s)
4868 if numRows == 3 && err == context.Canceled {
4869 if r.closemuScanHold {
4870 t.Errorf("expected closemu NOT to be held")
4871 }
4872 break
4873 }
4874 if !r.closemuScanHold {
4875 t.Errorf("expected closemu to be held")
4876 }
4877 if err != nil {
4878 t.Fatal(err)
4879 }
4880 t.Logf("read %q", s)
4881 if mode == "bottom" && numRows == 2 {
4882
4883 cancel()
4884 synctest.Wait()
4885 }
4886 if mode == "go" && numRows == 2 {
4887
4888 go cancel()
4889 }
4890 for _, b := range s {
4891 sink += b
4892 }
4893
4894 _, _ = r.Columns()
4895 }
4896 if r.closemuScanHold {
4897 t.Errorf("closemu held; should not be")
4898 }
4899
4900
4901
4902 switch numRows {
4903 case 0, 1:
4904 t.Errorf("got %d rows; want 2+", numRows)
4905 case 2:
4906 if err := r.Err(); err != context.Canceled {
4907 t.Errorf("unexpected error: %v (%T)", err, err)
4908 }
4909 default:
4910
4911 }
4912
4913 if err := r.Close(); err != nil {
4914 t.Fatal(err)
4915 }
4916 }
4917
4918 func TestContextCancelBetweenNextAndErr(t *testing.T) {
4919 testDatabase(t, testContextCancelBetweenNextAndErr)
4920 }
4921 func testContextCancelBetweenNextAndErr(t *testing.T, db *DB) {
4922 populate(t, db, "people")
4923 ctx, cancel := context.WithCancel(context.Background())
4924 defer cancel()
4925
4926 r, err := db.QueryContext(ctx, "SELECT|people|name|")
4927 if err != nil {
4928 t.Fatal(err)
4929 }
4930 for r.Next() {
4931 }
4932 cancel()
4933 synctest.Wait()
4934 if err := r.Err(); err != nil {
4935 t.Fatal(err)
4936 }
4937 }
4938
4939 type testScanner struct {
4940 scanf func(src any) error
4941 }
4942
4943 func (ts testScanner) Scan(src any) error { return ts.scanf(src) }
4944
4945 func TestContextCancelDuringScan(t *testing.T) {
4946 testDatabase(t, testContextCancelDuringScan)
4947 }
4948 func testContextCancelDuringScan(t *testing.T, db *DB) {
4949 populate(t, db, "people")
4950
4951 ctx, cancel := context.WithCancel(context.Background())
4952 defer cancel()
4953
4954 scanStart := make(chan any)
4955 scanEnd := make(chan error)
4956 scanner := &testScanner{
4957 scanf: func(src any) error {
4958 scanStart <- src
4959 return <-scanEnd
4960 },
4961 }
4962
4963
4964 want := []byte("Alice")
4965 r, err := db.QueryContext(ctx, "SELECT|people|name|name=?", string(want))
4966 if err != nil {
4967 t.Fatal(err)
4968 }
4969 if !r.Next() {
4970 t.Fatalf("r.Next() = false, want true")
4971 }
4972 go func() {
4973 r.Scan(scanner)
4974 }()
4975 got := <-scanStart
4976 defer close(scanEnd)
4977 gotBytes, ok := got.([]byte)
4978 if !ok {
4979 t.Fatalf("r.Scan returned %T, want []byte", got)
4980 }
4981 if !bytes.Equal(gotBytes, want) {
4982 t.Fatalf("before cancel: r.Scan returned %q, want %q", gotBytes, want)
4983 }
4984
4985
4986 cancel()
4987 synctest.Wait()
4988
4989
4990 if !bytes.Equal(gotBytes, want) {
4991 t.Fatalf("after cancel: r.Scan result is now %q, want %q", gotBytes, want)
4992 }
4993 }
4994
4995 func TestNilErrorAfterClose(t *testing.T) {
4996 testDatabase(t, testNilErrorAfterClose)
4997 }
4998 func testNilErrorAfterClose(t *testing.T, db *DB) {
4999 populate(t, db, "people")
5000
5001
5002
5003
5004 ctx, cancel := context.WithCancel(context.Background())
5005 defer cancel()
5006
5007 r, err := db.QueryContext(ctx, "SELECT|people|name|")
5008 if err != nil {
5009 t.Fatal(err)
5010 }
5011
5012 if err := r.Close(); err != nil {
5013 t.Fatal(err)
5014 }
5015
5016 synctest.Wait()
5017 if err := r.Err(); err != nil {
5018 t.Fatal(err)
5019 }
5020 }
5021
5022
5023
5024
5025
5026 func TestRawBytesReuse(t *testing.T) {
5027 testDatabase(t, testRawBytesReuse)
5028 }
5029 func testRawBytesReuse(t *testing.T, db *DB) {
5030 populate(t, db, "people")
5031
5032 var raw RawBytes
5033
5034
5035 rows, err := db.Query("SELECT|people|name|")
5036 if err != nil {
5037 t.Fatal(err)
5038 }
5039 rows.Next()
5040 rows.Scan(&raw)
5041 name1 := string(raw)
5042 rows.Close()
5043
5044
5045 rows, err = db.Query("SELECT|people|age|")
5046 if err != nil {
5047 t.Fatal(err)
5048 }
5049 rows.Next()
5050 rows.Scan(&raw)
5051 rows.Close()
5052
5053
5054 rows, err = db.Query("SELECT|people|name|")
5055 if err != nil {
5056 t.Fatal(err)
5057 }
5058 rows.Next()
5059 rows.Scan(&raw)
5060 name2 := string(raw)
5061 rows.Close()
5062 if name1 != name2 {
5063 t.Fatalf("Scan read name %q, want %q", name2, name1)
5064 }
5065 }
5066
5067
5068
5069 type badConn struct{}
5070
5071 func (bc badConn) Prepare(query string) (driver.Stmt, error) {
5072 return nil, errors.New("badConn Prepare")
5073 }
5074
5075 func (bc badConn) Close() error {
5076 return nil
5077 }
5078
5079 func (bc badConn) Begin() (driver.Tx, error) {
5080 return nil, errors.New("badConn Begin")
5081 }
5082
5083 func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) {
5084 panic("badConn.Exec")
5085 }
5086
5087
5088 type badDriver struct{}
5089
5090 func (bd badDriver) Open(name string) (driver.Conn, error) {
5091 return badConn{}, nil
5092 }
5093
5094
5095 func TestBadDriver(t *testing.T) {
5096 synctest.Test(t, testBadDriver)
5097 }
5098 func testBadDriver(t *testing.T) {
5099 Register("bad", badDriver{})
5100 db, err := Open("bad", "ignored")
5101 if err != nil {
5102 t.Fatal(err)
5103 }
5104 defer func() {
5105 if r := recover(); r == nil {
5106 t.Error("expected panic")
5107 } else {
5108 if want := "badConn.Exec"; r.(string) != want {
5109 t.Errorf("panic was %v, expected %v", r, want)
5110 }
5111 }
5112 }()
5113 defer db.Close()
5114 db.Exec("ignored")
5115 }
5116
5117 type pingDriver struct {
5118 fails bool
5119 }
5120
5121 type pingConn struct {
5122 badConn
5123 driver *pingDriver
5124 }
5125
5126 var pingError = errors.New("Ping failed")
5127
5128 func (pc pingConn) Ping(ctx context.Context) error {
5129 if pc.driver.fails {
5130 return pingError
5131 }
5132 return nil
5133 }
5134
5135 var _ driver.Pinger = pingConn{}
5136
5137 func (pd *pingDriver) Open(name string) (driver.Conn, error) {
5138 return pingConn{driver: pd}, nil
5139 }
5140
5141 func TestPing(t *testing.T) {
5142 synctest.Test(t, testPing)
5143 }
5144 func testPing(t *testing.T) {
5145 driver := &pingDriver{}
5146 Register("ping", driver)
5147
5148 db, err := Open("ping", "ignored")
5149 if err != nil {
5150 t.Fatal(err)
5151 }
5152 defer db.Close()
5153
5154 if err := db.Ping(); err != nil {
5155 t.Errorf("err was %#v, expected nil", err)
5156 return
5157 }
5158
5159 driver.fails = true
5160 if err := db.Ping(); err != pingError {
5161 t.Errorf("err was %#v, expected pingError", err)
5162 }
5163 }
5164
5165
5166 func TestTypedString(t *testing.T) {
5167 testDatabase(t, testTypedString)
5168 }
5169 func testTypedString(t *testing.T, db *DB) {
5170 populate(t, db, "people")
5171
5172 type Str string
5173 var scanned Str
5174
5175 err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned)
5176 if err != nil {
5177 t.Fatal(err)
5178 }
5179 expected := Str("Alice")
5180 if scanned != expected {
5181 t.Errorf("expected %+v, got %+v", expected, scanned)
5182 }
5183 }
5184
5185 func BenchmarkConcurrentDBExec(b *testing.B) {
5186 b.ReportAllocs()
5187 ct := new(concurrentDBExecTest)
5188 for i := 0; i < b.N; i++ {
5189 doConcurrentTest(b, ct)
5190 }
5191 }
5192
5193 func BenchmarkConcurrentStmtQuery(b *testing.B) {
5194 b.ReportAllocs()
5195 ct := new(concurrentStmtQueryTest)
5196 for i := 0; i < b.N; i++ {
5197 doConcurrentTest(b, ct)
5198 }
5199 }
5200
5201 func BenchmarkConcurrentStmtExec(b *testing.B) {
5202 b.ReportAllocs()
5203 ct := new(concurrentStmtExecTest)
5204 for i := 0; i < b.N; i++ {
5205 doConcurrentTest(b, ct)
5206 }
5207 }
5208
5209 func BenchmarkConcurrentTxQuery(b *testing.B) {
5210 b.ReportAllocs()
5211 ct := new(concurrentTxQueryTest)
5212 for i := 0; i < b.N; i++ {
5213 doConcurrentTest(b, ct)
5214 }
5215 }
5216
5217 func BenchmarkConcurrentTxExec(b *testing.B) {
5218 b.ReportAllocs()
5219 ct := new(concurrentTxExecTest)
5220 for i := 0; i < b.N; i++ {
5221 doConcurrentTest(b, ct)
5222 }
5223 }
5224
5225 func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
5226 b.ReportAllocs()
5227 ct := new(concurrentTxStmtQueryTest)
5228 for i := 0; i < b.N; i++ {
5229 doConcurrentTest(b, ct)
5230 }
5231 }
5232
5233 func BenchmarkConcurrentTxStmtExec(b *testing.B) {
5234 b.ReportAllocs()
5235 ct := new(concurrentTxStmtExecTest)
5236 for i := 0; i < b.N; i++ {
5237 doConcurrentTest(b, ct)
5238 }
5239 }
5240
5241 func BenchmarkConcurrentRandom(b *testing.B) {
5242 b.ReportAllocs()
5243 ct := new(concurrentRandomTest)
5244 for i := 0; i < b.N; i++ {
5245 doConcurrentTest(b, ct)
5246 }
5247 }
5248
5249 func BenchmarkManyConcurrentQueries(b *testing.B) {
5250 b.ReportAllocs()
5251
5252 const parallelism = 16
5253
5254 db := newTestDB(b, "magicquery")
5255 defer closeDB(b, db)
5256 db.SetMaxIdleConns(runtime.GOMAXPROCS(0) * parallelism)
5257
5258 stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
5259 if err != nil {
5260 b.Fatal(err)
5261 }
5262 defer stmt.Close()
5263
5264 b.SetParallelism(parallelism)
5265 b.RunParallel(func(pb *testing.PB) {
5266 for pb.Next() {
5267 rows, err := stmt.Query("sleep", 1)
5268 if err != nil {
5269 b.Error(err)
5270 return
5271 }
5272 rows.Close()
5273 }
5274 })
5275 }
5276
5277 func TestGrabConnAllocs(t *testing.T) {
5278 synctest.Test(t, testGrabConnAllocs)
5279 }
5280 func testGrabConnAllocs(t *testing.T) {
5281 testenv.SkipIfOptimizationOff(t)
5282 if race.Enabled {
5283 t.Skip("skipping allocation test when using race detector")
5284 }
5285 c := new(Conn)
5286 ctx := context.Background()
5287 n := int(testing.AllocsPerRun(1000, func() {
5288 _, release, err := c.grabConn(ctx)
5289 if err != nil {
5290 t.Fatal(err)
5291 }
5292 release(nil)
5293 }))
5294 if n > 0 {
5295 t.Fatalf("Conn.grabConn allocated %v objects; want 0", n)
5296 }
5297 }
5298
5299 func BenchmarkGrabConn(b *testing.B) {
5300 b.ReportAllocs()
5301 c := new(Conn)
5302 ctx := context.Background()
5303 for i := 0; i < b.N; i++ {
5304 _, release, err := c.grabConn(ctx)
5305 if err != nil {
5306 b.Fatal(err)
5307 }
5308 release(nil)
5309 }
5310 }
5311
5312 func TestConnRequestSet(t *testing.T) {
5313 var s connRequestSet
5314 wantLen := func(want int) {
5315 t.Helper()
5316 if got := s.Len(); got != want {
5317 t.Errorf("Len = %d; want %d", got, want)
5318 }
5319 if want == 0 && !t.Failed() {
5320 if _, ok := s.TakeRandom(); ok {
5321 t.Fatalf("TakeRandom returned result when empty")
5322 }
5323 }
5324 }
5325 reset := func() { s = connRequestSet{} }
5326
5327 t.Run("add-delete", func(t *testing.T) {
5328 reset()
5329 wantLen(0)
5330 dh := s.Add(nil)
5331 wantLen(1)
5332 if !s.Delete(dh) {
5333 t.Fatal("failed to delete")
5334 }
5335 wantLen(0)
5336 if s.Delete(dh) {
5337 t.Error("delete worked twice")
5338 }
5339 wantLen(0)
5340 })
5341 t.Run("take-before-delete", func(t *testing.T) {
5342 reset()
5343 ch1 := make(chan connRequest)
5344 dh := s.Add(ch1)
5345 wantLen(1)
5346 if got, ok := s.TakeRandom(); !ok || got != ch1 {
5347 t.Fatalf("wrong take; ok=%v", ok)
5348 }
5349 wantLen(0)
5350 if s.Delete(dh) {
5351 t.Error("unexpected delete after take")
5352 }
5353 })
5354 t.Run("get-take-many", func(t *testing.T) {
5355 reset()
5356 m := map[chan connRequest]bool{}
5357 const N = 100
5358 var inOrder, backOut []chan connRequest
5359 for range N {
5360 c := make(chan connRequest)
5361 m[c] = true
5362 s.Add(c)
5363 inOrder = append(inOrder, c)
5364 }
5365 if s.Len() != N {
5366 t.Fatalf("Len = %v; want %v", s.Len(), N)
5367 }
5368 for s.Len() > 0 {
5369 c, ok := s.TakeRandom()
5370 if !ok {
5371 t.Fatal("failed to take when non-empty")
5372 }
5373 if !m[c] {
5374 t.Fatal("returned item not in remaining set")
5375 }
5376 delete(m, c)
5377 backOut = append(backOut, c)
5378 }
5379 if len(m) > 0 {
5380 t.Error("items remain in expected map")
5381 }
5382 if slices.Equal(inOrder, backOut) {
5383 t.Error("wasn't random")
5384 }
5385 })
5386 t.Run("close-delete", func(t *testing.T) {
5387 reset()
5388 ch := make(chan connRequest)
5389 dh := s.Add(ch)
5390 wantLen(1)
5391 s.CloseAndRemoveAll()
5392 wantLen(0)
5393 if s.Delete(dh) {
5394 t.Error("unexpected delete after CloseAndRemoveAll")
5395 }
5396 })
5397 }
5398
5399 func BenchmarkConnRequestSet(b *testing.B) {
5400 var s connRequestSet
5401 for range b.N {
5402 for range 16 {
5403 s.Add(nil)
5404 }
5405 for range 8 {
5406 if _, ok := s.TakeRandom(); !ok {
5407 b.Fatal("want ok")
5408 }
5409 }
5410 for range 8 {
5411 s.Add(nil)
5412 }
5413 for range 16 {
5414 if _, ok := s.TakeRandom(); !ok {
5415 b.Fatal("want ok")
5416 }
5417 }
5418 if _, ok := s.TakeRandom(); ok {
5419 b.Fatal("unexpected ok")
5420 }
5421 }
5422 }
5423
5424 func TestIssue69837(t *testing.T) {
5425 u := Null[uint]{V: 1, Valid: true}
5426 val, err := driver.DefaultParameterConverter.ConvertValue(u)
5427 if err != nil {
5428 t.Errorf("ConvertValue() error = %v, want nil", err)
5429 }
5430
5431 if v, ok := val.(int64); !ok {
5432 t.Errorf("val.(type): got %T, expected int64", val)
5433 } else if v != 1 {
5434 t.Errorf("val: got %d, expected 1", v)
5435 }
5436 }
5437
5438 type issue69728Type struct {
5439 ID int
5440 Name string
5441 }
5442
5443 func (t issue69728Type) Value() (driver.Value, error) {
5444 return []byte(fmt.Sprintf("%d, %s", t.ID, t.Name)), nil
5445 }
5446
5447 func TestIssue69728(t *testing.T) {
5448 forValue := Null[issue69728Type]{
5449 Valid: true,
5450 V: issue69728Type{
5451 ID: 42,
5452 Name: "foobar",
5453 },
5454 }
5455
5456 v1, err := forValue.Value()
5457 if err != nil {
5458 t.Errorf("forValue.Value() error = %v, want nil", err)
5459 }
5460
5461 v2, err := forValue.V.Value()
5462 if err != nil {
5463 t.Errorf("forValue.V.Value() error = %v, want nil", err)
5464 }
5465
5466 if !reflect.DeepEqual(v1, v2) {
5467 t.Errorf("not equal; v1 = %v, v2 = %v", v1, v2)
5468 }
5469 }
5470
5471 func TestColumnConverterWithUnknownInputCount(t *testing.T) {
5472 db := OpenDB(&unknownInputsConnector{})
5473 stmt, err := db.Prepare("SELECT ?")
5474 if err != nil {
5475 t.Fatal(err)
5476 }
5477 _, err = stmt.Exec(1)
5478 if err != nil {
5479 t.Fatal(err)
5480 }
5481 }
5482
5483 type unknownInputsConnector struct{}
5484
5485 func (unknownInputsConnector) Connect(context.Context) (driver.Conn, error) {
5486 return unknownInputsConn{}, nil
5487 }
5488
5489 func (unknownInputsConnector) Driver() driver.Driver { return nil }
5490
5491 type unknownInputsConn struct{}
5492
5493 func (unknownInputsConn) Prepare(string) (driver.Stmt, error) { return unknownInputsStmt{}, nil }
5494 func (unknownInputsConn) Close() error { return nil }
5495 func (unknownInputsConn) Begin() (driver.Tx, error) { return nil, nil }
5496
5497 type unknownInputsStmt struct{}
5498
5499 func (unknownInputsStmt) Close() error { return nil }
5500 func (unknownInputsStmt) NumInput() int { return -1 }
5501 func (unknownInputsStmt) Exec(args []driver.Value) (driver.Result, error) {
5502 if _, ok := args[0].(string); !ok {
5503 return nil, fmt.Errorf("Expected string, got %T", args[0])
5504 }
5505 return nil, nil
5506 }
5507 func (unknownInputsStmt) Query([]driver.Value) (driver.Rows, error) { return nil, nil }
5508 func (unknownInputsStmt) ColumnConverter(idx int) driver.ValueConverter {
5509 return unknownInputsValueConverter{}
5510 }
5511
5512 type unknownInputsValueConverter struct{}
5513
5514 func (unknownInputsValueConverter) ConvertValue(v any) (driver.Value, error) {
5515 return "string", nil
5516 }
5517
5518 func TestNullTypeScanErrorConsistency(t *testing.T) {
5519
5520
5521
5522
5523 tests := []struct {
5524 name string
5525 scanner Scanner
5526 input any
5527 wantErr bool
5528 }{
5529 {
5530 name: "NullInt32 with invalid input",
5531 scanner: &NullInt32{},
5532 input: []byte("not_a_number"),
5533 wantErr: true,
5534 },
5535 {
5536 name: "NullInt64 with invalid input",
5537 scanner: &NullInt64{},
5538 input: []byte("not_a_number"),
5539 wantErr: true,
5540 },
5541 {
5542 name: "NullFloat64 with invalid input",
5543 scanner: &NullFloat64{},
5544 input: []byte("not_a_float"),
5545 wantErr: true,
5546 },
5547 {
5548 name: "NullBool with invalid input",
5549 scanner: &NullBool{},
5550 input: []byte("not_a_bool"),
5551 wantErr: true,
5552 },
5553
5554 {
5555 name: "NullInt32 with valid input",
5556 scanner: &NullInt32{},
5557 input: int64(42),
5558 wantErr: false,
5559 },
5560 {
5561 name: "NullInt64 with valid input",
5562 scanner: &NullInt64{},
5563 input: int64(42),
5564 wantErr: false,
5565 },
5566 {
5567 name: "NullFloat64 with valid input",
5568 scanner: &NullFloat64{},
5569 input: float64(3.14),
5570 wantErr: false,
5571 },
5572 {
5573 name: "NullBool with valid input",
5574 scanner: &NullBool{},
5575 input: true,
5576 wantErr: false,
5577 },
5578 {
5579 name: "NullString with valid input",
5580 scanner: &NullString{},
5581 input: "hello",
5582 wantErr: false,
5583 },
5584 }
5585
5586 for _, tt := range tests {
5587 t.Run(tt.name, func(t *testing.T) {
5588 err := tt.scanner.Scan(tt.input)
5589
5590
5591 if (err != nil) != tt.wantErr {
5592 t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
5593 }
5594
5595
5596
5597 var valid bool
5598 switch s := tt.scanner.(type) {
5599 case *NullInt32:
5600 valid = s.Valid
5601 case *NullInt64:
5602 valid = s.Valid
5603 case *NullFloat64:
5604 valid = s.Valid
5605 case *NullBool:
5606 valid = s.Valid
5607 case *NullString:
5608 valid = s.Valid
5609 case *NullTime:
5610 valid = s.Valid
5611 }
5612
5613 if err != nil && valid {
5614 t.Errorf("Scan() returned error but Valid=true; want Valid=false when err!=nil")
5615 }
5616 if err == nil && !valid {
5617 t.Errorf("Scan() returned nil error but Valid=false; want Valid=true when err==nil")
5618 }
5619 })
5620 }
5621 }
5622
5623
5624 func TestNullTypeScanNil(t *testing.T) {
5625 tests := []struct {
5626 name string
5627 scanner Scanner
5628 }{
5629 {"NullString", &NullString{String: "preset", Valid: true}},
5630 {"NullInt64", &NullInt64{Int64: 42, Valid: true}},
5631 {"NullInt32", &NullInt32{Int32: 42, Valid: true}},
5632 {"NullFloat64", &NullFloat64{Float64: 3.14, Valid: true}},
5633 {"NullBool", &NullBool{Bool: true, Valid: true}},
5634 {"NullTime", &NullTime{Time: time.Now(), Valid: true}},
5635 }
5636
5637 for _, tt := range tests {
5638 t.Run(tt.name, func(t *testing.T) {
5639 err := tt.scanner.Scan(nil)
5640 if err != nil {
5641 t.Errorf("Scan(nil) error = %v; want nil", err)
5642 }
5643
5644 var valid bool
5645 switch s := tt.scanner.(type) {
5646 case *NullString:
5647 valid = s.Valid
5648 case *NullInt64:
5649 valid = s.Valid
5650 case *NullInt32:
5651 valid = s.Valid
5652 case *NullFloat64:
5653 valid = s.Valid
5654 case *NullBool:
5655 valid = s.Valid
5656 case *NullTime:
5657 valid = s.Valid
5658 }
5659
5660 if valid {
5661 t.Errorf("Scan(nil) left Valid=true; want Valid=false")
5662 }
5663 })
5664 }
5665 }
5666
5667 type testStringType struct {
5668 s string
5669 }
5670
5671 func TestQueryRowsScanner(t *testing.T) {
5672 testDatabase(t, testQueryRowsScanner, requireFeature("ScanColumn"))
5673 }
5674 func testQueryRowsScanner(t *testing.T, db *DB) {
5675 populate(t, db, "people")
5676 rows, err := db.Query("SELECT|people|age,name|")
5677 if err != nil {
5678 t.Fatalf("Query: %v", err)
5679 }
5680 defer rows.Close()
5681 type row struct {
5682 age int
5683 name testStringType
5684 }
5685 got := []row{}
5686 for rows.Next() {
5687 var r row
5688 err = rows.Scan(&r.age, &r.name)
5689 if err != nil {
5690 t.Fatalf("Scan: %v", err)
5691 }
5692 got = append(got, r)
5693 }
5694 err = rows.Err()
5695 if err != nil {
5696 t.Fatalf("Err: %v", err)
5697 }
5698 want := []row{
5699 {age: 1, name: testStringType{"Alice"}},
5700 {age: 2, name: testStringType{"Bob"}},
5701 {age: 3, name: testStringType{"Chris"}},
5702 }
5703 if !slices.Equal(got, want) {
5704 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
5705 }
5706 }
5707
5708 type rowsColumnScannerConnector struct {
5709 fakeConnector
5710 }
5711
5712 func (c *rowsColumnScannerConnector) Connect(ctx context.Context) (driver.Conn, error) {
5713 conn, err := c.fakeConnector.Connect(ctx)
5714 fc := getFakeConn(conn)
5715 return &rowsColumnScannerConn{fc}, err
5716 }
5717
5718
5719 type rowsColumnScannerConn struct {
5720 *fakeConn
5721 }
5722
5723 func (s *rowsColumnScannerConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
5724 stmt, err := s.fakeConn.PrepareContext(ctx, query)
5725 if err != nil {
5726 return nil, err
5727 }
5728 return &rowsColumnScannerStmt{stmt.(*fakeStmt)}, nil
5729 }
5730
5731 type rowsColumnScannerStmt struct {
5732 *fakeStmt
5733 }
5734
5735 func (s *rowsColumnScannerStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
5736 rows, err := s.fakeStmt.QueryContext(ctx, args)
5737 if err != nil {
5738 return nil, err
5739 }
5740 return &rowsColumnScannerRows{rowsCursor: rows.(*rowsCursor)}, nil
5741 }
5742
5743 type rowsColumnScannerRows struct {
5744 *rowsCursor
5745 row []driver.Value
5746 }
5747
5748 func (c *rowsColumnScannerRows) NextRow() error {
5749 if c.row == nil {
5750 c.row = make([]driver.Value, len(c.rowsCursor.Columns()))
5751 }
5752 return c.rowsCursor.Next(c.row)
5753 }
5754
5755 func (c *rowsColumnScannerRows) NextResultSet() error {
5756 c.row = nil
5757 return c.rowsCursor.NextResultSet()
5758 }
5759
5760 func (c *rowsColumnScannerRows) ScanColumn(ctx driver.ScanContext, index int, dest any) error {
5761 if index < 0 || index >= len(c.row) {
5762 return fmt.Errorf("index %v out of range", index)
5763 }
5764 switch d := dest.(type) {
5765 case *testStringType:
5766 switch s := c.row[index].(type) {
5767 case string:
5768 d.s = s
5769 return nil
5770 case []byte:
5771 d.s = string(s)
5772 return nil
5773 }
5774 }
5775 return ConvertAssign(ctx, dest, c.row[index])
5776 }
5777
View as plain text