1
2
3
4
5 package sql
6
7 import (
8 "bytes"
9 "context"
10 "database/sql/driver"
11 "errors"
12 "fmt"
13 "internal/race"
14 "internal/testenv"
15 "math/rand"
16 "reflect"
17 "runtime"
18 "slices"
19 "strings"
20 "sync"
21 "sync/atomic"
22 "testing"
23 "testing/synctest"
24 "time"
25 )
26
27 func init() {
28 type dbConn struct {
29 db *DB
30 c *driverConn
31 }
32 freedFrom := make(map[dbConn]string)
33 var mu sync.Mutex
34 getFreedFrom := func(c dbConn) string {
35 mu.Lock()
36 defer mu.Unlock()
37 return freedFrom[c]
38 }
39 setFreedFrom := func(c dbConn, s string) {
40 mu.Lock()
41 defer mu.Unlock()
42 freedFrom[c] = s
43 }
44 putConnHook = func(db *DB, c *driverConn) {
45 if slices.Contains(db.freeConn, c) {
46
47
48
49 println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack())
50 panic("double free of conn.")
51 }
52 setFreedFrom(dbConn{db, c}, stack())
53 }
54 }
55
56 const fakeDBName = "foo"
57
58 var chrisBirthday = time.Unix(123456789, 0)
59
60 func newTestDB(t testing.TB, name string) *DB {
61 return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name)
62 }
63
64 func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
65 fc.name = fakeDBName
66 db := OpenDB(fc)
67 if _, err := db.Exec("WIPE"); err != nil {
68 t.Fatalf("exec wipe: %v", err)
69 }
70 t.Cleanup(func() {
71 closeDB(t, db)
72 })
73 if name == "people" {
74 exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
75 exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
76 exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
77 exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
78 }
79 if name == "magicquery" {
80
81 exec(t, db, "CREATE|magicquery|op=string,millis=int32")
82 exec(t, db, "INSERT|magicquery|op=sleep,millis=10")
83 }
84 if name == "tx_status" {
85
86 exec(t, db, "CREATE|tx_status|tx_status=string")
87 exec(t, db, "INSERT|tx_status|tx_status=invalid")
88 }
89 return db
90 }
91
92 func TestOpenDB(t *testing.T) {
93 synctest.Test(t, testOpenDB)
94 }
95 func testOpenDB(t *testing.T) {
96 db := OpenDB(dsnConnector{dsn: fakeDBName, driver: fdriver})
97 defer db.Close()
98 if db.Driver() != fdriver {
99 t.Fatalf("OpenDB should return the driver of the Connector")
100 }
101 }
102
103 func TestDriverPanic(t *testing.T) {
104 synctest.Test(t, testDriverPanic)
105 }
106 func testDriverPanic(t *testing.T) {
107 db, err := Open("test", fakeDBName)
108 if err != nil {
109 t.Fatalf("Open: %v", err)
110 }
111 defer db.Close()
112
113 expectPanic := func(name string, f func()) {
114 defer func() {
115 err := recover()
116 if err == nil {
117 t.Fatalf("%s did not panic", name)
118 }
119 }()
120 f()
121 }
122
123 expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") })
124 exec(t, db, "WIPE")
125 expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") })
126 exec(t, db, "WIPE")
127 expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") })
128 exec(t, db, "WIPE")
129 exec(t, db, "PANIC|Query|WIPE")
130 exec(t, db, "WIPE")
131
132 exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
133
134 expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") })
135 expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") })
136 expectPanic("Query Close", func() {
137 rows, err := db.Query("PANIC|Close|SELECT|people|age,name|")
138 if err != nil {
139 t.Fatal(err)
140 }
141 rows.Close()
142 })
143 db.Query("PANIC|Exec|SELECT|people|age,name|")
144 exec(t, db, "WIPE")
145 }
146
147 func exec(t testing.TB, db *DB, query string, args ...any) {
148 t.Helper()
149 _, err := db.Exec(query, args...)
150 if err != nil {
151 t.Fatalf("Exec of %q: %v", query, err)
152 }
153 }
154
155 func closeDB(t testing.TB, db *DB) {
156 t.Helper()
157
158 if e := recover(); e != nil {
159 fmt.Printf("Panic: %v\n", e)
160 panic(e)
161 }
162 defer setHookpostCloseConn(nil)
163 setHookpostCloseConn(func(_ *fakeConn, err error) {
164 if err != nil {
165 t.Errorf("Error closing fakeConn: %v", err)
166 }
167 })
168 db.mu.Lock()
169 for i, dc := range db.freeConn {
170 if n := len(dc.openStmt); n > 0 {
171
172
173
174
175
176 t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
177 }
178 }
179 db.mu.Unlock()
180
181 err := db.Close()
182 if err != nil {
183 t.Fatalf("error closing DB: %v", err)
184 }
185
186
187 numOpenConns := func() int {
188 db.mu.Lock()
189 defer db.mu.Unlock()
190 return db.numOpen
191 }
192 if _, ok := t.(*testing.B); ok {
193
194 deadline := time.Now().Add(5 * time.Second)
195 for numOpenConns() > 0 && time.Now().Before(deadline) {
196 time.Sleep(1 * time.Millisecond)
197 }
198 } else {
199 synctest.Wait()
200 }
201 if numOpen := numOpenConns(); numOpen != 0 {
202 t.Fatalf("%d connections still open after closing DB", numOpen)
203 }
204 }
205
206
207
208 func numPrepares(t *testing.T, db *DB) int {
209 if n := len(db.freeConn); n != 1 {
210 t.Fatalf("free conns = %d; want 1", n)
211 }
212 return db.freeConn[0].ci.(*fakeConn).numPrepare
213 }
214
215 func (db *DB) numDeps() int {
216 synctest.Wait()
217 db.mu.Lock()
218 defer db.mu.Unlock()
219 return len(db.dep)
220 }
221
222 func (db *DB) numFreeConns() int {
223 synctest.Wait()
224 db.mu.Lock()
225 defer db.mu.Unlock()
226 return len(db.freeConn)
227 }
228
229
230 func (db *DB) clearAllConns(t *testing.T) {
231 db.SetMaxIdleConns(0)
232
233 if g, w := db.numFreeConns(), 0; g != w {
234 t.Errorf("free conns = %d; want %d", g, w)
235 }
236
237 if n := db.numDeps(); n > 0 {
238 t.Errorf("number of dependencies = %d; expected 0", n)
239 db.dumpDeps(t)
240 }
241 }
242
243 func (db *DB) dumpDeps(t *testing.T) {
244 for fc := range db.dep {
245 db.dumpDep(t, 0, fc, map[finalCloser]bool{})
246 }
247 }
248
249 func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) {
250 seen[dep] = true
251 indent := strings.Repeat(" ", depth)
252 ds := db.dep[dep]
253 for k := range ds {
254 t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k)
255 if fc, ok := k.(finalCloser); ok {
256 if !seen[fc] {
257 db.dumpDep(t, depth+1, fc, seen)
258 }
259 }
260 }
261 }
262
263 func TestQuery(t *testing.T) {
264 synctest.Test(t, testQuery)
265 }
266 func testQuery(t *testing.T) {
267 db := newTestDB(t, "people")
268 prepares0 := numPrepares(t, db)
269 rows, err := db.Query("SELECT|people|age,name|")
270 if err != nil {
271 t.Fatalf("Query: %v", err)
272 }
273 defer rows.Close()
274 type row struct {
275 age int
276 name string
277 }
278 got := []row{}
279 for rows.Next() {
280 var r row
281 err = rows.Scan(&r.age, &r.name)
282 if err != nil {
283 t.Fatalf("Scan: %v", err)
284 }
285 got = append(got, r)
286 }
287 err = rows.Err()
288 if err != nil {
289 t.Fatalf("Err: %v", err)
290 }
291 want := []row{
292 {age: 1, name: "Alice"},
293 {age: 2, name: "Bob"},
294 {age: 3, name: "Chris"},
295 }
296 if !slices.Equal(got, want) {
297 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
298 }
299
300
301
302 if n := db.numFreeConns(); n != 1 {
303 t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
304 }
305 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
306 t.Errorf("executed %d Prepare statements; want 1", prepares)
307 }
308 }
309
310
311 func TestQueryContext(t *testing.T) {
312 synctest.Test(t, testQueryContext)
313 }
314 func testQueryContext(t *testing.T) {
315 db := newTestDB(t, "people")
316 prepares0 := numPrepares(t, db)
317
318 ctx, cancel := context.WithCancel(context.Background())
319 defer cancel()
320
321 rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
322 if err != nil {
323 t.Fatalf("Query: %v", err)
324 }
325 type row struct {
326 age int
327 name string
328 }
329 got := []row{}
330 index := 0
331 for rows.Next() {
332 if index == 2 {
333 cancel()
334 waitForRowsClose(t, rows)
335 }
336 var r row
337 err = rows.Scan(&r.age, &r.name)
338 if err != nil {
339 if index == 2 {
340 break
341 }
342 t.Fatalf("Scan: %v", err)
343 }
344 if index == 2 && err != context.Canceled {
345 t.Fatalf("Scan: %v; want context.Canceled", err)
346 }
347 got = append(got, r)
348 index++
349 }
350 select {
351 case <-ctx.Done():
352 if err := ctx.Err(); err != context.Canceled {
353 t.Fatalf("context err = %v; want context.Canceled", err)
354 }
355 default:
356 t.Fatalf("context err = nil; want context.Canceled")
357 }
358 want := []row{
359 {age: 1, name: "Alice"},
360 {age: 2, name: "Bob"},
361 }
362 if !slices.Equal(got, want) {
363 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
364 }
365
366
367
368 waitForRowsClose(t, rows)
369 waitForFree(t, db, 1)
370 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
371 t.Errorf("executed %d Prepare statements; want 1", prepares)
372 }
373 }
374
375
376
377 func waitForFree(t *testing.T, db *DB, want int) {
378 synctest.Wait()
379 numFree := db.numFreeConns()
380 if numFree != want {
381 t.Fatalf("free conns after hitting EOF = %d; want %d", numFree, want)
382 }
383 }
384
385 func waitForRowsClose(t *testing.T, rows *Rows) {
386 synctest.Wait()
387 rows.closemu.RLock()
388 defer rows.closemu.RUnlock()
389 if !rows.closed {
390 t.Fatal("failed to close rows")
391 }
392 }
393
394
395
396 func TestQueryContextWait(t *testing.T) {
397 synctest.Test(t, testQueryContextWait)
398 }
399 func testQueryContextWait(t *testing.T) {
400 db := newTestDB(t, "people")
401 prepares0 := numPrepares(t, db)
402
403 ctx, cancel := context.WithCancel(context.Background())
404 defer cancel()
405
406
407
408
409 c, err := db.Conn(ctx)
410 if err != nil {
411 t.Fatal(err)
412 }
413
414 c.dc.ci.(*fakeConn).waiter = func(c context.Context) {
415 cancel()
416 <-ctx.Done()
417 }
418 _, err = c.QueryContext(ctx, "SELECT|people|age,name|")
419 c.Close()
420 if err != context.Canceled {
421 t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
422 }
423
424
425 waitForFree(t, db, 1)
426 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
427 t.Fatalf("executed %d Prepare statements; want 1", prepares)
428 }
429 }
430
431
432
433 func TestTxContextWait(t *testing.T) {
434 synctest.Test(t, func(t *testing.T) {
435 testContextWait(t, false)
436 })
437 }
438
439
440
441 func TestTxContextWaitNoDiscard(t *testing.T) {
442 synctest.Test(t, func(t *testing.T) {
443 testContextWait(t, true)
444 })
445 }
446
447 func testContextWait(t *testing.T, keepConnOnRollback bool) {
448 db := newTestDB(t, "people")
449
450 ctx, cancel := context.WithCancel(context.Background())
451
452 tx, err := db.BeginTx(ctx, nil)
453 if err != nil {
454 t.Fatal(err)
455 }
456 tx.keepConnOnRollback = keepConnOnRollback
457
458 tx.dc.ci.(*fakeConn).waiter = func(c context.Context) {
459 cancel()
460 <-ctx.Done()
461 }
462
463
464
465 _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
466 if err != context.Canceled {
467 t.Fatalf("expected QueryContext to error with context canceled but returned %v", err)
468 }
469
470 if keepConnOnRollback {
471 waitForFree(t, db, 1)
472 } else {
473 waitForFree(t, db, 0)
474 }
475 }
476
477
478
479
480 func TestUnsupportedOptions(t *testing.T) {
481 synctest.Test(t, testUnsupportedOptions)
482 }
483 func testUnsupportedOptions(t *testing.T) {
484 db := newTestDB(t, "people")
485 _, err := db.BeginTx(context.Background(), &TxOptions{
486 Isolation: LevelSerializable, ReadOnly: true,
487 })
488 if err == nil {
489 t.Fatal("expected error when using unsupported options, got nil")
490 }
491 }
492
493 func TestMultiResultSetQuery(t *testing.T) {
494 synctest.Test(t, testMultiResultSetQuery)
495 }
496 func testMultiResultSetQuery(t *testing.T) {
497 db := newTestDB(t, "people")
498 prepares0 := numPrepares(t, db)
499 rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
500 if err != nil {
501 t.Fatalf("Query: %v", err)
502 }
503 type row1 struct {
504 age int
505 name string
506 }
507 type row2 struct {
508 name string
509 }
510 got1 := []row1{}
511 for rows.Next() {
512 var r row1
513 err = rows.Scan(&r.age, &r.name)
514 if err != nil {
515 t.Fatalf("Scan: %v", err)
516 }
517 got1 = append(got1, r)
518 }
519 err = rows.Err()
520 if err != nil {
521 t.Fatalf("Err: %v", err)
522 }
523 want1 := []row1{
524 {age: 1, name: "Alice"},
525 {age: 2, name: "Bob"},
526 {age: 3, name: "Chris"},
527 }
528 if !slices.Equal(got1, want1) {
529 t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
530 }
531
532 if !rows.NextResultSet() {
533 t.Errorf("expected another result set")
534 }
535
536 got2 := []row2{}
537 for rows.Next() {
538 var r row2
539 err = rows.Scan(&r.name)
540 if err != nil {
541 t.Fatalf("Scan: %v", err)
542 }
543 got2 = append(got2, r)
544 }
545 err = rows.Err()
546 if err != nil {
547 t.Fatalf("Err: %v", err)
548 }
549 want2 := []row2{
550 {name: "Alice"},
551 {name: "Bob"},
552 {name: "Chris"},
553 }
554 if !slices.Equal(got2, want2) {
555 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
556 }
557 if rows.NextResultSet() {
558 t.Errorf("expected no more result sets")
559 }
560
561
562
563 waitForFree(t, db, 1)
564 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
565 t.Errorf("executed %d Prepare statements; want 1", prepares)
566 }
567 }
568
569 func TestQueryNamedArg(t *testing.T) {
570 synctest.Test(t, testQueryNamedArg)
571 }
572 func testQueryNamedArg(t *testing.T) {
573 db := newTestDB(t, "people")
574 prepares0 := numPrepares(t, db)
575 rows, err := db.Query(
576
577 "SELECT|people|age,name|name=?name,age=?age",
578 Named("age", 2),
579 Named("name", "Bob"),
580 )
581 if err != nil {
582 t.Fatalf("Query: %v", err)
583 }
584 type row struct {
585 age int
586 name string
587 }
588 got := []row{}
589 for rows.Next() {
590 var r row
591 err = rows.Scan(&r.age, &r.name)
592 if err != nil {
593 t.Fatalf("Scan: %v", err)
594 }
595 got = append(got, r)
596 }
597 err = rows.Err()
598 if err != nil {
599 t.Fatalf("Err: %v", err)
600 }
601 want := []row{
602 {age: 2, name: "Bob"},
603 }
604 if !slices.Equal(got, want) {
605 t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
606 }
607
608
609
610 if n := db.numFreeConns(); n != 1 {
611 t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
612 }
613 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
614 t.Errorf("executed %d Prepare statements; want 1", prepares)
615 }
616 }
617
618 func TestPoolExhaustOnCancel(t *testing.T) {
619 synctest.Test(t, testPoolExhaustOnCancel)
620 }
621 func testPoolExhaustOnCancel(t *testing.T) {
622 max := 3
623 var saturate, saturateDone sync.WaitGroup
624 saturate.Add(max)
625 saturateDone.Add(max)
626
627 donePing := make(chan bool)
628 state := 0
629
630
631
632
633
634
635
636 waiter := func(ctx context.Context) {
637 switch state {
638 case 0:
639
640 case 1:
641 saturate.Done()
642 select {
643 case <-ctx.Done():
644 case <-donePing:
645 }
646 case 2:
647 }
648 }
649 db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people")
650
651 db.SetMaxOpenConns(max)
652
653
654
655
656 state = 1
657 for i := 0; i < max; i++ {
658 go func() {
659 rows, err := db.Query("SELECT|people|name,photo|")
660 if err != nil {
661 t.Errorf("Query: %v", err)
662 return
663 }
664 rows.Close()
665 saturateDone.Done()
666 }()
667 }
668
669 saturate.Wait()
670 if t.Failed() {
671 t.FailNow()
672 }
673 state = 2
674
675
676 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
677 defer cancel()
678
679 for i := 0; i < max; i++ {
680 ctxReq, cancelReq := context.WithCancel(ctx)
681 go func() {
682 time.Sleep(100 * time.Millisecond)
683 cancelReq()
684 }()
685 err := db.PingContext(ctxReq)
686 if err != context.Canceled {
687 t.Fatalf("PingContext (Exhaust): %v", err)
688 }
689 }
690 close(donePing)
691 saturateDone.Wait()
692
693
694 err := db.PingContext(ctx)
695 if err != nil {
696 t.Fatalf("PingContext (Normal): %v", err)
697 }
698 }
699
700 func TestRowsColumns(t *testing.T) {
701 synctest.Test(t, testRowsColumns)
702 }
703 func testRowsColumns(t *testing.T) {
704 db := newTestDB(t, "people")
705 rows, err := db.Query("SELECT|people|age,name|")
706 if err != nil {
707 t.Fatalf("Query: %v", err)
708 }
709 cols, err := rows.Columns()
710 if err != nil {
711 t.Fatalf("Columns: %v", err)
712 }
713 want := []string{"age", "name"}
714 if !slices.Equal(cols, want) {
715 t.Errorf("got %#v; want %#v", cols, want)
716 }
717 if err := rows.Close(); err != nil {
718 t.Errorf("error closing rows: %s", err)
719 }
720 }
721
722 func TestRowsColumnTypes(t *testing.T) {
723 synctest.Test(t, testRowsColumnTypes)
724 }
725 func testRowsColumnTypes(t *testing.T) {
726 db := newTestDB(t, "people")
727 defer closeDB(t, db)
728 rows, err := db.Query("SELECT|people|age,name|")
729 if err != nil {
730 t.Fatalf("Query: %v", err)
731 }
732 tt, err := rows.ColumnTypes()
733 if err != nil {
734 t.Fatalf("ColumnTypes: %v", err)
735 }
736
737 types := make([]reflect.Type, len(tt))
738 for i, tp := range tt {
739 st := tp.ScanType()
740 if st == nil {
741 t.Errorf("scantype is null for column %q", tp.Name())
742 continue
743 }
744 types[i] = st
745 }
746 values := make([]any, len(tt))
747 for i := range values {
748 values[i] = reflect.New(types[i]).Interface()
749 }
750 ct := 0
751 for rows.Next() {
752 err = rows.Scan(values...)
753 if err != nil {
754 t.Fatalf("failed to scan values in %v", err)
755 }
756 if ct == 1 {
757 if age := *values[0].(*int32); age != 2 {
758 t.Errorf("Expected 2, got %v", age)
759 }
760 if name := *values[1].(*string); name != "Bob" {
761 t.Errorf("Expected Bob, got %v", name)
762 }
763 }
764 ct++
765 }
766 if ct != 3 {
767 t.Errorf("expected 3 rows, got %d", ct)
768 }
769
770 if err := rows.Close(); err != nil {
771 t.Errorf("error closing rows: %s", err)
772 }
773 }
774
775 func TestQueryRow(t *testing.T) {
776 synctest.Test(t, testQueryRow)
777 }
778 func testQueryRow(t *testing.T) {
779 db := newTestDB(t, "people")
780 defer closeDB(t, db)
781 var name string
782 var age int
783 var birthday time.Time
784
785 err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age)
786 if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") {
787 t.Errorf("expected error from wrong number of arguments; actually got: %v", err)
788 }
789
790 err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday)
791 if err != nil || !birthday.Equal(chrisBirthday) {
792 t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday)
793 }
794
795 err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name)
796 if err != nil {
797 t.Fatalf("age QueryRow+Scan: %v", err)
798 }
799 if name != "Bob" {
800 t.Errorf("expected name Bob, got %q", name)
801 }
802 if age != 2 {
803 t.Errorf("expected age 2, got %d", age)
804 }
805
806 err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name)
807 if err != nil {
808 t.Fatalf("name QueryRow+Scan: %v", err)
809 }
810 if name != "Alice" {
811 t.Errorf("expected name Alice, got %q", name)
812 }
813 if age != 1 {
814 t.Errorf("expected age 1, got %d", age)
815 }
816
817 var photo []byte
818 err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
819 if err != nil {
820 t.Fatalf("photo QueryRow+Scan: %v", err)
821 }
822 want := []byte("APHOTO")
823 if !slices.Equal(photo, want) {
824 t.Errorf("photo = %q; want %q", photo, want)
825 }
826 }
827
828 func TestRowErr(t *testing.T) {
829 synctest.Test(t, testRowErr)
830 }
831 func testRowErr(t *testing.T) {
832 db := newTestDB(t, "people")
833
834 row := db.QueryRowContext(context.Background(), "SELECT|people|bdate|age=?", 3)
835 if err := row.Err(); err != nil {
836 t.Errorf("Unexpected err = %v; want %v", err, nil)
837 }
838 row.Scan()
839
840 ctx, cancel := context.WithCancel(context.Background())
841 cancel()
842
843 err := db.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 3).Err()
844 exp := "context canceled"
845 if err == nil || !strings.Contains(err.Error(), exp) {
846 t.Errorf("Expected err = %v; got %v", exp, err)
847 }
848 }
849
850 func TestTxRollbackCommitErr(t *testing.T) {
851 synctest.Test(t, testTxRollbackCommitErr)
852 }
853 func testTxRollbackCommitErr(t *testing.T) {
854 db := newTestDB(t, "people")
855 defer closeDB(t, db)
856
857 tx, err := db.Begin()
858 if err != nil {
859 t.Fatal(err)
860 }
861 err = tx.Rollback()
862 if err != nil {
863 t.Errorf("expected nil error from Rollback; got %v", err)
864 }
865 err = tx.Commit()
866 if err != ErrTxDone {
867 t.Errorf("expected %q from Commit; got %q", ErrTxDone, err)
868 }
869
870 tx, err = db.Begin()
871 if err != nil {
872 t.Fatal(err)
873 }
874 err = tx.Commit()
875 if err != nil {
876 t.Errorf("expected nil error from Commit; got %v", err)
877 }
878 err = tx.Rollback()
879 if err != ErrTxDone {
880 t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err)
881 }
882 }
883
884 func TestStatementErrorAfterClose(t *testing.T) {
885 synctest.Test(t, testStatementErrorAfterClose)
886 }
887 func testStatementErrorAfterClose(t *testing.T) {
888 db := newTestDB(t, "people")
889 defer closeDB(t, db)
890 stmt, err := db.Prepare("SELECT|people|age|name=?")
891 if err != nil {
892 t.Fatalf("Prepare: %v", err)
893 }
894 err = stmt.Close()
895 if err != nil {
896 t.Fatalf("Close: %v", err)
897 }
898 var name string
899 err = stmt.QueryRow("foo").Scan(&name)
900 if err == nil {
901 t.Errorf("expected error from QueryRow.Scan after Stmt.Close")
902 }
903 }
904
905 func TestStatementQueryRow(t *testing.T) {
906 synctest.Test(t, testStatementQueryRow)
907 }
908 func testStatementQueryRow(t *testing.T) {
909 db := newTestDB(t, "people")
910 defer closeDB(t, db)
911 stmt, err := db.Prepare("SELECT|people|age|name=?")
912 if err != nil {
913 t.Fatalf("Prepare: %v", err)
914 }
915 defer stmt.Close()
916 var age int
917 for n, tt := range []struct {
918 name string
919 want int
920 }{
921 {"Alice", 1},
922 {"Bob", 2},
923 {"Chris", 3},
924 } {
925 if err := stmt.QueryRow(tt.name).Scan(&age); err != nil {
926 t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err)
927 } else if age != tt.want {
928 t.Errorf("%d: age=%d, want %d", n, age, tt.want)
929 }
930 }
931 }
932
933 type stubDriverStmt struct {
934 err error
935 }
936
937 func (s stubDriverStmt) Close() error {
938 return s.err
939 }
940
941 func (s stubDriverStmt) NumInput() int {
942 return -1
943 }
944
945 func (s stubDriverStmt) Exec(args []driver.Value) (driver.Result, error) {
946 return nil, nil
947 }
948
949 func (s stubDriverStmt) Query(args []driver.Value) (driver.Rows, error) {
950 return nil, nil
951 }
952
953
954 func TestStatementClose(t *testing.T) {
955 synctest.Test(t, testStatementClose)
956 }
957 func testStatementClose(t *testing.T) {
958 want := errors.New("STMT ERROR")
959
960 tests := []struct {
961 stmt *Stmt
962 msg string
963 }{
964 {&Stmt{stickyErr: want}, "stickyErr not propagated"},
965 {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
966 }
967 for _, test := range tests {
968 if err := test.stmt.Close(); err != want {
969 t.Errorf("%s. Got stmt.Close() = %v, want = %v", test.msg, err, want)
970 }
971 }
972 }
973
974
975 func TestStatementQueryRowConcurrent(t *testing.T) {
976 synctest.Test(t, testStatementQueryRowConcurrent)
977 }
978 func testStatementQueryRowConcurrent(t *testing.T) {
979 db := newTestDB(t, "people")
980 defer closeDB(t, db)
981 stmt, err := db.Prepare("SELECT|people|age|name=?")
982 if err != nil {
983 t.Fatalf("Prepare: %v", err)
984 }
985 defer stmt.Close()
986
987 const n = 10
988 ch := make(chan error, n)
989 for i := 0; i < n; i++ {
990 go func() {
991 var age int
992 err := stmt.QueryRow("Alice").Scan(&age)
993 if err == nil && age != 1 {
994 err = fmt.Errorf("unexpected age %d", age)
995 }
996 ch <- err
997 }()
998 }
999 for i := 0; i < n; i++ {
1000 if err := <-ch; err != nil {
1001 t.Error(err)
1002 }
1003 }
1004 }
1005
1006
1007 func TestBogusPreboundParameters(t *testing.T) {
1008 synctest.Test(t, testBogusPreboundParameters)
1009 }
1010 func testBogusPreboundParameters(t *testing.T) {
1011 db := newTestDB(t, "foo")
1012 defer closeDB(t, db)
1013 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1014 _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
1015 if err == nil {
1016 t.Fatalf("expected error")
1017 }
1018 if err.Error() != `fakedb: invalid conversion to int32 from "bogusconversion"` {
1019 t.Errorf("unexpected error: %v", err)
1020 }
1021 }
1022
1023 func TestExec(t *testing.T) {
1024 synctest.Test(t, testExec)
1025 }
1026 func testExec(t *testing.T) {
1027 db := newTestDB(t, "foo")
1028 defer closeDB(t, db)
1029 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1030 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1031 if err != nil {
1032 t.Errorf("Stmt, err = %v, %v", stmt, err)
1033 }
1034 defer stmt.Close()
1035
1036 type execTest struct {
1037 args []any
1038 wantErr string
1039 }
1040 execTests := []execTest{
1041
1042 {[]any{"Brad", 31}, ""},
1043 {[]any{"Brad", int64(31)}, ""},
1044 {[]any{"Bob", "32"}, ""},
1045 {[]any{7, 9}, ""},
1046
1047
1048 {[]any{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"},
1049 {[]any{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`},
1050
1051
1052 {[]any{}, "sql: expected 2 arguments, got 0"},
1053 {[]any{1, 2, 3}, "sql: expected 2 arguments, got 3"},
1054 }
1055 for n, et := range execTests {
1056 _, err := stmt.Exec(et.args...)
1057 errStr := ""
1058 if err != nil {
1059 errStr = err.Error()
1060 }
1061 if errStr != et.wantErr {
1062 t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q",
1063 n, et.args, errStr, et.wantErr)
1064 }
1065 }
1066 }
1067
1068 func TestTxPrepare(t *testing.T) {
1069 synctest.Test(t, testTxPrepare)
1070 }
1071 func testTxPrepare(t *testing.T) {
1072 db := newTestDB(t, "")
1073 defer closeDB(t, db)
1074 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1075 tx, err := db.Begin()
1076 if err != nil {
1077 t.Fatalf("Begin = %v", err)
1078 }
1079 stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
1080 if err != nil {
1081 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1082 }
1083 defer stmt.Close()
1084 _, err = stmt.Exec("Bobby", 7)
1085 if err != nil {
1086 t.Fatalf("Exec = %v", err)
1087 }
1088 err = tx.Commit()
1089 if err != nil {
1090 t.Fatalf("Commit = %v", err)
1091 }
1092
1093 if !stmt.closed {
1094 t.Fatal("Stmt not closed after Commit")
1095 }
1096 }
1097
1098 func TestTxStmt(t *testing.T) {
1099 synctest.Test(t, testTxStmt)
1100 }
1101 func testTxStmt(t *testing.T) {
1102 db := newTestDB(t, "")
1103 defer closeDB(t, db)
1104 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1105 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1106 if err != nil {
1107 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1108 }
1109 defer stmt.Close()
1110 tx, err := db.Begin()
1111 if err != nil {
1112 t.Fatalf("Begin = %v", err)
1113 }
1114 txs := tx.Stmt(stmt)
1115 defer txs.Close()
1116 _, err = txs.Exec("Bobby", 7)
1117 if err != nil {
1118 t.Fatalf("Exec = %v", err)
1119 }
1120 err = tx.Commit()
1121 if err != nil {
1122 t.Fatalf("Commit = %v", err)
1123 }
1124
1125 if !txs.closed {
1126 t.Fatal("Stmt not closed after Commit")
1127 }
1128 }
1129
1130 func TestTxStmtPreparedOnce(t *testing.T) {
1131 synctest.Test(t, testTxStmtPreparedOnce)
1132 }
1133 func testTxStmtPreparedOnce(t *testing.T) {
1134 db := newTestDB(t, "")
1135 defer closeDB(t, db)
1136 exec(t, db, "CREATE|t1|name=string,age=int32")
1137
1138 prepares0 := numPrepares(t, db)
1139
1140
1141 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1142 if err != nil {
1143 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1144 }
1145 defer stmt.Close()
1146
1147 tx, err := db.Begin()
1148 if err != nil {
1149 t.Fatalf("Begin = %v", err)
1150 }
1151
1152 txs1 := tx.Stmt(stmt)
1153 txs2 := tx.Stmt(stmt)
1154
1155 _, err = txs1.Exec("Go", 7)
1156 if err != nil {
1157 t.Fatalf("Exec = %v", err)
1158 }
1159 txs1.Close()
1160
1161 _, err = txs2.Exec("Gopher", 8)
1162 if err != nil {
1163 t.Fatalf("Exec = %v", err)
1164 }
1165 txs2.Close()
1166
1167 err = tx.Commit()
1168 if err != nil {
1169 t.Fatalf("Commit = %v", err)
1170 }
1171
1172 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
1173 t.Errorf("executed %d Prepare statements; want 1", prepares)
1174 }
1175 }
1176
1177 func TestTxStmtClosedRePrepares(t *testing.T) {
1178 synctest.Test(t, testTxStmtClosedRePrepares)
1179 }
1180 func testTxStmtClosedRePrepares(t *testing.T) {
1181 db := newTestDB(t, "")
1182 defer closeDB(t, db)
1183 exec(t, db, "CREATE|t1|name=string,age=int32")
1184
1185 prepares0 := numPrepares(t, db)
1186
1187
1188 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1189 if err != nil {
1190 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1191 }
1192 tx, err := db.Begin()
1193 if err != nil {
1194 t.Fatalf("Begin = %v", err)
1195 }
1196 err = stmt.Close()
1197 if err != nil {
1198 t.Fatalf("stmt.Close() = %v", err)
1199 }
1200
1201 txs := tx.Stmt(stmt)
1202 if txs.stickyErr != nil {
1203 t.Fatal(txs.stickyErr)
1204 }
1205 if txs.parentStmt != nil {
1206 t.Fatal("expected nil parentStmt")
1207 }
1208 _, err = txs.Exec(`Eric`, 82)
1209 if err != nil {
1210 t.Fatalf("txs.Exec = %v", err)
1211 }
1212
1213 err = txs.Close()
1214 if err != nil {
1215 t.Fatalf("txs.Close = %v", err)
1216 }
1217
1218 tx.Rollback()
1219
1220 if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
1221 t.Errorf("executed %d Prepare statements; want 2", prepares)
1222 }
1223 }
1224
1225 func TestParentStmtOutlivesTxStmt(t *testing.T) {
1226 synctest.Test(t, testParentStmtOutlivesTxStmt)
1227 }
1228 func testParentStmtOutlivesTxStmt(t *testing.T) {
1229 db := newTestDB(t, "")
1230 defer closeDB(t, db)
1231 exec(t, db, "CREATE|t1|name=string,age=int32")
1232
1233
1234 db.SetMaxOpenConns(1)
1235
1236 prepares0 := numPrepares(t, db)
1237
1238
1239 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1240 if err != nil {
1241 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1242 }
1243 defer stmt.Close()
1244 tx, err := db.Begin()
1245 if err != nil {
1246 t.Fatalf("Begin = %v", err)
1247 }
1248 txs := tx.Stmt(stmt)
1249 if len(stmt.css) != 1 {
1250 t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css))
1251 }
1252 err = txs.Close()
1253 if err != nil {
1254 t.Fatalf("txs.Close() = %v", err)
1255 }
1256 err = tx.Rollback()
1257 if err != nil {
1258 t.Fatalf("tx.Rollback() = %v", err)
1259 }
1260
1261 _, err = txs.Exec("Suzan", 30)
1262 if err == nil {
1263 t.Fatalf("txs.Exec(), expected err")
1264 }
1265
1266 _, err = stmt.Exec("Janina", 25)
1267 if err != nil {
1268 t.Fatalf("stmt.Exec() = %v", err)
1269 }
1270
1271 if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
1272 t.Errorf("executed %d Prepare statements; want 1", prepares)
1273 }
1274 }
1275
1276
1277
1278
1279 func TestTxStmtFromTxStmtRePrepares(t *testing.T) {
1280 synctest.Test(t, testTxStmtFromTxStmtRePrepares)
1281 }
1282 func testTxStmtFromTxStmtRePrepares(t *testing.T) {
1283 db := newTestDB(t, "")
1284 defer closeDB(t, db)
1285 exec(t, db, "CREATE|t1|name=string,age=int32")
1286 prepares0 := numPrepares(t, db)
1287
1288 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1289 if err != nil {
1290 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1291 }
1292 defer stmt.Close()
1293
1294 tx, err := db.Begin()
1295 if err != nil {
1296 t.Fatalf("Begin = %v", err)
1297 }
1298 txs1 := tx.Stmt(stmt)
1299
1300
1301
1302 txs2 := tx.Stmt(txs1)
1303 if txs2.stickyErr != nil {
1304 t.Fatal(txs2.stickyErr)
1305 }
1306 if txs2.parentStmt != nil {
1307 t.Fatal("expected nil parentStmt")
1308 }
1309 _, err = txs2.Exec(`Eric`, 82)
1310 if err != nil {
1311 t.Fatal(err)
1312 }
1313
1314 err = txs1.Close()
1315 if err != nil {
1316 t.Fatalf("txs1.Close = %v", err)
1317 }
1318 err = txs2.Close()
1319 if err != nil {
1320 t.Fatalf("txs1.Close = %v", err)
1321 }
1322 err = tx.Rollback()
1323 if err != nil {
1324 t.Fatalf("tx.Rollback = %v", err)
1325 }
1326
1327 if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
1328 t.Errorf("executed %d Prepare statements; want 2", prepares)
1329 }
1330 }
1331
1332
1333
1334
1335 func TestTxQuery(t *testing.T) {
1336 synctest.Test(t, testTxQuery)
1337 }
1338 func testTxQuery(t *testing.T) {
1339 db := newTestDB(t, "")
1340 defer closeDB(t, db)
1341 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1342 exec(t, db, "INSERT|t1|name=Alice")
1343
1344 tx, err := db.Begin()
1345 if err != nil {
1346 t.Fatal(err)
1347 }
1348 defer tx.Rollback()
1349
1350 r, err := tx.Query("SELECT|t1|name|")
1351 if err != nil {
1352 t.Fatal(err)
1353 }
1354 defer r.Close()
1355
1356 if !r.Next() {
1357 if r.Err() != nil {
1358 t.Fatal(r.Err())
1359 }
1360 t.Fatal("expected one row")
1361 }
1362
1363 var x string
1364 err = r.Scan(&x)
1365 if err != nil {
1366 t.Fatal(err)
1367 }
1368 }
1369
1370 func TestTxQueryInvalid(t *testing.T) {
1371 synctest.Test(t, testTxQueryInvalid)
1372 }
1373 func testTxQueryInvalid(t *testing.T) {
1374 db := newTestDB(t, "")
1375 defer closeDB(t, db)
1376
1377 tx, err := db.Begin()
1378 if err != nil {
1379 t.Fatal(err)
1380 }
1381 defer tx.Rollback()
1382
1383 _, err = tx.Query("SELECT|t1|name|")
1384 if err == nil {
1385 t.Fatal("Error expected")
1386 }
1387 }
1388
1389
1390
1391 func TestTxErrBadConn(t *testing.T) {
1392 synctest.Test(t, testTxErrBadConn)
1393 }
1394 func testTxErrBadConn(t *testing.T) {
1395 db, err := Open("test", fakeDBName+";badConn")
1396 if err != nil {
1397 t.Fatalf("Open: %v", err)
1398 }
1399 if _, err := db.Exec("WIPE"); err != nil {
1400 t.Fatalf("exec wipe: %v", err)
1401 }
1402 defer closeDB(t, db)
1403 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
1404 stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
1405 if err != nil {
1406 t.Fatalf("Stmt, err = %v, %v", stmt, err)
1407 }
1408 defer stmt.Close()
1409 tx, err := db.Begin()
1410 if err != nil {
1411 t.Fatalf("Begin = %v", err)
1412 }
1413 txs := tx.Stmt(stmt)
1414 defer txs.Close()
1415 _, err = txs.Exec("Bobby", 7)
1416 if err != nil {
1417 t.Fatalf("Exec = %v", err)
1418 }
1419 err = tx.Commit()
1420 if err != nil {
1421 t.Fatalf("Commit = %v", err)
1422 }
1423 }
1424
1425 func TestConnQuery(t *testing.T) {
1426 synctest.Test(t, testConnQuery)
1427 }
1428 func testConnQuery(t *testing.T) {
1429 db := newTestDB(t, "people")
1430 defer closeDB(t, db)
1431
1432 ctx := t.Context()
1433 conn, err := db.Conn(ctx)
1434 if err != nil {
1435 t.Fatal(err)
1436 }
1437 conn.dc.ci.(*fakeConn).skipDirtySession = true
1438 defer conn.Close()
1439
1440 var name string
1441 err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", 3).Scan(&name)
1442 if err != nil {
1443 t.Fatal(err)
1444 }
1445 if name != "Chris" {
1446 t.Fatalf("unexpected result, got %q want Chris", name)
1447 }
1448
1449 err = conn.PingContext(ctx)
1450 if err != nil {
1451 t.Fatal(err)
1452 }
1453 }
1454
1455 func TestConnRaw(t *testing.T) {
1456 synctest.Test(t, testConnRaw)
1457 }
1458 func testConnRaw(t *testing.T) {
1459 db := newTestDB(t, "people")
1460 defer closeDB(t, db)
1461
1462 ctx := t.Context()
1463 conn, err := db.Conn(ctx)
1464 if err != nil {
1465 t.Fatal(err)
1466 }
1467 conn.dc.ci.(*fakeConn).skipDirtySession = true
1468 defer conn.Close()
1469
1470 sawFunc := false
1471 err = conn.Raw(func(dc any) error {
1472 sawFunc = true
1473 if _, ok := dc.(*fakeConn); !ok {
1474 return fmt.Errorf("got %T want *fakeConn", dc)
1475 }
1476 return nil
1477 })
1478 if err != nil {
1479 t.Fatal(err)
1480 }
1481 if !sawFunc {
1482 t.Fatal("Raw func not called")
1483 }
1484
1485 func() {
1486 defer func() {
1487 x := recover()
1488 if x == nil {
1489 t.Fatal("expected panic")
1490 }
1491 conn.closemu.Lock()
1492 closed := conn.dc == nil
1493 conn.closemu.Unlock()
1494 if !closed {
1495 t.Fatal("expected connection to be closed after panic")
1496 }
1497 }()
1498 err = conn.Raw(func(dc any) error {
1499 panic("Conn.Raw panic should return an error")
1500 })
1501 t.Fatal("expected panic from Raw func")
1502 }()
1503 }
1504
1505 func TestCursorFake(t *testing.T) {
1506 synctest.Test(t, testCursorFake)
1507 }
1508 func testCursorFake(t *testing.T) {
1509 db := newTestDB(t, "people")
1510 defer closeDB(t, db)
1511
1512 ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
1513 defer cancel()
1514
1515 exec(t, db, "CREATE|peoplecursor|list=table")
1516 exec(t, db, "INSERT|peoplecursor|list=people!name!age")
1517
1518 rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`)
1519 if err != nil {
1520 t.Fatal(err)
1521 }
1522 defer rows.Close()
1523
1524 if !rows.Next() {
1525 t.Fatal("no rows")
1526 }
1527 var cursor = &Rows{}
1528 err = rows.Scan(cursor)
1529 if err != nil {
1530 t.Fatal(err)
1531 }
1532 defer cursor.Close()
1533
1534 const expectedRows = 3
1535 var currentRow int64
1536
1537 var n int64
1538 var s string
1539 for cursor.Next() {
1540 currentRow++
1541 err = cursor.Scan(&s, &n)
1542 if err != nil {
1543 t.Fatal(err)
1544 }
1545 if n != currentRow {
1546 t.Errorf("expected number(Age)=%d, got %d", currentRow, n)
1547 }
1548 }
1549 if currentRow != expectedRows {
1550 t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow)
1551 }
1552 }
1553
1554 func TestInvalidNilValues(t *testing.T) {
1555 var date1 time.Time
1556 var date2 int
1557
1558 tests := []struct {
1559 name string
1560 input any
1561 expectedError string
1562 }{
1563 {
1564 name: "time.Time",
1565 input: &date1,
1566 expectedError: `sql: Scan error on column index 0, name "bdate": unsupported Scan, storing driver.Value type <nil> into type *time.Time`,
1567 },
1568 {
1569 name: "int",
1570 input: &date2,
1571 expectedError: `sql: Scan error on column index 0, name "bdate": converting NULL to int is unsupported`,
1572 },
1573 }
1574
1575 for _, tt := range tests {
1576 synctestSubtest(t, tt.name, func(t *testing.T) {
1577 db := newTestDB(t, "people")
1578 defer closeDB(t, db)
1579
1580 ctx := t.Context()
1581 conn, err := db.Conn(ctx)
1582 if err != nil {
1583 t.Fatal(err)
1584 }
1585 conn.dc.ci.(*fakeConn).skipDirtySession = true
1586 defer conn.Close()
1587
1588 err = conn.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 1).Scan(tt.input)
1589 if err == nil {
1590 t.Fatal("expected error when querying nil column, but succeeded")
1591 }
1592 if err.Error() != tt.expectedError {
1593 t.Fatalf("Expected error: %s\nReceived: %s", tt.expectedError, err.Error())
1594 }
1595
1596 err = conn.PingContext(ctx)
1597 if err != nil {
1598 t.Fatal(err)
1599 }
1600 })
1601 }
1602 }
1603
1604 func TestConnTx(t *testing.T) {
1605 synctest.Test(t, testConnTx)
1606 }
1607 func testConnTx(t *testing.T) {
1608 db := newTestDB(t, "people")
1609 defer closeDB(t, db)
1610
1611 ctx := t.Context()
1612 conn, err := db.Conn(ctx)
1613 if err != nil {
1614 t.Fatal(err)
1615 }
1616 conn.dc.ci.(*fakeConn).skipDirtySession = true
1617 defer conn.Close()
1618
1619 tx, err := conn.BeginTx(ctx, nil)
1620 if err != nil {
1621 t.Fatal(err)
1622 }
1623 insertName, insertAge := "Nancy", 33
1624 _, err = tx.ExecContext(ctx, "INSERT|people|name=?,age=?,photo=APHOTO", insertName, insertAge)
1625 if err != nil {
1626 t.Fatal(err)
1627 }
1628 err = tx.Commit()
1629 if err != nil {
1630 t.Fatal(err)
1631 }
1632
1633 var selectName string
1634 err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", insertAge).Scan(&selectName)
1635 if err != nil {
1636 t.Fatal(err)
1637 }
1638 if selectName != insertName {
1639 t.Fatalf("got %q want %q", selectName, insertName)
1640 }
1641 }
1642
1643
1644
1645
1646 func TestConnIsValid(t *testing.T) {
1647 synctest.Test(t, testConnIsValid)
1648 }
1649 func testConnIsValid(t *testing.T) {
1650 db := newTestDB(t, "people")
1651 defer closeDB(t, db)
1652
1653 db.SetMaxOpenConns(1)
1654
1655 ctx := context.Background()
1656
1657 c, err := db.Conn(ctx)
1658 if err != nil {
1659 t.Fatal(err)
1660 }
1661
1662 err = c.Raw(func(raw any) error {
1663 dc := raw.(*fakeConn)
1664 dc.stickyBad = true
1665 return nil
1666 })
1667 if err != nil {
1668 t.Fatal(err)
1669 }
1670 c.Close()
1671
1672 if len(db.freeConn) > 0 && db.freeConn[0].ci.(*fakeConn).stickyBad {
1673 t.Fatal("bad connection returned to pool; expected bad connection to be discarded")
1674 }
1675 }
1676
1677
1678
1679 func TestIssue2542Deadlock(t *testing.T) {
1680 synctest.Test(t, testIssue2542Deadlock)
1681 }
1682 func testIssue2542Deadlock(t *testing.T) {
1683 db := newTestDB(t, "people")
1684 closeDB(t, db)
1685 for i := 0; i < 2; i++ {
1686 _, err := db.Query("SELECT|people|age,name|")
1687 if err == nil {
1688 t.Fatalf("expected error")
1689 }
1690 }
1691 }
1692
1693
1694 func TestCloseStmtBeforeRows(t *testing.T) {
1695 synctest.Test(t, testCloseStmtBeforeRows)
1696 }
1697 func testCloseStmtBeforeRows(t *testing.T) {
1698 db := newTestDB(t, "people")
1699 defer closeDB(t, db)
1700
1701 s, err := db.Prepare("SELECT|people|name|")
1702 if err != nil {
1703 t.Fatal(err)
1704 }
1705
1706 r, err := s.Query()
1707 if err != nil {
1708 s.Close()
1709 t.Fatal(err)
1710 }
1711
1712 err = s.Close()
1713 if err != nil {
1714 t.Fatal(err)
1715 }
1716
1717 r.Close()
1718 }
1719
1720
1721
1722 func TestNullByteSlice(t *testing.T) {
1723 synctest.Test(t, testNullByteSlice)
1724 }
1725 func testNullByteSlice(t *testing.T) {
1726 db := newTestDB(t, "")
1727 defer closeDB(t, db)
1728 exec(t, db, "CREATE|t|id=int32,name=nullstring")
1729 exec(t, db, "INSERT|t|id=10,name=?", nil)
1730
1731 var name []byte
1732
1733 err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
1734 if err != nil {
1735 t.Fatal(err)
1736 }
1737 if name != nil {
1738 t.Fatalf("name []byte should be nil for null column value, got: %#v", name)
1739 }
1740
1741 exec(t, db, "INSERT|t|id=11,name=?", "bob")
1742 err = db.QueryRow("SELECT|t|name|id=?", 11).Scan(&name)
1743 if err != nil {
1744 t.Fatal(err)
1745 }
1746 if string(name) != "bob" {
1747 t.Fatalf("name []byte should be bob, got: %q", string(name))
1748 }
1749 }
1750
1751 func TestPointerParamsAndScans(t *testing.T) {
1752 synctest.Test(t, testPointerParamsAndScans)
1753 }
1754 func testPointerParamsAndScans(t *testing.T) {
1755 db := newTestDB(t, "")
1756 defer closeDB(t, db)
1757 exec(t, db, "CREATE|t|id=int32,name=nullstring")
1758
1759 bob := "bob"
1760 var name *string
1761
1762 name = &bob
1763 exec(t, db, "INSERT|t|id=10,name=?", name)
1764 name = nil
1765 exec(t, db, "INSERT|t|id=20,name=?", name)
1766
1767 err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
1768 if err != nil {
1769 t.Fatalf("querying id 10: %v", err)
1770 }
1771 if name == nil {
1772 t.Errorf("id 10's name = nil; want bob")
1773 } else if *name != "bob" {
1774 t.Errorf("id 10's name = %q; want bob", *name)
1775 }
1776
1777 err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name)
1778 if err != nil {
1779 t.Fatalf("querying id 20: %v", err)
1780 }
1781 if name != nil {
1782 t.Errorf("id 20 = %q; want nil", *name)
1783 }
1784 }
1785
1786 func TestQueryRowClosingStmt(t *testing.T) {
1787 synctest.Test(t, testQueryRowClosingStmt)
1788 }
1789 func testQueryRowClosingStmt(t *testing.T) {
1790 db := newTestDB(t, "people")
1791 defer closeDB(t, db)
1792 var name string
1793 var age int
1794 err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name)
1795 if err != nil {
1796 t.Fatal(err)
1797 }
1798 if len(db.freeConn) != 1 {
1799 t.Fatalf("expected 1 free conn")
1800 }
1801 fakeConn := db.freeConn[0].ci.(*fakeConn)
1802 if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
1803 t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
1804 }
1805 }
1806
1807 var atomicRowsCloseHook atomic.Value
1808
1809 func init() {
1810 rowsCloseHook = func() func(*Rows, *error) {
1811 fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
1812 return fn
1813 }
1814 }
1815
1816 func setRowsCloseHook(fn func(*Rows, *error)) {
1817 if fn == nil {
1818
1819
1820 fn = func(*Rows, *error) {}
1821 }
1822 atomicRowsCloseHook.Store(fn)
1823 }
1824
1825
1826 func TestIssue6651(t *testing.T) {
1827 synctest.Test(t, testIssue6651)
1828 }
1829 func testIssue6651(t *testing.T) {
1830 db := newTestDB(t, "people")
1831 defer closeDB(t, db)
1832
1833 var v string
1834
1835 want := "error in rows.Next"
1836 rowsCursorNextHook = func(dest []driver.Value) error {
1837 return errors.New(want)
1838 }
1839 defer func() { rowsCursorNextHook = nil }()
1840
1841 err := db.QueryRow("SELECT|people|name|").Scan(&v)
1842 if err == nil || err.Error() != want {
1843 t.Errorf("error = %q; want %q", err, want)
1844 }
1845 rowsCursorNextHook = nil
1846
1847 want = "error in rows.Close"
1848 setRowsCloseHook(func(rows *Rows, err *error) {
1849 *err = errors.New(want)
1850 })
1851 defer setRowsCloseHook(nil)
1852 err = db.QueryRow("SELECT|people|name|").Scan(&v)
1853 if err == nil || err.Error() != want {
1854 t.Errorf("error = %q; want %q", err, want)
1855 }
1856 }
1857
1858 type nullTestRow struct {
1859 nullParam any
1860 notNullParam any
1861 scanNullVal any
1862 }
1863
1864 type nullTestSpec struct {
1865 nullType string
1866 notNullType string
1867 rows [6]nullTestRow
1868 }
1869
1870 func TestNullStringParam(t *testing.T) {
1871 spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
1872 {NullString{"aqua", true}, "", NullString{"aqua", true}},
1873 {NullString{"brown", false}, "", NullString{"", false}},
1874 {"chartreuse", "", NullString{"chartreuse", true}},
1875 {NullString{"darkred", true}, "", NullString{"darkred", true}},
1876 {NullString{"eel", false}, "", NullString{"", false}},
1877 {"foo", NullString{"black", false}, nil},
1878 }}
1879 synctest.Test(t, func(t *testing.T) {
1880 nullTestRun(t, spec)
1881 })
1882 }
1883
1884 func TestGenericNullStringParam(t *testing.T) {
1885 spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
1886 {Null[string]{"aqua", true}, "", Null[string]{"aqua", true}},
1887 {Null[string]{"brown", false}, "", Null[string]{"", false}},
1888 {"chartreuse", "", Null[string]{"chartreuse", true}},
1889 {Null[string]{"darkred", true}, "", Null[string]{"darkred", true}},
1890 {Null[string]{"eel", false}, "", Null[string]{"", false}},
1891 {"foo", Null[string]{"black", false}, nil},
1892 }}
1893 synctest.Test(t, func(t *testing.T) {
1894 nullTestRun(t, spec)
1895 })
1896 }
1897
1898 func TestNullInt64Param(t *testing.T) {
1899 spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{
1900 {NullInt64{31, true}, 1, NullInt64{31, true}},
1901 {NullInt64{-22, false}, 1, NullInt64{0, false}},
1902 {22, 1, NullInt64{22, true}},
1903 {NullInt64{33, true}, 1, NullInt64{33, true}},
1904 {NullInt64{222, false}, 1, NullInt64{0, false}},
1905 {0, NullInt64{31, false}, nil},
1906 }}
1907 synctest.Test(t, func(t *testing.T) {
1908 nullTestRun(t, spec)
1909 })
1910 }
1911
1912 func TestNullInt32Param(t *testing.T) {
1913 spec := nullTestSpec{"nullint32", "int32", [6]nullTestRow{
1914 {NullInt32{31, true}, 1, NullInt32{31, true}},
1915 {NullInt32{-22, false}, 1, NullInt32{0, false}},
1916 {22, 1, NullInt32{22, true}},
1917 {NullInt32{33, true}, 1, NullInt32{33, true}},
1918 {NullInt32{222, false}, 1, NullInt32{0, false}},
1919 {0, NullInt32{31, false}, nil},
1920 }}
1921 synctest.Test(t, func(t *testing.T) {
1922 nullTestRun(t, spec)
1923 })
1924 }
1925
1926 func TestNullInt16Param(t *testing.T) {
1927 spec := nullTestSpec{"nullint16", "int16", [6]nullTestRow{
1928 {NullInt16{31, true}, 1, NullInt16{31, true}},
1929 {NullInt16{-22, false}, 1, NullInt16{0, false}},
1930 {22, 1, NullInt16{22, true}},
1931 {NullInt16{33, true}, 1, NullInt16{33, true}},
1932 {NullInt16{222, false}, 1, NullInt16{0, false}},
1933 {0, NullInt16{31, false}, nil},
1934 }}
1935 synctest.Test(t, func(t *testing.T) {
1936 nullTestRun(t, spec)
1937 })
1938 }
1939
1940 func TestNullByteParam(t *testing.T) {
1941 spec := nullTestSpec{"nullbyte", "byte", [6]nullTestRow{
1942 {NullByte{31, true}, 1, NullByte{31, true}},
1943 {NullByte{0, false}, 1, NullByte{0, false}},
1944 {22, 1, NullByte{22, true}},
1945 {NullByte{33, true}, 1, NullByte{33, true}},
1946 {NullByte{222, false}, 1, NullByte{0, false}},
1947 {0, NullByte{31, false}, nil},
1948 }}
1949 synctest.Test(t, func(t *testing.T) {
1950 nullTestRun(t, spec)
1951 })
1952 }
1953
1954 func TestNullFloat64Param(t *testing.T) {
1955 spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{
1956 {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}},
1957 {NullFloat64{13.1, false}, 1, NullFloat64{0, false}},
1958 {-22.9, 1, NullFloat64{-22.9, true}},
1959 {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}},
1960 {NullFloat64{222, false}, 1, NullFloat64{0, false}},
1961 {10, NullFloat64{31.2, false}, nil},
1962 }}
1963 synctest.Test(t, func(t *testing.T) {
1964 nullTestRun(t, spec)
1965 })
1966 }
1967
1968 func TestNullBoolParam(t *testing.T) {
1969 spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{
1970 {NullBool{false, true}, true, NullBool{false, true}},
1971 {NullBool{true, false}, false, NullBool{false, false}},
1972 {true, true, NullBool{true, true}},
1973 {NullBool{true, true}, false, NullBool{true, true}},
1974 {NullBool{true, false}, true, NullBool{false, false}},
1975 {true, NullBool{true, false}, nil},
1976 }}
1977 synctest.Test(t, func(t *testing.T) {
1978 nullTestRun(t, spec)
1979 })
1980 }
1981
1982 func TestNullTimeParam(t *testing.T) {
1983 t0 := time.Time{}
1984 t1 := time.Date(2000, 1, 1, 8, 9, 10, 11, time.UTC)
1985 t2 := time.Date(2010, 1, 1, 8, 9, 10, 11, time.UTC)
1986 spec := nullTestSpec{"nulldatetime", "datetime", [6]nullTestRow{
1987 {NullTime{t1, true}, t2, NullTime{t1, true}},
1988 {NullTime{t1, false}, t2, NullTime{t0, false}},
1989 {t1, t2, NullTime{t1, true}},
1990 {NullTime{t1, true}, t2, NullTime{t1, true}},
1991 {NullTime{t1, false}, t2, NullTime{t0, false}},
1992 {t2, NullTime{t1, false}, nil},
1993 }}
1994 synctest.Test(t, func(t *testing.T) {
1995 nullTestRun(t, spec)
1996 })
1997 }
1998
1999 func nullTestRun(t *testing.T, spec nullTestSpec) {
2000 db := newTestDB(t, "")
2001 defer closeDB(t, db)
2002 exec(t, db, fmt.Sprintf("CREATE|t|id=int32,name=string,nullf=%s,notnullf=%s", spec.nullType, spec.notNullType))
2003
2004
2005 exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 1, "alice", spec.rows[0].nullParam, spec.rows[0].notNullParam)
2006 exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 2, "bob", spec.rows[1].nullParam, spec.rows[1].notNullParam)
2007
2008
2009 stmt, err := db.Prepare("INSERT|t|id=?,name=?,nullf=?,notnullf=?")
2010 if err != nil {
2011 t.Fatalf("prepare: %v", err)
2012 }
2013 defer stmt.Close()
2014 if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
2015 t.Errorf("exec insert chris: %v", err)
2016 }
2017 if _, err := stmt.Exec(4, "dave", spec.rows[3].nullParam, spec.rows[3].notNullParam); err != nil {
2018 t.Errorf("exec insert dave: %v", err)
2019 }
2020 if _, err := stmt.Exec(5, "eleanor", spec.rows[4].nullParam, spec.rows[4].notNullParam); err != nil {
2021 t.Errorf("exec insert eleanor: %v", err)
2022 }
2023
2024
2025 row5 := spec.rows[5]
2026 if _, err := stmt.Exec(6, "bob", row5.nullParam, row5.notNullParam); err == nil {
2027 t.Errorf("expected error inserting nil val with prepared statement Exec: NULL=%#v, NOT-NULL=%#v", row5.nullParam, row5.notNullParam)
2028 }
2029
2030 _, err = db.Exec("INSERT|t|id=?,name=?,nullf=?", 999, nil, nil)
2031 if err == nil {
2032
2033
2034
2035
2036
2037
2038 }
2039
2040 paramtype := reflect.TypeOf(spec.rows[0].nullParam)
2041 bindVal := reflect.New(paramtype).Interface()
2042
2043 for i := 0; i < 5; i++ {
2044 id := i + 1
2045 if err := db.QueryRow("SELECT|t|nullf|id=?", id).Scan(bindVal); err != nil {
2046 t.Errorf("id=%d Scan: %v", id, err)
2047 }
2048 bindValDeref := reflect.ValueOf(bindVal).Elem().Interface()
2049 if !reflect.DeepEqual(bindValDeref, spec.rows[i].scanNullVal) {
2050 t.Errorf("id=%d got %#v, want %#v", id, bindValDeref, spec.rows[i].scanNullVal)
2051 }
2052 }
2053 }
2054
2055
2056 func TestQueryRowNilScanDest(t *testing.T) {
2057 synctest.Test(t, testQueryRowNilScanDest)
2058 }
2059 func testQueryRowNilScanDest(t *testing.T) {
2060 db := newTestDB(t, "people")
2061 defer closeDB(t, db)
2062 var name *string
2063 err := db.QueryRow("SELECT|people|name|").Scan(name)
2064 want := `sql: Scan error on column index 0, name "name": destination pointer is nil`
2065 if err == nil || err.Error() != want {
2066 t.Errorf("error = %q; want %q", err.Error(), want)
2067 }
2068 }
2069
2070 func TestIssue4902(t *testing.T) {
2071 synctest.Test(t, testIssue4902)
2072 }
2073 func testIssue4902(t *testing.T) {
2074 db := newTestDB(t, "people")
2075 defer closeDB(t, db)
2076
2077 driver := db.Driver().(*fakeDriver)
2078 opens0 := driver.openCount
2079
2080 var stmt *Stmt
2081 var err error
2082 for i := 0; i < 10; i++ {
2083 stmt, err = db.Prepare("SELECT|people|name|")
2084 if err != nil {
2085 t.Fatal(err)
2086 }
2087 err = stmt.Close()
2088 if err != nil {
2089 t.Fatal(err)
2090 }
2091 }
2092
2093 opens := driver.openCount - opens0
2094 if opens > 1 {
2095 t.Errorf("opens = %d; want <= 1", opens)
2096 t.Logf("db = %#v", db)
2097 t.Logf("driver = %#v", driver)
2098 t.Logf("stmt = %#v", stmt)
2099 }
2100 }
2101
2102
2103
2104 func TestSimultaneousQueries(t *testing.T) {
2105 synctest.Test(t, testSimultaneousQueries)
2106 }
2107 func testSimultaneousQueries(t *testing.T) {
2108 db := newTestDB(t, "people")
2109 defer closeDB(t, db)
2110
2111 tx, err := db.Begin()
2112 if err != nil {
2113 t.Fatal(err)
2114 }
2115 defer tx.Rollback()
2116
2117 r1, err := tx.Query("SELECT|people|name|")
2118 if err != nil {
2119 t.Fatal(err)
2120 }
2121 defer r1.Close()
2122
2123 r2, err := tx.Query("SELECT|people|name|")
2124 if err != nil {
2125 t.Fatal(err)
2126 }
2127 defer r2.Close()
2128 }
2129
2130 func TestMaxIdleConns(t *testing.T) {
2131 synctest.Test(t, testMaxIdleConns)
2132 }
2133 func testMaxIdleConns(t *testing.T) {
2134 db := newTestDB(t, "people")
2135 defer closeDB(t, db)
2136
2137 tx, err := db.Begin()
2138 if err != nil {
2139 t.Fatal(err)
2140 }
2141 tx.Commit()
2142 if got := len(db.freeConn); got != 1 {
2143 t.Errorf("freeConns = %d; want 1", got)
2144 }
2145
2146 db.SetMaxIdleConns(0)
2147
2148 if got := len(db.freeConn); got != 0 {
2149 t.Errorf("freeConns after set to zero = %d; want 0", got)
2150 }
2151
2152 tx, err = db.Begin()
2153 if err != nil {
2154 t.Fatal(err)
2155 }
2156 tx.Commit()
2157 if got := len(db.freeConn); got != 0 {
2158 t.Errorf("freeConns = %d; want 0", got)
2159 }
2160 }
2161
2162 func TestMaxOpenConns(t *testing.T) {
2163 synctest.Test(t, testMaxOpenConns)
2164 }
2165 func testMaxOpenConns(t *testing.T) {
2166 if testing.Short() {
2167 t.Skip("skipping in short mode")
2168 }
2169 defer setHookpostCloseConn(nil)
2170 setHookpostCloseConn(func(_ *fakeConn, err error) {
2171 if err != nil {
2172 t.Errorf("Error closing fakeConn: %v", err)
2173 }
2174 })
2175
2176 db := newTestDB(t, "magicquery")
2177 defer closeDB(t, db)
2178
2179 driver := db.Driver().(*fakeDriver)
2180
2181
2182
2183 db.clearAllConns(t)
2184
2185 driver.mu.Lock()
2186 opens0 := driver.openCount
2187 closes0 := driver.closeCount
2188 driver.mu.Unlock()
2189
2190 db.SetMaxIdleConns(10)
2191 db.SetMaxOpenConns(10)
2192
2193 stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
2194 if err != nil {
2195 t.Fatal(err)
2196 }
2197
2198
2199 const (
2200 nquery = 50
2201 sleepMillis = 25
2202 nbatch = 2
2203 )
2204 var wg sync.WaitGroup
2205 for batch := 0; batch < nbatch; batch++ {
2206 for i := 0; i < nquery; i++ {
2207 wg.Add(1)
2208 go func() {
2209 defer wg.Done()
2210 var op string
2211 if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
2212 t.Error(err)
2213 }
2214 }()
2215 }
2216
2217 wg.Wait()
2218 }
2219
2220 if g, w := db.numFreeConns(), 10; g != w {
2221 t.Errorf("free conns = %d; want %d", g, w)
2222 }
2223
2224 if n := db.numDeps(); n > 20 {
2225 t.Errorf("number of dependencies = %d; expected <= 20", n)
2226 db.dumpDeps(t)
2227 }
2228
2229 driver.mu.Lock()
2230 opens := driver.openCount - opens0
2231 closes := driver.closeCount - closes0
2232 driver.mu.Unlock()
2233
2234 if opens > 10 {
2235 t.Logf("open calls = %d", opens)
2236 t.Logf("close calls = %d", closes)
2237 t.Errorf("db connections opened = %d; want <= 10", opens)
2238 db.dumpDeps(t)
2239 }
2240
2241 if err := stmt.Close(); err != nil {
2242 t.Fatal(err)
2243 }
2244
2245 if g, w := db.numFreeConns(), 10; g != w {
2246 t.Errorf("free conns = %d; want %d", g, w)
2247 }
2248
2249 if n := db.numDeps(); n > 10 {
2250 t.Errorf("number of dependencies = %d; expected <= 10", n)
2251 db.dumpDeps(t)
2252 }
2253
2254 db.SetMaxOpenConns(5)
2255
2256 if g, w := db.numFreeConns(), 5; g != w {
2257 t.Errorf("free conns = %d; want %d", g, w)
2258 }
2259
2260 if n := db.numDeps(); n > 5 {
2261 t.Errorf("number of dependencies = %d; expected 0", n)
2262 db.dumpDeps(t)
2263 }
2264
2265 db.SetMaxOpenConns(0)
2266
2267 if g, w := db.numFreeConns(), 5; g != w {
2268 t.Errorf("free conns = %d; want %d", g, w)
2269 }
2270
2271 if n := db.numDeps(); n > 5 {
2272 t.Errorf("number of dependencies = %d; expected 0", n)
2273 db.dumpDeps(t)
2274 }
2275
2276 db.clearAllConns(t)
2277 }
2278
2279
2280
2281 func TestMaxOpenConnsOnBusy(t *testing.T) {
2282 synctest.Test(t, testMaxOpenConnsOnBusy)
2283 }
2284 func testMaxOpenConnsOnBusy(t *testing.T) {
2285 defer setHookpostCloseConn(nil)
2286 setHookpostCloseConn(func(_ *fakeConn, err error) {
2287 if err != nil {
2288 t.Errorf("Error closing fakeConn: %v", err)
2289 }
2290 })
2291
2292 db := newTestDB(t, "magicquery")
2293 defer closeDB(t, db)
2294
2295 db.SetMaxOpenConns(3)
2296
2297 ctx := context.Background()
2298
2299 conn0, err := db.conn(ctx, cachedOrNewConn)
2300 if err != nil {
2301 t.Fatalf("db open conn fail: %v", err)
2302 }
2303
2304 conn1, err := db.conn(ctx, cachedOrNewConn)
2305 if err != nil {
2306 t.Fatalf("db open conn fail: %v", err)
2307 }
2308
2309 conn2, err := db.conn(ctx, cachedOrNewConn)
2310 if err != nil {
2311 t.Fatalf("db open conn fail: %v", err)
2312 }
2313
2314 if g, w := db.numOpen, 3; g != w {
2315 t.Errorf("free conns = %d; want %d", g, w)
2316 }
2317
2318 db.SetMaxOpenConns(2)
2319 if g, w := db.numOpen, 3; g != w {
2320 t.Errorf("free conns = %d; want %d", g, w)
2321 }
2322
2323 conn0.releaseConn(nil)
2324 conn1.releaseConn(nil)
2325 if g, w := db.numOpen, 2; g != w {
2326 t.Errorf("free conns = %d; want %d", g, w)
2327 }
2328
2329 conn2.releaseConn(nil)
2330 if g, w := db.numOpen, 2; g != w {
2331 t.Errorf("free conns = %d; want %d", g, w)
2332 }
2333 }
2334
2335
2336
2337 func TestPendingConnsAfterErr(t *testing.T) {
2338 synctest.Test(t, testPendingConnsAfterErr)
2339 }
2340 func testPendingConnsAfterErr(t *testing.T) {
2341 const (
2342 maxOpen = 2
2343 tryOpen = maxOpen*2 + 2
2344 )
2345
2346
2347 db, err := Open("test", fakeDBName)
2348 if err != nil {
2349 t.Fatalf("Open: %v", err)
2350 }
2351 defer closeDB(t, db)
2352 defer func() {
2353 for k, v := range db.lastPut {
2354 t.Logf("%p: %v", k, v)
2355 }
2356 }()
2357
2358 db.SetMaxOpenConns(maxOpen)
2359 db.SetMaxIdleConns(0)
2360
2361 errOffline := errors.New("db offline")
2362
2363 defer func() { setHookOpenErr(nil) }()
2364
2365 errs := make(chan error, tryOpen)
2366
2367 var opening sync.WaitGroup
2368 opening.Add(tryOpen)
2369
2370 setHookOpenErr(func() error {
2371
2372 opening.Wait()
2373 return errOffline
2374 })
2375
2376 for i := 0; i < tryOpen; i++ {
2377 go func() {
2378 opening.Done()
2379 _, err := db.Exec("will never run")
2380 errs <- err
2381 }()
2382 }
2383
2384 opening.Wait()
2385
2386 const timeout = 5 * time.Second
2387 to := time.NewTimer(timeout)
2388 defer to.Stop()
2389
2390
2391 for i := 0; i < tryOpen; i++ {
2392 select {
2393 case err := <-errs:
2394 if got, want := err, errOffline; got != want {
2395 t.Errorf("unexpected err: got %v, want %v", got, want)
2396 }
2397 case <-to.C:
2398 t.Fatalf("orphaned connection request(s), still waiting after %v", timeout)
2399 }
2400 }
2401
2402
2403 tick := time.NewTicker(3 * time.Millisecond)
2404 defer tick.Stop()
2405 for {
2406 select {
2407 case <-tick.C:
2408 db.mu.Lock()
2409 if db.numOpen == 0 {
2410 db.mu.Unlock()
2411 return
2412 }
2413 db.mu.Unlock()
2414 case <-to.C:
2415
2416 return
2417 }
2418 }
2419 }
2420
2421 func TestSingleOpenConn(t *testing.T) {
2422 synctest.Test(t, testSingleOpenConn)
2423 }
2424 func testSingleOpenConn(t *testing.T) {
2425 db := newTestDB(t, "people")
2426 defer closeDB(t, db)
2427
2428 db.SetMaxOpenConns(1)
2429
2430 rows, err := db.Query("SELECT|people|name|")
2431 if err != nil {
2432 t.Fatal(err)
2433 }
2434 if err = rows.Close(); err != nil {
2435 t.Fatal(err)
2436 }
2437
2438 rows, err = db.Query("SELECT|people|name|")
2439 if err != nil {
2440 t.Fatal(err)
2441 }
2442 if err = rows.Close(); err != nil {
2443 t.Fatal(err)
2444 }
2445 }
2446
2447 func TestStats(t *testing.T) {
2448 synctest.Test(t, testStats)
2449 }
2450 func testStats(t *testing.T) {
2451 db := newTestDB(t, "people")
2452 stats := db.Stats()
2453 if got := stats.OpenConnections; got != 1 {
2454 t.Errorf("stats.OpenConnections = %d; want 1", got)
2455 }
2456
2457 tx, err := db.Begin()
2458 if err != nil {
2459 t.Fatal(err)
2460 }
2461 tx.Commit()
2462
2463 closeDB(t, db)
2464 stats = db.Stats()
2465 if got := stats.OpenConnections; got != 0 {
2466 t.Errorf("stats.OpenConnections = %d; want 0", got)
2467 }
2468 }
2469
2470 func TestConnMaxLifetime(t *testing.T) {
2471 synctest.Test(t, testConnMaxLifetime)
2472 }
2473 func testConnMaxLifetime(t *testing.T) {
2474 db := newTestDB(t, "magicquery")
2475 defer closeDB(t, db)
2476
2477 driver := db.Driver().(*fakeDriver)
2478
2479
2480
2481 db.clearAllConns(t)
2482
2483 driver.mu.Lock()
2484 opens0 := driver.openCount
2485 closes0 := driver.closeCount
2486 driver.mu.Unlock()
2487
2488 db.SetMaxIdleConns(10)
2489 db.SetMaxOpenConns(10)
2490
2491 tx, err := db.Begin()
2492 if err != nil {
2493 t.Fatal(err)
2494 }
2495
2496 synctest.Sleep(1 * time.Second)
2497 tx2, err := db.Begin()
2498 if err != nil {
2499 t.Fatal(err)
2500 }
2501
2502 tx.Commit()
2503 tx2.Commit()
2504
2505 driver.mu.Lock()
2506 opens := driver.openCount - opens0
2507 closes := driver.closeCount - closes0
2508 driver.mu.Unlock()
2509
2510 if opens != 2 {
2511 t.Errorf("opens = %d; want 2", opens)
2512 }
2513 if closes != 0 {
2514 t.Errorf("closes = %d; want 0", closes)
2515 }
2516 if g, w := db.numFreeConns(), 2; g != w {
2517 t.Errorf("free conns = %d; want %d", g, w)
2518 }
2519
2520
2521 synctest.Sleep(10 * time.Second)
2522 db.SetConnMaxLifetime(10 * time.Second)
2523
2524 tx, err = db.Begin()
2525 if err != nil {
2526 t.Fatal(err)
2527 }
2528 tx2, err = db.Begin()
2529 if err != nil {
2530 t.Fatal(err)
2531 }
2532 tx.Commit()
2533 tx2.Commit()
2534
2535
2536 synctest.Wait()
2537 driver.mu.Lock()
2538 opens = driver.openCount - opens0
2539 closes = driver.closeCount - closes0
2540 driver.mu.Unlock()
2541
2542 if opens != 3 {
2543 t.Errorf("opens = %d; want 3", opens)
2544 }
2545 if closes != 1 {
2546 t.Errorf("closes = %d; want 1", closes)
2547 }
2548
2549 if s := db.Stats(); s.MaxLifetimeClosed != 1 {
2550 t.Errorf("MaxLifetimeClosed = %d; want 1 %#v", s.MaxLifetimeClosed, s)
2551 }
2552 }
2553
2554
2555 func TestStmtCloseDeps(t *testing.T) {
2556 synctest.Test(t, testStmtCloseDeps)
2557 }
2558 func testStmtCloseDeps(t *testing.T) {
2559 if testing.Short() {
2560 t.Skip("skipping in short mode")
2561 }
2562 defer setHookpostCloseConn(nil)
2563 setHookpostCloseConn(func(_ *fakeConn, err error) {
2564 if err != nil {
2565 t.Errorf("Error closing fakeConn: %v", err)
2566 }
2567 })
2568
2569 db := newTestDB(t, "magicquery")
2570 defer closeDB(t, db)
2571
2572 driver := db.Driver().(*fakeDriver)
2573
2574 driver.mu.Lock()
2575 opens0 := driver.openCount
2576 closes0 := driver.closeCount
2577 driver.mu.Unlock()
2578 openDelta0 := opens0 - closes0
2579
2580 stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
2581 if err != nil {
2582 t.Fatal(err)
2583 }
2584
2585
2586 const (
2587 nquery = 50
2588 sleepMillis = 25
2589 nbatch = 2
2590 )
2591 var wg sync.WaitGroup
2592 for batch := 0; batch < nbatch; batch++ {
2593 for i := 0; i < nquery; i++ {
2594 wg.Add(1)
2595 go func() {
2596 defer wg.Done()
2597 var op string
2598 if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
2599 t.Error(err)
2600 }
2601 }()
2602 }
2603
2604 wg.Wait()
2605 }
2606
2607 if g, w := db.numFreeConns(), 2; g != w {
2608 t.Errorf("free conns = %d; want %d", g, w)
2609 }
2610
2611 if n := db.numDeps(); n > 4 {
2612 t.Errorf("number of dependencies = %d; expected <= 4", n)
2613 db.dumpDeps(t)
2614 }
2615
2616 driver.mu.Lock()
2617 opens := driver.openCount - opens0
2618 closes := driver.closeCount - closes0
2619 openDelta := (driver.openCount - driver.closeCount) - openDelta0
2620 driver.mu.Unlock()
2621
2622 if openDelta > 2 {
2623 t.Logf("open calls = %d", opens)
2624 t.Logf("close calls = %d", closes)
2625 t.Logf("open delta = %d", openDelta)
2626 t.Errorf("db connections opened = %d; want <= 2", openDelta)
2627 db.dumpDeps(t)
2628 }
2629
2630 synctest.Wait()
2631 if len(stmt.css) > nquery {
2632 t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
2633 }
2634
2635 if err := stmt.Close(); err != nil {
2636 t.Fatal(err)
2637 }
2638
2639 if g, w := db.numFreeConns(), 2; g != w {
2640 t.Errorf("free conns = %d; want %d", g, w)
2641 }
2642
2643 if n := db.numDeps(); n > 2 {
2644 t.Errorf("number of dependencies = %d; expected <= 2", n)
2645 db.dumpDeps(t)
2646 }
2647
2648 db.clearAllConns(t)
2649 }
2650
2651
2652 func TestCloseConnBeforeStmts(t *testing.T) {
2653 synctest.Test(t, testCloseConnBeforeStmts)
2654 }
2655 func testCloseConnBeforeStmts(t *testing.T) {
2656 db := newTestDB(t, "people")
2657 defer closeDB(t, db)
2658
2659 defer setHookpostCloseConn(nil)
2660 setHookpostCloseConn(func(_ *fakeConn, err error) {
2661 if err != nil {
2662 t.Errorf("Error closing fakeConn: %v; from %s", err, stack())
2663 db.dumpDeps(t)
2664 t.Errorf("DB = %#v", db)
2665 }
2666 })
2667
2668 stmt, err := db.Prepare("SELECT|people|name|")
2669 if err != nil {
2670 t.Fatal(err)
2671 }
2672
2673 if len(db.freeConn) != 1 {
2674 t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
2675 }
2676 dc := db.freeConn[0]
2677 if dc.closed {
2678 t.Errorf("conn shouldn't be closed")
2679 }
2680
2681 if n := len(dc.openStmt); n != 1 {
2682 t.Errorf("driverConn num openStmt = %d; want 1", n)
2683 }
2684 err = db.Close()
2685 if err != nil {
2686 t.Errorf("db Close = %v", err)
2687 }
2688 if !dc.closed {
2689 t.Errorf("after db.Close, driverConn should be closed")
2690 }
2691 if n := len(dc.openStmt); n != 0 {
2692 t.Errorf("driverConn num openStmt = %d; want 0", n)
2693 }
2694
2695 err = stmt.Close()
2696 if err != nil {
2697 t.Errorf("Stmt close = %v", err)
2698 }
2699
2700 if !dc.closed {
2701 t.Errorf("conn should be closed")
2702 }
2703 if dc.ci != nil {
2704 t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
2705 }
2706 }
2707
2708
2709
2710 func TestRowsCloseOrder(t *testing.T) {
2711 synctest.Test(t, testRowsCloseOrder)
2712 }
2713 func testRowsCloseOrder(t *testing.T) {
2714 db := newTestDB(t, "people")
2715 defer closeDB(t, db)
2716
2717 db.SetMaxIdleConns(0)
2718 setStrictFakeConnClose(t)
2719 defer setStrictFakeConnClose(nil)
2720
2721 rows, err := db.Query("SELECT|people|age,name|")
2722 if err != nil {
2723 t.Fatal(err)
2724 }
2725 err = rows.Close()
2726 if err != nil {
2727 t.Fatal(err)
2728 }
2729 }
2730
2731 func TestRowsImplicitClose(t *testing.T) {
2732 synctest.Test(t, testRowsImplicitClose)
2733 }
2734 func testRowsImplicitClose(t *testing.T) {
2735 db := newTestDB(t, "people")
2736 defer closeDB(t, db)
2737
2738 rows, err := db.Query("SELECT|people|age,name|")
2739 if err != nil {
2740 t.Fatal(err)
2741 }
2742
2743 want, fail := 2, errors.New("fail")
2744 r := rows.rowsi.(*rowsCursor)
2745 r.errPos, r.err = want, fail
2746
2747 got := 0
2748 for rows.Next() {
2749 got++
2750 }
2751 if got != want {
2752 t.Errorf("got %d rows, want %d", got, want)
2753 }
2754 if err := rows.Err(); err != fail {
2755 t.Errorf("got error %v, want %v", err, fail)
2756 }
2757 if !r.closed {
2758 t.Errorf("r.closed is false, want true")
2759 }
2760 }
2761
2762 func TestRowsCloseError(t *testing.T) {
2763 synctest.Test(t, testRowsCloseError)
2764 }
2765 func testRowsCloseError(t *testing.T) {
2766 db := newTestDB(t, "people")
2767 defer db.Close()
2768 rows, err := db.Query("SELECT|people|age,name|")
2769 if err != nil {
2770 t.Fatalf("Query: %v", err)
2771 }
2772 type row struct {
2773 age int
2774 name string
2775 }
2776 got := []row{}
2777
2778 rc, ok := rows.rowsi.(*rowsCursor)
2779 if !ok {
2780 t.Fatal("not using *rowsCursor")
2781 }
2782 rc.closeErr = errors.New("rowsCursor: failed to close")
2783
2784 for rows.Next() {
2785 var r row
2786 err = rows.Scan(&r.age, &r.name)
2787 if err != nil {
2788 t.Fatalf("Scan: %v", err)
2789 }
2790 got = append(got, r)
2791 }
2792 err = rows.Err()
2793 if err != rc.closeErr {
2794 t.Fatalf("unexpected err: got %v, want %v", err, rc.closeErr)
2795 }
2796 }
2797
2798 func TestStmtCloseOrder(t *testing.T) {
2799 synctest.Test(t, testStmtCloseOrder)
2800 }
2801 func testStmtCloseOrder(t *testing.T) {
2802 db := newTestDB(t, "people")
2803 defer closeDB(t, db)
2804
2805 db.SetMaxIdleConns(0)
2806 setStrictFakeConnClose(t)
2807 defer setStrictFakeConnClose(nil)
2808
2809 _, err := db.Query("SELECT|non_existent|name|")
2810 if err == nil {
2811 t.Fatal("Querying non-existent table should fail")
2812 }
2813 }
2814
2815
2816
2817 func TestManyErrBadConn(t *testing.T) {
2818 synctest.Test(t, testManyErrBadConn)
2819 }
2820 func testManyErrBadConn(t *testing.T) {
2821 manyErrBadConnSetup := func(first ...func(db *DB)) *DB {
2822 db := newTestDB(t, "people")
2823
2824 for _, f := range first {
2825 f(db)
2826 }
2827
2828 nconn := maxBadConnRetries + 1
2829 db.SetMaxIdleConns(nconn)
2830 db.SetMaxOpenConns(nconn)
2831
2832 func() {
2833 for i := 0; i < nconn; i++ {
2834 rows, err := db.Query("SELECT|people|age,name|")
2835 if err != nil {
2836 t.Fatal(err)
2837 }
2838 defer rows.Close()
2839 }
2840 }()
2841
2842 db.mu.Lock()
2843 defer db.mu.Unlock()
2844 if db.numOpen != nconn {
2845 t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn)
2846 } else if len(db.freeConn) != nconn {
2847 t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
2848 }
2849 for _, conn := range db.freeConn {
2850 conn.Lock()
2851 conn.ci.(*fakeConn).stickyBad = true
2852 conn.Unlock()
2853 }
2854 return db
2855 }
2856
2857
2858 db := manyErrBadConnSetup()
2859 defer closeDB(t, db)
2860 rows, err := db.Query("SELECT|people|age,name|")
2861 if err != nil {
2862 t.Fatal(err)
2863 }
2864 if err = rows.Close(); err != nil {
2865 t.Fatal(err)
2866 }
2867
2868
2869 db = manyErrBadConnSetup()
2870 defer closeDB(t, db)
2871 _, err = db.Exec("INSERT|people|name=Julia,age=19")
2872 if err != nil {
2873 t.Fatal(err)
2874 }
2875
2876
2877 db = manyErrBadConnSetup()
2878 defer closeDB(t, db)
2879 tx, err := db.Begin()
2880 if err != nil {
2881 t.Fatal(err)
2882 }
2883 if err = tx.Rollback(); err != nil {
2884 t.Fatal(err)
2885 }
2886
2887
2888 db = manyErrBadConnSetup()
2889 defer closeDB(t, db)
2890 stmt, err := db.Prepare("SELECT|people|age,name|")
2891 if err != nil {
2892 t.Fatal(err)
2893 }
2894 if err = stmt.Close(); err != nil {
2895 t.Fatal(err)
2896 }
2897
2898
2899 db = manyErrBadConnSetup(func(db *DB) {
2900 stmt, err = db.Prepare("INSERT|people|name=Julia,age=19")
2901 if err != nil {
2902 t.Fatal(err)
2903 }
2904 })
2905 defer closeDB(t, db)
2906 _, err = stmt.Exec()
2907 if err != nil {
2908 t.Fatal(err)
2909 }
2910 if err = stmt.Close(); err != nil {
2911 t.Fatal(err)
2912 }
2913
2914
2915 db = manyErrBadConnSetup(func(db *DB) {
2916 stmt, err = db.Prepare("SELECT|people|age,name|")
2917 if err != nil {
2918 t.Fatal(err)
2919 }
2920 })
2921 defer closeDB(t, db)
2922 rows, err = stmt.Query()
2923 if err != nil {
2924 t.Fatal(err)
2925 }
2926 if err = rows.Close(); err != nil {
2927 t.Fatal(err)
2928 }
2929 if err = stmt.Close(); err != nil {
2930 t.Fatal(err)
2931 }
2932
2933
2934 db = manyErrBadConnSetup()
2935 defer closeDB(t, db)
2936 ctx := t.Context()
2937 conn, err := db.Conn(ctx)
2938 if err != nil {
2939 t.Fatal(err)
2940 }
2941 conn.dc.ci.(*fakeConn).skipDirtySession = true
2942 err = conn.Close()
2943 if err != nil {
2944 t.Fatal(err)
2945 }
2946
2947
2948 db = manyErrBadConnSetup()
2949 defer closeDB(t, db)
2950 err = db.PingContext(ctx)
2951 if err != nil {
2952 t.Fatal(err)
2953 }
2954 }
2955
2956
2957 func TestTxCannotCommitAfterRollback(t *testing.T) {
2958 synctest.Test(t, testTxCannotCommitAfterRollback)
2959 }
2960 func testTxCannotCommitAfterRollback(t *testing.T) {
2961 db := newTestDB(t, "tx_status")
2962 defer closeDB(t, db)
2963
2964
2965 var txStatus string
2966 err := db.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
2967 if err != nil {
2968 t.Fatal(err)
2969 }
2970 if g, w := txStatus, "autocommit"; g != w {
2971 t.Fatalf("tx_status=%q, wanted %q", g, w)
2972 }
2973
2974 ctx, cancel := context.WithCancel(context.Background())
2975 defer cancel()
2976
2977 tx, err := db.BeginTx(ctx, nil)
2978 if err != nil {
2979 t.Fatal(err)
2980 }
2981
2982
2983
2984
2985 tx.txi.(*fakeTx).c.skipDirtySession = true
2986
2987 defer tx.Rollback()
2988
2989 err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
2990 if err != nil {
2991 t.Fatal(err)
2992 }
2993 if g, w := txStatus, "transaction"; g != w {
2994 t.Fatalf("tx_status=%q, wanted %q", g, w)
2995 }
2996
2997
2998
2999
3000 sendQuery := make(chan struct{})
3001
3002
3003 bypassRowsAwaitDone = true
3004 hookTxGrabConn = func() {
3005 cancel()
3006 <-sendQuery
3007 }
3008 rollbackHook = func() {
3009 close(sendQuery)
3010 }
3011 defer func() {
3012 hookTxGrabConn = nil
3013 rollbackHook = nil
3014 bypassRowsAwaitDone = false
3015 }()
3016
3017 err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
3018 if err != nil {
3019
3020 t.Fatal(err)
3021 }
3022 if g, w := txStatus, "transaction"; g != w {
3023 t.Fatalf("tx_status=%q, wanted %q", g, w)
3024 }
3025 }
3026
3027
3028 func TestTxStmtDeadlock(t *testing.T) {
3029 synctest.Test(t, testTxStmtDeadlock)
3030 }
3031 func testTxStmtDeadlock(t *testing.T) {
3032 db := newTestDB(t, "people")
3033 defer closeDB(t, db)
3034
3035 ctx, cancel := context.WithCancel(context.Background())
3036 defer cancel()
3037 tx, err := db.BeginTx(ctx, nil)
3038 if err != nil {
3039 t.Fatal(err)
3040 }
3041
3042 stmt, err := tx.Prepare("SELECT|people|name,age|age=?")
3043 if err != nil {
3044 t.Fatal(err)
3045 }
3046 cancel()
3047
3048 for i := 0; i < 1e3; i++ {
3049
3050
3051 _, err = stmt.Query(1)
3052 if err != nil {
3053 break
3054 }
3055 }
3056 _ = tx.Rollback()
3057 }
3058
3059 func synctestSubtest(t *testing.T, name string, f func(t *testing.T)) {
3060 t.Run(name, func(t *testing.T) {
3061 synctest.Test(t, f)
3062 })
3063 }
3064
3065
3066
3067
3068 func TestConnExpiresFreshOutOfPool(t *testing.T) {
3069 execCases := []struct {
3070 expired bool
3071 badReset bool
3072 }{
3073 {false, false},
3074 {true, false},
3075 {false, true},
3076 }
3077
3078 for _, ec := range execCases {
3079 name := fmt.Sprintf("expired=%t,badReset=%t", ec.expired, ec.badReset)
3080 synctestSubtest(t, name, func(t *testing.T) {
3081 ctx := t.Context()
3082
3083 db := newTestDB(t, "magicquery")
3084
3085 db.SetMaxOpenConns(1)
3086
3087 db.clearAllConns(t)
3088
3089 db.SetMaxIdleConns(1)
3090 db.SetConnMaxLifetime(10 * time.Second)
3091
3092 conn, err := db.conn(ctx, alwaysNewConn)
3093 if err != nil {
3094 t.Fatal(err)
3095 }
3096
3097 afterPutConn := make(chan struct{})
3098
3099 go func() {
3100 defer close(afterPutConn)
3101
3102 conn, err := db.conn(ctx, alwaysNewConn)
3103 if err == nil {
3104 db.putConn(conn, err, false)
3105 } else {
3106 t.Errorf("db.conn: %v", err)
3107 }
3108 }()
3109 synctest.Wait()
3110
3111 if t.Failed() {
3112 return
3113 }
3114
3115 synctest.Sleep(11 * time.Second)
3116
3117 conn.ci.(*fakeConn).stickyBad = ec.badReset
3118
3119 db.putConn(conn, err, true)
3120
3121 <-afterPutConn
3122 })
3123 }
3124 }
3125
3126
3127
3128 func TestIssue20575(t *testing.T) {
3129 synctest.Test(t, testIssue20575)
3130 }
3131 func testIssue20575(t *testing.T) {
3132 db := newTestDB(t, "people")
3133 defer closeDB(t, db)
3134
3135 tx, err := db.Begin()
3136 if err != nil {
3137 t.Fatal(err)
3138 }
3139 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
3140 defer cancel()
3141 _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
3142 if err != nil {
3143 t.Fatal(err)
3144 }
3145
3146 err = tx.Rollback()
3147 if err != nil {
3148 t.Fatal(err)
3149 }
3150 select {
3151 default:
3152 case <-ctx.Done():
3153 t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err())
3154 }
3155 }
3156
3157
3158
3159 func TestIssue20622(t *testing.T) {
3160 synctest.Test(t, testIssue20622)
3161 }
3162 func testIssue20622(t *testing.T) {
3163 db := newTestDB(t, "people")
3164 defer closeDB(t, db)
3165
3166 ctx, cancel := context.WithCancel(context.Background())
3167 defer cancel()
3168
3169 tx, err := db.BeginTx(ctx, nil)
3170 if err != nil {
3171 t.Fatal(err)
3172 }
3173
3174 rows, err := tx.Query("SELECT|people|age,name|")
3175 if err != nil {
3176 t.Fatal(err)
3177 }
3178
3179 count := 0
3180 for rows.Next() {
3181 count++
3182 var age int
3183 var name string
3184 if err := rows.Scan(&age, &name); err != nil {
3185 t.Fatal("scan failed", err)
3186 }
3187
3188 if count == 1 {
3189 cancel()
3190 }
3191 time.Sleep(100 * time.Millisecond)
3192 }
3193 rows.Close()
3194 tx.Commit()
3195 }
3196
3197
3198 func TestErrBadConnReconnect(t *testing.T) {
3199 synctest.Test(t, testErrBadConnReconnect)
3200 }
3201 func testErrBadConnReconnect(t *testing.T) {
3202 db := newTestDB(t, "foo")
3203 defer closeDB(t, db)
3204 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
3205
3206 simulateBadConn := func(name string, hook *func() bool, op func() error) {
3207 broken, retried := false, false
3208 numOpen := db.numOpen
3209
3210
3211 *hook = func() bool {
3212 if !broken {
3213 broken = true
3214 return true
3215 }
3216 retried = true
3217 return false
3218 }
3219
3220 if err := op(); err != nil {
3221 t.Errorf(name+": %v", err)
3222 return
3223 }
3224
3225 if !broken || !retried {
3226 t.Error(name + ": Failed to simulate broken connection")
3227 }
3228 *hook = nil
3229
3230 if numOpen != db.numOpen {
3231 t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
3232 numOpen = db.numOpen
3233 }
3234 }
3235
3236
3237 dbExec := func() error {
3238 _, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
3239 return err
3240 }
3241 simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec)
3242 simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec)
3243
3244
3245 dbQuery := func() error {
3246 rows, err := db.Query("SELECT|t1|age,name|")
3247 if err == nil {
3248 err = rows.Close()
3249 }
3250 return err
3251 }
3252 simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery)
3253 simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery)
3254
3255
3256 simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error {
3257 stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
3258 if err != nil {
3259 return err
3260 }
3261 stmt.Close()
3262 return nil
3263 })
3264
3265
3266 forcePrepare := func(stmt *Stmt) {
3267 stmt.css = nil
3268 }
3269
3270
3271 stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
3272 if err != nil {
3273 t.Fatalf("prepare: %v", err)
3274 }
3275 defer stmt1.Close()
3276
3277 forcePrepare(stmt1)
3278
3279 stmtExec := func() error {
3280 _, err := stmt1.Exec("Gopher", 3, false)
3281 return err
3282 }
3283 simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec)
3284 simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec)
3285
3286
3287 stmt2, err := db.Prepare("SELECT|t1|age,name|")
3288 if err != nil {
3289 t.Fatalf("prepare: %v", err)
3290 }
3291 defer stmt2.Close()
3292
3293 forcePrepare(stmt2)
3294
3295 stmtQuery := func() error {
3296 rows, err := stmt2.Query()
3297 if err == nil {
3298 err = rows.Close()
3299 }
3300 return err
3301 }
3302 simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery)
3303 simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery)
3304 }
3305
3306
3307 func TestTxEndBadConn(t *testing.T) {
3308 synctest.Test(t, testTxEndBadConn)
3309 }
3310 func testTxEndBadConn(t *testing.T) {
3311 db := newTestDB(t, "foo")
3312 defer closeDB(t, db)
3313 db.SetMaxIdleConns(0)
3314 exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
3315 db.SetMaxIdleConns(1)
3316
3317 simulateBadConn := func(name string, hook *func() bool, op func() error) {
3318 broken := false
3319 numOpen := db.numOpen
3320
3321 *hook = func() bool {
3322 if !broken {
3323 broken = true
3324 }
3325 return broken
3326 }
3327
3328 if err := op(); !errors.Is(err, driver.ErrBadConn) {
3329 t.Errorf(name+": %v", err)
3330 return
3331 }
3332
3333 if !broken {
3334 t.Error(name + ": Failed to simulate broken connection")
3335 }
3336 *hook = nil
3337
3338 if numOpen != db.numOpen {
3339 t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
3340 }
3341 }
3342
3343
3344 dbExec := func(endTx func(tx *Tx) error) func() error {
3345 return func() error {
3346 tx, err := db.Begin()
3347 if err != nil {
3348 return err
3349 }
3350 _, err = tx.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
3351 if err != nil {
3352 return err
3353 }
3354 return endTx(tx)
3355 }
3356 }
3357 simulateBadConn("db.Tx.Exec commit", &hookCommitBadConn, dbExec((*Tx).Commit))
3358 simulateBadConn("db.Tx.Exec rollback", &hookRollbackBadConn, dbExec((*Tx).Rollback))
3359
3360
3361 dbQuery := func(endTx func(tx *Tx) error) func() error {
3362 return func() error {
3363 tx, err := db.Begin()
3364 if err != nil {
3365 return err
3366 }
3367 rows, err := tx.Query("SELECT|t1|age,name|")
3368 if err == nil {
3369 err = rows.Close()
3370 } else {
3371 return err
3372 }
3373 return endTx(tx)
3374 }
3375 }
3376 simulateBadConn("db.Tx.Query commit", &hookCommitBadConn, dbQuery((*Tx).Commit))
3377 simulateBadConn("db.Tx.Query rollback", &hookRollbackBadConn, dbQuery((*Tx).Rollback))
3378 }
3379
3380 type concurrentTest interface {
3381 init(t testing.TB, db *DB)
3382 finish(t testing.TB)
3383 test(t testing.TB) error
3384 }
3385
3386 type concurrentDBQueryTest struct {
3387 db *DB
3388 }
3389
3390 func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) {
3391 c.db = db
3392 }
3393
3394 func (c *concurrentDBQueryTest) finish(t testing.TB) {
3395 c.db = nil
3396 }
3397
3398 func (c *concurrentDBQueryTest) test(t testing.TB) error {
3399 rows, err := c.db.Query("SELECT|people|name|")
3400 if err != nil {
3401 t.Error(err)
3402 return err
3403 }
3404 var name string
3405 for rows.Next() {
3406 rows.Scan(&name)
3407 }
3408 rows.Close()
3409 return nil
3410 }
3411
3412 type concurrentDBExecTest struct {
3413 db *DB
3414 }
3415
3416 func (c *concurrentDBExecTest) init(t testing.TB, db *DB) {
3417 c.db = db
3418 }
3419
3420 func (c *concurrentDBExecTest) finish(t testing.TB) {
3421 c.db = nil
3422 }
3423
3424 func (c *concurrentDBExecTest) test(t testing.TB) error {
3425 _, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
3426 if err != nil {
3427 t.Error(err)
3428 return err
3429 }
3430 return nil
3431 }
3432
3433 type concurrentStmtQueryTest struct {
3434 db *DB
3435 stmt *Stmt
3436 }
3437
3438 func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) {
3439 c.db = db
3440 var err error
3441 c.stmt, err = db.Prepare("SELECT|people|name|")
3442 if err != nil {
3443 t.Fatal(err)
3444 }
3445 }
3446
3447 func (c *concurrentStmtQueryTest) finish(t testing.TB) {
3448 if c.stmt != nil {
3449 c.stmt.Close()
3450 c.stmt = nil
3451 }
3452 c.db = nil
3453 }
3454
3455 func (c *concurrentStmtQueryTest) test(t testing.TB) error {
3456 rows, err := c.stmt.Query()
3457 if err != nil {
3458 t.Errorf("error on query: %v", err)
3459 return err
3460 }
3461
3462 var name string
3463 for rows.Next() {
3464 rows.Scan(&name)
3465 }
3466 rows.Close()
3467 return nil
3468 }
3469
3470 type concurrentStmtExecTest struct {
3471 db *DB
3472 stmt *Stmt
3473 }
3474
3475 func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) {
3476 c.db = db
3477 var err error
3478 c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
3479 if err != nil {
3480 t.Fatal(err)
3481 }
3482 }
3483
3484 func (c *concurrentStmtExecTest) finish(t testing.TB) {
3485 if c.stmt != nil {
3486 c.stmt.Close()
3487 c.stmt = nil
3488 }
3489 c.db = nil
3490 }
3491
3492 func (c *concurrentStmtExecTest) test(t testing.TB) error {
3493 _, err := c.stmt.Exec(3, chrisBirthday)
3494 if err != nil {
3495 t.Errorf("error on exec: %v", err)
3496 return err
3497 }
3498 return nil
3499 }
3500
3501 type concurrentTxQueryTest struct {
3502 db *DB
3503 tx *Tx
3504 }
3505
3506 func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) {
3507 c.db = db
3508 var err error
3509 c.tx, err = c.db.Begin()
3510 if err != nil {
3511 t.Fatal(err)
3512 }
3513 }
3514
3515 func (c *concurrentTxQueryTest) finish(t testing.TB) {
3516 if c.tx != nil {
3517 c.tx.Rollback()
3518 c.tx = nil
3519 }
3520 c.db = nil
3521 }
3522
3523 func (c *concurrentTxQueryTest) test(t testing.TB) error {
3524 rows, err := c.db.Query("SELECT|people|name|")
3525 if err != nil {
3526 t.Error(err)
3527 return err
3528 }
3529 var name string
3530 for rows.Next() {
3531 rows.Scan(&name)
3532 }
3533 rows.Close()
3534 return nil
3535 }
3536
3537 type concurrentTxExecTest struct {
3538 db *DB
3539 tx *Tx
3540 }
3541
3542 func (c *concurrentTxExecTest) init(t testing.TB, db *DB) {
3543 c.db = db
3544 var err error
3545 c.tx, err = c.db.Begin()
3546 if err != nil {
3547 t.Fatal(err)
3548 }
3549 }
3550
3551 func (c *concurrentTxExecTest) finish(t testing.TB) {
3552 if c.tx != nil {
3553 c.tx.Rollback()
3554 c.tx = nil
3555 }
3556 c.db = nil
3557 }
3558
3559 func (c *concurrentTxExecTest) test(t testing.TB) error {
3560 _, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
3561 if err != nil {
3562 t.Error(err)
3563 return err
3564 }
3565 return nil
3566 }
3567
3568 type concurrentTxStmtQueryTest struct {
3569 db *DB
3570 tx *Tx
3571 stmt *Stmt
3572 }
3573
3574 func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) {
3575 c.db = db
3576 var err error
3577 c.tx, err = c.db.Begin()
3578 if err != nil {
3579 t.Fatal(err)
3580 }
3581 c.stmt, err = c.tx.Prepare("SELECT|people|name|")
3582 if err != nil {
3583 t.Fatal(err)
3584 }
3585 }
3586
3587 func (c *concurrentTxStmtQueryTest) finish(t testing.TB) {
3588 if c.stmt != nil {
3589 c.stmt.Close()
3590 c.stmt = nil
3591 }
3592 if c.tx != nil {
3593 c.tx.Rollback()
3594 c.tx = nil
3595 }
3596 c.db = nil
3597 }
3598
3599 func (c *concurrentTxStmtQueryTest) test(t testing.TB) error {
3600 rows, err := c.stmt.Query()
3601 if err != nil {
3602 t.Errorf("error on query: %v", err)
3603 return err
3604 }
3605
3606 var name string
3607 for rows.Next() {
3608 rows.Scan(&name)
3609 }
3610 rows.Close()
3611 return nil
3612 }
3613
3614 type concurrentTxStmtExecTest struct {
3615 db *DB
3616 tx *Tx
3617 stmt *Stmt
3618 }
3619
3620 func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) {
3621 c.db = db
3622 var err error
3623 c.tx, err = c.db.Begin()
3624 if err != nil {
3625 t.Fatal(err)
3626 }
3627 c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
3628 if err != nil {
3629 t.Fatal(err)
3630 }
3631 }
3632
3633 func (c *concurrentTxStmtExecTest) finish(t testing.TB) {
3634 if c.stmt != nil {
3635 c.stmt.Close()
3636 c.stmt = nil
3637 }
3638 if c.tx != nil {
3639 c.tx.Rollback()
3640 c.tx = nil
3641 }
3642 c.db = nil
3643 }
3644
3645 func (c *concurrentTxStmtExecTest) test(t testing.TB) error {
3646 _, err := c.stmt.Exec(3, chrisBirthday)
3647 if err != nil {
3648 t.Errorf("error on exec: %v", err)
3649 return err
3650 }
3651 return nil
3652 }
3653
3654 type concurrentRandomTest struct {
3655 tests []concurrentTest
3656 }
3657
3658 func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
3659 c.tests = []concurrentTest{
3660 new(concurrentDBQueryTest),
3661 new(concurrentDBExecTest),
3662 new(concurrentStmtQueryTest),
3663 new(concurrentStmtExecTest),
3664 new(concurrentTxQueryTest),
3665 new(concurrentTxExecTest),
3666 new(concurrentTxStmtQueryTest),
3667 new(concurrentTxStmtExecTest),
3668 }
3669 for _, ct := range c.tests {
3670 ct.init(t, db)
3671 }
3672 }
3673
3674 func (c *concurrentRandomTest) finish(t testing.TB) {
3675 for _, ct := range c.tests {
3676 ct.finish(t)
3677 }
3678 }
3679
3680 func (c *concurrentRandomTest) test(t testing.TB) error {
3681 ct := c.tests[rand.Intn(len(c.tests))]
3682 return ct.test(t)
3683 }
3684
3685 func doConcurrentTest(t testing.TB, ct concurrentTest) {
3686 maxProcs, numReqs := 1, 500
3687 if testing.Short() {
3688 maxProcs, numReqs = 4, 50
3689 }
3690 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
3691
3692 db := newTestDB(t, "people")
3693 defer closeDB(t, db)
3694
3695 ct.init(t, db)
3696 defer ct.finish(t)
3697
3698 var wg sync.WaitGroup
3699 wg.Add(numReqs)
3700
3701 reqs := make(chan bool)
3702 defer close(reqs)
3703
3704 for i := 0; i < maxProcs*2; i++ {
3705 go func() {
3706 for range reqs {
3707 err := ct.test(t)
3708 if err != nil {
3709 wg.Done()
3710 continue
3711 }
3712 wg.Done()
3713 }
3714 }()
3715 }
3716
3717 for i := 0; i < numReqs; i++ {
3718 reqs <- true
3719 }
3720
3721 wg.Wait()
3722 }
3723
3724 func TestIssue6081(t *testing.T) {
3725 synctest.Test(t, testIssue6081)
3726 }
3727 func testIssue6081(t *testing.T) {
3728 db := newTestDB(t, "people")
3729 defer closeDB(t, db)
3730
3731 drv := db.Driver().(*fakeDriver)
3732 drv.mu.Lock()
3733 opens0 := drv.openCount
3734 closes0 := drv.closeCount
3735 drv.mu.Unlock()
3736
3737 stmt, err := db.Prepare("SELECT|people|name|")
3738 if err != nil {
3739 t.Fatal(err)
3740 }
3741 setRowsCloseHook(func(rows *Rows, err *error) {
3742 *err = driver.ErrBadConn
3743 })
3744 defer setRowsCloseHook(nil)
3745 for i := 0; i < 10; i++ {
3746 rows, err := stmt.Query()
3747 if err != nil {
3748 t.Fatal(err)
3749 }
3750 rows.Close()
3751 }
3752 if n := len(stmt.css); n > 1 {
3753 t.Errorf("len(css slice) = %d; want <= 1", n)
3754 }
3755 stmt.Close()
3756 if n := len(stmt.css); n != 0 {
3757 t.Errorf("len(css slice) after Close = %d; want 0", n)
3758 }
3759
3760 drv.mu.Lock()
3761 opens := drv.openCount - opens0
3762 closes := drv.closeCount - closes0
3763 drv.mu.Unlock()
3764 if opens < 9 {
3765 t.Errorf("opens = %d; want >= 9", opens)
3766 }
3767 if closes < 9 {
3768 t.Errorf("closes = %d; want >= 9", closes)
3769 }
3770 }
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783 func TestIssue18429(t *testing.T) {
3784 synctest.Test(t, testIssue18429)
3785 }
3786 func testIssue18429(t *testing.T) {
3787 db := newTestDB(t, "people")
3788 defer closeDB(t, db)
3789
3790 ctx := context.Background()
3791 sem := make(chan bool, 20)
3792 var wg sync.WaitGroup
3793
3794 const milliWait = 30
3795
3796 for i := 0; i < 100; i++ {
3797 sem <- true
3798 wg.Add(1)
3799 go func() {
3800 defer func() {
3801 <-sem
3802 wg.Done()
3803 }()
3804 qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
3805
3806 ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
3807 defer cancel()
3808
3809 tx, err := db.BeginTx(ctx, nil)
3810 if err != nil {
3811 return
3812 }
3813
3814
3815
3816 rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
3817 if rows != nil {
3818 var name string
3819
3820 for rows.Next() {
3821
3822 rows.Scan(&name)
3823 }
3824 rows.Close()
3825 }
3826
3827
3828 tx.Rollback()
3829 }()
3830 }
3831 wg.Wait()
3832 }
3833
3834
3835 func TestIssue20160(t *testing.T) {
3836 synctest.Test(t, testIssue20160)
3837 }
3838 func testIssue20160(t *testing.T) {
3839 db := newTestDB(t, "people")
3840 defer closeDB(t, db)
3841
3842 ctx := context.Background()
3843 sem := make(chan bool, 20)
3844 var wg sync.WaitGroup
3845
3846 const milliWait = 30
3847
3848 stmt, err := db.PrepareContext(ctx, "SELECT|people|name|")
3849 if err != nil {
3850 t.Fatal(err)
3851 }
3852 defer stmt.Close()
3853
3854 for i := 0; i < 100; i++ {
3855 sem <- true
3856 wg.Add(1)
3857 go func() {
3858 defer func() {
3859 <-sem
3860 wg.Done()
3861 }()
3862 ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
3863 defer cancel()
3864
3865
3866
3867
3868 rows, _ := stmt.QueryContext(ctx)
3869 if rows != nil {
3870 rows.Close()
3871 }
3872 }()
3873 }
3874 wg.Wait()
3875 }
3876
3877
3878
3879
3880
3881
3882 func TestIssue18719(t *testing.T) {
3883 synctest.Test(t, testIssue18719)
3884 }
3885 func testIssue18719(t *testing.T) {
3886 db := newTestDB(t, "people")
3887 defer closeDB(t, db)
3888
3889 ctx, cancel := context.WithCancel(context.Background())
3890 defer cancel()
3891
3892 tx, err := db.BeginTx(ctx, nil)
3893 if err != nil {
3894 t.Fatal(err)
3895 }
3896
3897 hookTxGrabConn = func() {
3898 cancel()
3899 synctest.Wait()
3900 if !tx.isDone() {
3901 t.Errorf("tx is not done")
3902 }
3903 }
3904 defer func() { hookTxGrabConn = nil }()
3905
3906
3907
3908 _, err = tx.QueryContext(ctx, "SELECT|people|name|")
3909 if err != nil {
3910 t.Fatalf("expected error %v but got %v", nil, err)
3911 }
3912
3913
3914
3915
3916
3917
3918 cancel()
3919 }
3920
3921 func TestIssue20647(t *testing.T) {
3922 synctest.Test(t, testIssue20647)
3923 }
3924 func testIssue20647(t *testing.T) {
3925 db := newTestDB(t, "people")
3926 defer closeDB(t, db)
3927
3928 ctx := t.Context()
3929
3930 conn, err := db.Conn(ctx)
3931 if err != nil {
3932 t.Fatal(err)
3933 }
3934 conn.dc.ci.(*fakeConn).skipDirtySession = true
3935 defer conn.Close()
3936
3937 stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
3938 if err != nil {
3939 t.Fatal(err)
3940 }
3941 defer stmt.Close()
3942
3943 rows1, err := stmt.QueryContext(ctx)
3944 if err != nil {
3945 t.Fatal("rows1", err)
3946 }
3947 defer rows1.Close()
3948
3949 rows2, err := stmt.QueryContext(ctx)
3950 if err != nil {
3951 t.Fatal("rows2", err)
3952 }
3953 defer rows2.Close()
3954
3955 if rows1.dc != rows2.dc {
3956 t.Fatal("stmt prepared on Conn does not use same connection")
3957 }
3958 }
3959
3960 func TestConcurrency(t *testing.T) {
3961 list := []struct {
3962 name string
3963 ct concurrentTest
3964 }{
3965 {"Query", new(concurrentDBQueryTest)},
3966 {"Exec", new(concurrentDBExecTest)},
3967 {"StmtQuery", new(concurrentStmtQueryTest)},
3968 {"StmtExec", new(concurrentStmtExecTest)},
3969 {"TxQuery", new(concurrentTxQueryTest)},
3970 {"TxExec", new(concurrentTxExecTest)},
3971 {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
3972 {"TxStmtExec", new(concurrentTxStmtExecTest)},
3973 {"Random", new(concurrentRandomTest)},
3974 }
3975 for _, item := range list {
3976 synctestSubtest(t, item.name, func(t *testing.T) {
3977 doConcurrentTest(t, item.ct)
3978 })
3979 }
3980 }
3981
3982 func TestConnectionLeak(t *testing.T) {
3983 synctest.Test(t, testConnectionLeak)
3984 }
3985 func testConnectionLeak(t *testing.T) {
3986 db := newTestDB(t, "people")
3987 defer closeDB(t, db)
3988
3989 rows := make([]*Rows, defaultMaxIdleConns)
3990
3991
3992
3993 db.SetMaxOpenConns(len(rows) + 1)
3994 for ii := range rows {
3995 r, err := db.Query("SELECT|people|name|")
3996 if err != nil {
3997 t.Fatal(err)
3998 }
3999 r.Next()
4000 if err := r.Err(); err != nil {
4001 t.Fatal(err)
4002 }
4003 rows[ii] = r
4004 }
4005
4006
4007
4008 drv := db.Driver().(*fakeDriver)
4009 drv.waitCh = make(chan struct{}, 1)
4010 drv.waitingCh = make(chan struct{}, 1)
4011 var wg sync.WaitGroup
4012 wg.Add(1)
4013 go func() {
4014 r, err := db.Query("SELECT|people|name|")
4015 if err != nil {
4016 t.Error(err)
4017 return
4018 }
4019 r.Close()
4020 wg.Done()
4021 }()
4022
4023 <-drv.waitingCh
4024
4025
4026 for _, v := range rows {
4027 v.Close()
4028 }
4029
4030
4031
4032
4033 drv.waitCh <- struct{}{}
4034 wg.Wait()
4035 }
4036
4037 func TestStatsMaxIdleClosedZero(t *testing.T) {
4038 synctest.Test(t, testStatsMaxIdleClosedZero)
4039 }
4040 func testStatsMaxIdleClosedZero(t *testing.T) {
4041 db := newTestDB(t, "people")
4042 defer closeDB(t, db)
4043
4044 db.SetMaxOpenConns(1)
4045 db.SetMaxIdleConns(1)
4046 db.SetConnMaxLifetime(0)
4047
4048 preMaxIdleClosed := db.Stats().MaxIdleClosed
4049
4050 for i := 0; i < 10; i++ {
4051 rows, err := db.Query("SELECT|people|name|")
4052 if err != nil {
4053 t.Fatal(err)
4054 }
4055 rows.Close()
4056 }
4057
4058 st := db.Stats()
4059 maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
4060 t.Logf("MaxIdleClosed: %d", maxIdleClosed)
4061 if maxIdleClosed != 0 {
4062 t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
4063 }
4064 }
4065
4066 func TestStatsMaxIdleClosedTen(t *testing.T) {
4067 synctest.Test(t, testStatsMaxIdleClosedTen)
4068 }
4069 func testStatsMaxIdleClosedTen(t *testing.T) {
4070 db := newTestDB(t, "people")
4071 defer closeDB(t, db)
4072
4073 db.SetMaxOpenConns(1)
4074 db.SetMaxIdleConns(0)
4075 db.SetConnMaxLifetime(0)
4076
4077 preMaxIdleClosed := db.Stats().MaxIdleClosed
4078
4079 for i := 0; i < 10; i++ {
4080 rows, err := db.Query("SELECT|people|name|")
4081 if err != nil {
4082 t.Fatal(err)
4083 }
4084 rows.Close()
4085 }
4086
4087 st := db.Stats()
4088 maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
4089 t.Logf("MaxIdleClosed: %d", maxIdleClosed)
4090 if maxIdleClosed != 10 {
4091 t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
4092 }
4093 }
4094
4095
4096 func testUseConns(t *testing.T, count int, db *DB) {
4097 conns := make([]*Conn, count)
4098 ctx := context.Background()
4099 for i := range conns {
4100 time.Sleep(1 * time.Nanosecond)
4101 c, err := db.Conn(ctx)
4102 if err != nil {
4103 t.Error(err)
4104 }
4105 conns[i] = c
4106 }
4107
4108 for i := len(conns) - 1; i >= 0; i-- {
4109 time.Sleep(1 * time.Nanosecond)
4110 if err := conns[i].Close(); err != nil {
4111 t.Error(err)
4112 }
4113 }
4114 }
4115
4116 func TestMaxIdleTime(t *testing.T) {
4117 usedConns := 5
4118 reusedConns := 2
4119 list := []struct {
4120 wantMaxIdleTime time.Duration
4121 wantMaxLifetime time.Duration
4122 wantNextCheck time.Duration
4123 wantIdleClosed int64
4124 wantMaxIdleClosed int64
4125 timeOffset time.Duration
4126 secondTimeOffset time.Duration
4127 }{
4128 {
4129 time.Millisecond,
4130 0,
4131 time.Millisecond - time.Nanosecond,
4132 int64(usedConns - reusedConns),
4133 int64(usedConns - reusedConns),
4134 10 * time.Millisecond,
4135 0,
4136 },
4137 {
4138
4139 time.Millisecond,
4140
4141
4142
4143 10*time.Millisecond + 100*time.Nanosecond + 3*time.Nanosecond,
4144 time.Nanosecond,
4145
4146 int64(usedConns - reusedConns + 1),
4147 int64(usedConns - reusedConns),
4148 10 * time.Millisecond,
4149
4150 100 * time.Nanosecond,
4151 },
4152 {
4153 time.Hour,
4154 0,
4155 time.Second,
4156 0,
4157 0,
4158 10 * time.Millisecond,
4159 0},
4160 }
4161 for _, item := range list {
4162 synctestSubtest(t, fmt.Sprintf("%v", item.wantMaxIdleTime), func(t *testing.T) {
4163 startTime := time.Now()
4164 db := newTestDB(t, "people")
4165 defer closeDB(t, db)
4166
4167 db.SetMaxOpenConns(usedConns)
4168 db.SetMaxIdleConns(usedConns)
4169 db.SetConnMaxIdleTime(item.wantMaxIdleTime)
4170 db.SetConnMaxLifetime(item.wantMaxLifetime)
4171
4172 preMaxIdleClosed := db.Stats().MaxIdleTimeClosed
4173
4174
4175 testUseConns(t, usedConns, db)
4176
4177 time.Sleep(time.Until(startTime.Add(item.timeOffset)))
4178
4179
4180
4181 testUseConns(t, reusedConns, db)
4182
4183 time.Sleep(item.secondTimeOffset)
4184
4185 db.mu.Lock()
4186 nc, closing := db.connectionCleanerRunLocked(time.Second)
4187 if nc != item.wantNextCheck {
4188 t.Errorf("got %v; want %v next check duration", nc, item.wantNextCheck)
4189 }
4190
4191
4192 var last time.Time
4193 for _, c := range db.freeConn {
4194 if last.After(c.returnedAt) {
4195 t.Error("freeConn is not ordered by returnedAt")
4196 break
4197 }
4198 last = c.returnedAt
4199 }
4200
4201 db.mu.Unlock()
4202 for _, c := range closing {
4203 c.Close()
4204 }
4205 if g, w := int64(len(closing)), item.wantIdleClosed; g != w {
4206 t.Errorf("got: %d; want %d closed conns", g, w)
4207 }
4208
4209 st := db.Stats()
4210 maxIdleClosed := st.MaxIdleTimeClosed - preMaxIdleClosed
4211 if g, w := maxIdleClosed, item.wantMaxIdleClosed; g != w {
4212 t.Errorf("got: %d; want %d max idle closed conns", g, w)
4213 }
4214 })
4215 }
4216 }
4217
4218 type nvcDriver struct {
4219 fakeDriver
4220 skipNamedValueCheck bool
4221 }
4222
4223 func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
4224 c, err := d.fakeDriver.Open(dsn)
4225 fc := c.(*fakeConn)
4226 fc.db.allowAny = true
4227 return &nvcConn{fc, d.skipNamedValueCheck}, err
4228 }
4229
4230 type nvcConn struct {
4231 *fakeConn
4232 skipNamedValueCheck bool
4233 }
4234
4235 type decimalInt struct {
4236 value int
4237 }
4238
4239 type doNotInclude struct{}
4240
4241 var _ driver.NamedValueChecker = &nvcConn{}
4242
4243 func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
4244 if c.skipNamedValueCheck {
4245 return driver.ErrSkip
4246 }
4247 switch v := nv.Value.(type) {
4248 default:
4249 return driver.ErrSkip
4250 case Out:
4251 switch ov := v.Dest.(type) {
4252 default:
4253 return errors.New("unknown NameValueCheck OUTPUT type")
4254 case *string:
4255 *ov = "from-server"
4256 nv.Value = "OUT:*string"
4257 }
4258 return nil
4259 case decimalInt, []int64:
4260 return nil
4261 case doNotInclude:
4262 return driver.ErrRemoveArgument
4263 }
4264 }
4265
4266 func TestNamedValueChecker(t *testing.T) {
4267 synctest.Test(t, testNamedValueChecker)
4268 }
4269 func testNamedValueChecker(t *testing.T) {
4270 Register("NamedValueCheck", &nvcDriver{})
4271 db, err := Open("NamedValueCheck", "")
4272 if err != nil {
4273 t.Fatal(err)
4274 }
4275 defer db.Close()
4276
4277 ctx := t.Context()
4278 _, err = db.ExecContext(ctx, "WIPE")
4279 if err != nil {
4280 t.Fatal("exec wipe", err)
4281 }
4282
4283 _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
4284 if err != nil {
4285 t.Fatal("exec create", err)
4286 }
4287
4288 o1 := ""
4289 _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimalInt{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
4290 if err != nil {
4291 t.Fatal("exec insert", err)
4292 }
4293 var (
4294 str1 string
4295 dec1 decimalInt
4296 arr1 []int64
4297 )
4298 err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
4299 if err != nil {
4300 t.Fatal("select", err)
4301 }
4302
4303 list := []struct{ got, want any }{
4304 {o1, "from-server"},
4305 {dec1, decimalInt{123}},
4306 {str1, "hello"},
4307 {arr1, []int64{42, 128, 707}},
4308 }
4309
4310 for index, item := range list {
4311 if !reflect.DeepEqual(item.got, item.want) {
4312 t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
4313 }
4314 }
4315 }
4316
4317 func TestNamedValueCheckerSkip(t *testing.T) {
4318 synctest.Test(t, testNamedValueCheckerSkip)
4319 }
4320 func testNamedValueCheckerSkip(t *testing.T) {
4321 Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
4322 db, err := Open("NamedValueCheckSkip", "")
4323 if err != nil {
4324 t.Fatal(err)
4325 }
4326 defer db.Close()
4327
4328 ctx := t.Context()
4329 _, err = db.ExecContext(ctx, "WIPE")
4330 if err != nil {
4331 t.Fatal("exec wipe", err)
4332 }
4333
4334 _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
4335 if err != nil {
4336 t.Fatal("exec create", err)
4337 }
4338
4339 _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimalInt{123}))
4340 if err == nil {
4341 t.Fatalf("expected error with bad argument, got %v", err)
4342 }
4343 }
4344
4345 func TestOpenConnector(t *testing.T) {
4346 synctest.Test(t, testOpenConnector)
4347 }
4348 func testOpenConnector(t *testing.T) {
4349 Register("testctx", &fakeDriverCtx{})
4350 db, err := Open("testctx", "people")
4351 if err != nil {
4352 t.Fatal(err)
4353 }
4354 defer db.Close()
4355
4356 c, ok := db.connector.(*fakeConnector)
4357 if !ok {
4358 t.Fatal("not using *fakeConnector")
4359 }
4360
4361 if err := db.Close(); err != nil {
4362 t.Fatal(err)
4363 }
4364
4365 if !c.closed {
4366 t.Fatal("connector is not closed")
4367 }
4368 }
4369
4370 type ctxOnlyDriver struct {
4371 fakeDriver
4372 }
4373
4374 func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) {
4375 conn, err := d.fakeDriver.Open(dsn)
4376 if err != nil {
4377 return nil, err
4378 }
4379 return &ctxOnlyConn{fc: conn.(*fakeConn)}, nil
4380 }
4381
4382 var (
4383 _ driver.Conn = &ctxOnlyConn{}
4384 _ driver.QueryerContext = &ctxOnlyConn{}
4385 _ driver.ExecerContext = &ctxOnlyConn{}
4386 )
4387
4388 type ctxOnlyConn struct {
4389 fc *fakeConn
4390
4391 queryCtxCalled bool
4392 execCtxCalled bool
4393 }
4394
4395 func (c *ctxOnlyConn) Begin() (driver.Tx, error) {
4396 return c.fc.Begin()
4397 }
4398
4399 func (c *ctxOnlyConn) Close() error {
4400 return c.fc.Close()
4401 }
4402
4403
4404
4405 func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) {
4406 panic("not used")
4407 }
4408
4409 func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
4410 return c.fc.PrepareContext(ctx, q)
4411 }
4412
4413 func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
4414 c.queryCtxCalled = true
4415 return c.fc.QueryContext(ctx, q, args)
4416 }
4417
4418 func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
4419 c.execCtxCalled = true
4420 return c.fc.ExecContext(ctx, q, args)
4421 }
4422
4423
4424
4425 func TestQueryExecContextOnly(t *testing.T) {
4426 synctest.Test(t, testQueryExecContextOnly)
4427 }
4428 func testQueryExecContextOnly(t *testing.T) {
4429
4430 var connType driver.Conn = &ctxOnlyConn{}
4431 if _, ok := connType.(driver.Execer); ok {
4432 t.Fatalf("%T must not implement driver.Execer", connType)
4433 }
4434 if _, ok := connType.(driver.Queryer); ok {
4435 t.Fatalf("%T must not implement driver.Queryer", connType)
4436 }
4437
4438 Register("ContextOnly", &ctxOnlyDriver{})
4439 db, err := Open("ContextOnly", "")
4440 if err != nil {
4441 t.Fatal(err)
4442 }
4443 defer db.Close()
4444
4445 ctx := t.Context()
4446
4447 conn, err := db.Conn(ctx)
4448 if err != nil {
4449 t.Fatal("db.Conn", err)
4450 }
4451 defer conn.Close()
4452 coc := conn.dc.ci.(*ctxOnlyConn)
4453 coc.fc.skipDirtySession = true
4454
4455 _, err = conn.ExecContext(ctx, "WIPE")
4456 if err != nil {
4457 t.Fatal("exec wipe", err)
4458 }
4459
4460 _, err = conn.ExecContext(ctx, "CREATE|keys|v1=string")
4461 if err != nil {
4462 t.Fatal("exec create", err)
4463 }
4464 expectedValue := "value1"
4465 _, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue)
4466 if err != nil {
4467 t.Fatal("exec insert", err)
4468 }
4469 rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|")
4470 if err != nil {
4471 t.Fatal("query select", err)
4472 }
4473 v1 := ""
4474 for rows.Next() {
4475 err = rows.Scan(&v1)
4476 if err != nil {
4477 t.Fatal("rows scan", err)
4478 }
4479 }
4480 rows.Close()
4481
4482 if v1 != expectedValue {
4483 t.Fatalf("expected %q, got %q", expectedValue, v1)
4484 }
4485
4486 if !coc.execCtxCalled {
4487 t.Error("ExecContext not called")
4488 }
4489 if !coc.queryCtxCalled {
4490 t.Error("QueryContext not called")
4491 }
4492 }
4493
4494 type alwaysErrScanner struct{}
4495
4496 var errTestScanWrap = errors.New("errTestScanWrap")
4497
4498 func (alwaysErrScanner) Scan(any) error {
4499 return errTestScanWrap
4500 }
4501
4502
4503 func TestRowsScanProperlyWrapsErrors(t *testing.T) {
4504 synctest.Test(t, testRowsScanProperlyWrapsErrors)
4505 }
4506 func testRowsScanProperlyWrapsErrors(t *testing.T) {
4507 db := newTestDB(t, "people")
4508 defer closeDB(t, db)
4509
4510 rows, err := db.Query("SELECT|people|age|")
4511 if err != nil {
4512 t.Fatalf("Query: %v", err)
4513 }
4514
4515 var res alwaysErrScanner
4516
4517 for rows.Next() {
4518 err = rows.Scan(&res)
4519 if err == nil {
4520 t.Fatal("expecting back an error")
4521 }
4522 if !errors.Is(err, errTestScanWrap) {
4523 t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errTestScanWrap)
4524 }
4525
4526 if !strings.Contains(err.Error(), errTestScanWrap.Error()) {
4527 t.Fatalf("Error %v does not contain %v", err, errTestScanWrap)
4528 }
4529 }
4530 }
4531
4532 type alwaysErrValuer struct{}
4533
4534
4535 var errEmpty = errors.New("empty value")
4536
4537 func (v alwaysErrValuer) Value() (driver.Value, error) {
4538 return nil, errEmpty
4539 }
4540
4541
4542 func TestDriverArgsWrapsErrors(t *testing.T) {
4543
4544 synctestSubtest(t, "exec", func(t *testing.T) {
4545 db := newTestDB(t, "people")
4546 defer closeDB(t, db)
4547 _, err := db.Exec("INSERT|keys|dec1=?", alwaysErrValuer{})
4548 if err == nil {
4549 t.Fatal("expecting back an error")
4550 }
4551 if !errors.Is(err, errEmpty) {
4552 t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errEmpty)
4553 }
4554
4555 if !strings.Contains(err.Error(), errEmpty.Error()) {
4556 t.Fatalf("Error %v does not contain %v", err, errEmpty)
4557 }
4558 })
4559
4560 synctestSubtest(t, "query", func(t *testing.T) {
4561 db := newTestDB(t, "people")
4562 defer closeDB(t, db)
4563 _, err := db.Query("INSERT|keys|dec1=?", alwaysErrValuer{})
4564 if err == nil {
4565 t.Fatal("expecting back an error")
4566 }
4567 if !errors.Is(err, errEmpty) {
4568 t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errEmpty)
4569 }
4570
4571 if !strings.Contains(err.Error(), errEmpty.Error()) {
4572 t.Fatalf("Error %v does not contain %v", err, errEmpty)
4573 }
4574 })
4575 }
4576
4577 func TestContextCancelDuringRawBytesScan(t *testing.T) {
4578 for _, mode := range []string{"nocancel", "top", "bottom", "go"} {
4579 synctestSubtest(t, mode, func(t *testing.T) {
4580 testContextCancelDuringRawBytesScan(t, mode)
4581 })
4582 }
4583 }
4584
4585
4586 func testContextCancelDuringRawBytesScan(t *testing.T, mode string) {
4587 db := newTestDB(t, "people")
4588 defer closeDB(t, db)
4589
4590
4591
4592 ctx, cancel := context.WithCancel(context.Background())
4593 defer cancel()
4594
4595 r, err := db.QueryContext(ctx, "SELECT|people|name|")
4596 if err != nil {
4597 t.Fatal(err)
4598 }
4599 numRows := 0
4600 var sink byte
4601 for r.Next() {
4602 if mode == "top" && numRows == 2 {
4603
4604 cancel()
4605 synctest.Wait()
4606 }
4607 numRows++
4608 var s RawBytes
4609 err = r.Scan(&s)
4610 if numRows == 3 && err == context.Canceled {
4611 if r.closemuScanHold {
4612 t.Errorf("expected closemu NOT to be held")
4613 }
4614 break
4615 }
4616 if !r.closemuScanHold {
4617 t.Errorf("expected closemu to be held")
4618 }
4619 if err != nil {
4620 t.Fatal(err)
4621 }
4622 t.Logf("read %q", s)
4623 if mode == "bottom" && numRows == 2 {
4624
4625 cancel()
4626 synctest.Wait()
4627 }
4628 if mode == "go" && numRows == 2 {
4629
4630 go cancel()
4631 }
4632 for _, b := range s {
4633 sink += b
4634 }
4635
4636 _, _ = r.Columns()
4637 }
4638 if r.closemuScanHold {
4639 t.Errorf("closemu held; should not be")
4640 }
4641
4642
4643
4644 switch numRows {
4645 case 0, 1:
4646 t.Errorf("got %d rows; want 2+", numRows)
4647 case 2:
4648 if err := r.Err(); err != context.Canceled {
4649 t.Errorf("unexpected error: %v (%T)", err, err)
4650 }
4651 default:
4652
4653 }
4654
4655 if err := r.Close(); err != nil {
4656 t.Fatal(err)
4657 }
4658 }
4659
4660 func TestContextCancelBetweenNextAndErr(t *testing.T) {
4661 synctest.Test(t, testContextCancelBetweenNextAndErr)
4662 }
4663 func testContextCancelBetweenNextAndErr(t *testing.T) {
4664 db := newTestDB(t, "people")
4665 defer closeDB(t, db)
4666 ctx, cancel := context.WithCancel(context.Background())
4667 defer cancel()
4668
4669 r, err := db.QueryContext(ctx, "SELECT|people|name|")
4670 if err != nil {
4671 t.Fatal(err)
4672 }
4673 for r.Next() {
4674 }
4675 cancel()
4676 synctest.Wait()
4677 if err := r.Err(); err != nil {
4678 t.Fatal(err)
4679 }
4680 }
4681
4682 type testScanner struct {
4683 scanf func(src any) error
4684 }
4685
4686 func (ts testScanner) Scan(src any) error { return ts.scanf(src) }
4687
4688 func TestContextCancelDuringScan(t *testing.T) {
4689 synctest.Test(t, testContextCancelDuringScan)
4690 }
4691 func testContextCancelDuringScan(t *testing.T) {
4692 db := newTestDB(t, "people")
4693 defer closeDB(t, db)
4694
4695 ctx, cancel := context.WithCancel(context.Background())
4696 defer cancel()
4697
4698 scanStart := make(chan any)
4699 scanEnd := make(chan error)
4700 scanner := &testScanner{
4701 scanf: func(src any) error {
4702 scanStart <- src
4703 return <-scanEnd
4704 },
4705 }
4706
4707
4708 want := []byte("Alice")
4709 r, err := db.QueryContext(ctx, "SELECT|people|name|name=?", string(want))
4710 if err != nil {
4711 t.Fatal(err)
4712 }
4713 if !r.Next() {
4714 t.Fatalf("r.Next() = false, want true")
4715 }
4716 go func() {
4717 r.Scan(scanner)
4718 }()
4719 got := <-scanStart
4720 defer close(scanEnd)
4721 gotBytes, ok := got.([]byte)
4722 if !ok {
4723 t.Fatalf("r.Scan returned %T, want []byte", got)
4724 }
4725 if !bytes.Equal(gotBytes, want) {
4726 t.Fatalf("before cancel: r.Scan returned %q, want %q", gotBytes, want)
4727 }
4728
4729
4730 cancel()
4731 synctest.Wait()
4732
4733
4734 if !bytes.Equal(gotBytes, want) {
4735 t.Fatalf("after cancel: r.Scan result is now %q, want %q", gotBytes, want)
4736 }
4737 }
4738
4739 func TestNilErrorAfterClose(t *testing.T) {
4740 synctest.Test(t, testNilErrorAfterClose)
4741 }
4742 func testNilErrorAfterClose(t *testing.T) {
4743 db := newTestDB(t, "people")
4744 defer closeDB(t, db)
4745
4746
4747
4748
4749 ctx, cancel := context.WithCancel(context.Background())
4750 defer cancel()
4751
4752 r, err := db.QueryContext(ctx, "SELECT|people|name|")
4753 if err != nil {
4754 t.Fatal(err)
4755 }
4756
4757 if err := r.Close(); err != nil {
4758 t.Fatal(err)
4759 }
4760
4761 synctest.Wait()
4762 if err := r.Err(); err != nil {
4763 t.Fatal(err)
4764 }
4765 }
4766
4767
4768
4769
4770
4771 func TestRawBytesReuse(t *testing.T) {
4772 synctest.Test(t, testRawBytesReuse)
4773 }
4774 func testRawBytesReuse(t *testing.T) {
4775 db := newTestDB(t, "people")
4776 defer closeDB(t, db)
4777
4778 var raw RawBytes
4779
4780
4781 rows, err := db.Query("SELECT|people|name|")
4782 if err != nil {
4783 t.Fatal(err)
4784 }
4785 rows.Next()
4786 rows.Scan(&raw)
4787 name1 := string(raw)
4788 rows.Close()
4789
4790
4791 rows, err = db.Query("SELECT|people|age|")
4792 if err != nil {
4793 t.Fatal(err)
4794 }
4795 rows.Next()
4796 rows.Scan(&raw)
4797 rows.Close()
4798
4799
4800 rows, err = db.Query("SELECT|people|name|")
4801 if err != nil {
4802 t.Fatal(err)
4803 }
4804 rows.Next()
4805 rows.Scan(&raw)
4806 name2 := string(raw)
4807 rows.Close()
4808 if name1 != name2 {
4809 t.Fatalf("Scan read name %q, want %q", name2, name1)
4810 }
4811 }
4812
4813
4814
4815 type badConn struct{}
4816
4817 func (bc badConn) Prepare(query string) (driver.Stmt, error) {
4818 return nil, errors.New("badConn Prepare")
4819 }
4820
4821 func (bc badConn) Close() error {
4822 return nil
4823 }
4824
4825 func (bc badConn) Begin() (driver.Tx, error) {
4826 return nil, errors.New("badConn Begin")
4827 }
4828
4829 func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) {
4830 panic("badConn.Exec")
4831 }
4832
4833
4834 type badDriver struct{}
4835
4836 func (bd badDriver) Open(name string) (driver.Conn, error) {
4837 return badConn{}, nil
4838 }
4839
4840
4841 func TestBadDriver(t *testing.T) {
4842 synctest.Test(t, testBadDriver)
4843 }
4844 func testBadDriver(t *testing.T) {
4845 Register("bad", badDriver{})
4846 db, err := Open("bad", "ignored")
4847 if err != nil {
4848 t.Fatal(err)
4849 }
4850 defer func() {
4851 if r := recover(); r == nil {
4852 t.Error("expected panic")
4853 } else {
4854 if want := "badConn.Exec"; r.(string) != want {
4855 t.Errorf("panic was %v, expected %v", r, want)
4856 }
4857 }
4858 }()
4859 defer db.Close()
4860 db.Exec("ignored")
4861 }
4862
4863 type pingDriver struct {
4864 fails bool
4865 }
4866
4867 type pingConn struct {
4868 badConn
4869 driver *pingDriver
4870 }
4871
4872 var pingError = errors.New("Ping failed")
4873
4874 func (pc pingConn) Ping(ctx context.Context) error {
4875 if pc.driver.fails {
4876 return pingError
4877 }
4878 return nil
4879 }
4880
4881 var _ driver.Pinger = pingConn{}
4882
4883 func (pd *pingDriver) Open(name string) (driver.Conn, error) {
4884 return pingConn{driver: pd}, nil
4885 }
4886
4887 func TestPing(t *testing.T) {
4888 synctest.Test(t, testPing)
4889 }
4890 func testPing(t *testing.T) {
4891 driver := &pingDriver{}
4892 Register("ping", driver)
4893
4894 db, err := Open("ping", "ignored")
4895 if err != nil {
4896 t.Fatal(err)
4897 }
4898 defer db.Close()
4899
4900 if err := db.Ping(); err != nil {
4901 t.Errorf("err was %#v, expected nil", err)
4902 return
4903 }
4904
4905 driver.fails = true
4906 if err := db.Ping(); err != pingError {
4907 t.Errorf("err was %#v, expected pingError", err)
4908 }
4909 }
4910
4911
4912 func TestTypedString(t *testing.T) {
4913 synctest.Test(t, testTypedString)
4914 }
4915 func testTypedString(t *testing.T) {
4916 db := newTestDB(t, "people")
4917 defer closeDB(t, db)
4918
4919 type Str string
4920 var scanned Str
4921
4922 err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned)
4923 if err != nil {
4924 t.Fatal(err)
4925 }
4926 expected := Str("Alice")
4927 if scanned != expected {
4928 t.Errorf("expected %+v, got %+v", expected, scanned)
4929 }
4930 }
4931
4932 func BenchmarkConcurrentDBExec(b *testing.B) {
4933 b.ReportAllocs()
4934 ct := new(concurrentDBExecTest)
4935 for i := 0; i < b.N; i++ {
4936 doConcurrentTest(b, ct)
4937 }
4938 }
4939
4940 func BenchmarkConcurrentStmtQuery(b *testing.B) {
4941 b.ReportAllocs()
4942 ct := new(concurrentStmtQueryTest)
4943 for i := 0; i < b.N; i++ {
4944 doConcurrentTest(b, ct)
4945 }
4946 }
4947
4948 func BenchmarkConcurrentStmtExec(b *testing.B) {
4949 b.ReportAllocs()
4950 ct := new(concurrentStmtExecTest)
4951 for i := 0; i < b.N; i++ {
4952 doConcurrentTest(b, ct)
4953 }
4954 }
4955
4956 func BenchmarkConcurrentTxQuery(b *testing.B) {
4957 b.ReportAllocs()
4958 ct := new(concurrentTxQueryTest)
4959 for i := 0; i < b.N; i++ {
4960 doConcurrentTest(b, ct)
4961 }
4962 }
4963
4964 func BenchmarkConcurrentTxExec(b *testing.B) {
4965 b.ReportAllocs()
4966 ct := new(concurrentTxExecTest)
4967 for i := 0; i < b.N; i++ {
4968 doConcurrentTest(b, ct)
4969 }
4970 }
4971
4972 func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
4973 b.ReportAllocs()
4974 ct := new(concurrentTxStmtQueryTest)
4975 for i := 0; i < b.N; i++ {
4976 doConcurrentTest(b, ct)
4977 }
4978 }
4979
4980 func BenchmarkConcurrentTxStmtExec(b *testing.B) {
4981 b.ReportAllocs()
4982 ct := new(concurrentTxStmtExecTest)
4983 for i := 0; i < b.N; i++ {
4984 doConcurrentTest(b, ct)
4985 }
4986 }
4987
4988 func BenchmarkConcurrentRandom(b *testing.B) {
4989 b.ReportAllocs()
4990 ct := new(concurrentRandomTest)
4991 for i := 0; i < b.N; i++ {
4992 doConcurrentTest(b, ct)
4993 }
4994 }
4995
4996 func BenchmarkManyConcurrentQueries(b *testing.B) {
4997 b.ReportAllocs()
4998
4999 const parallelism = 16
5000
5001 db := newTestDB(b, "magicquery")
5002 defer closeDB(b, db)
5003 db.SetMaxIdleConns(runtime.GOMAXPROCS(0) * parallelism)
5004
5005 stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
5006 if err != nil {
5007 b.Fatal(err)
5008 }
5009 defer stmt.Close()
5010
5011 b.SetParallelism(parallelism)
5012 b.RunParallel(func(pb *testing.PB) {
5013 for pb.Next() {
5014 rows, err := stmt.Query("sleep", 1)
5015 if err != nil {
5016 b.Error(err)
5017 return
5018 }
5019 rows.Close()
5020 }
5021 })
5022 }
5023
5024 func TestGrabConnAllocs(t *testing.T) {
5025 synctest.Test(t, testGrabConnAllocs)
5026 }
5027 func testGrabConnAllocs(t *testing.T) {
5028 testenv.SkipIfOptimizationOff(t)
5029 if race.Enabled {
5030 t.Skip("skipping allocation test when using race detector")
5031 }
5032 c := new(Conn)
5033 ctx := context.Background()
5034 n := int(testing.AllocsPerRun(1000, func() {
5035 _, release, err := c.grabConn(ctx)
5036 if err != nil {
5037 t.Fatal(err)
5038 }
5039 release(nil)
5040 }))
5041 if n > 0 {
5042 t.Fatalf("Conn.grabConn allocated %v objects; want 0", n)
5043 }
5044 }
5045
5046 func BenchmarkGrabConn(b *testing.B) {
5047 b.ReportAllocs()
5048 c := new(Conn)
5049 ctx := context.Background()
5050 for i := 0; i < b.N; i++ {
5051 _, release, err := c.grabConn(ctx)
5052 if err != nil {
5053 b.Fatal(err)
5054 }
5055 release(nil)
5056 }
5057 }
5058
5059 func TestConnRequestSet(t *testing.T) {
5060 var s connRequestSet
5061 wantLen := func(want int) {
5062 t.Helper()
5063 if got := s.Len(); got != want {
5064 t.Errorf("Len = %d; want %d", got, want)
5065 }
5066 if want == 0 && !t.Failed() {
5067 if _, ok := s.TakeRandom(); ok {
5068 t.Fatalf("TakeRandom returned result when empty")
5069 }
5070 }
5071 }
5072 reset := func() { s = connRequestSet{} }
5073
5074 t.Run("add-delete", func(t *testing.T) {
5075 reset()
5076 wantLen(0)
5077 dh := s.Add(nil)
5078 wantLen(1)
5079 if !s.Delete(dh) {
5080 t.Fatal("failed to delete")
5081 }
5082 wantLen(0)
5083 if s.Delete(dh) {
5084 t.Error("delete worked twice")
5085 }
5086 wantLen(0)
5087 })
5088 t.Run("take-before-delete", func(t *testing.T) {
5089 reset()
5090 ch1 := make(chan connRequest)
5091 dh := s.Add(ch1)
5092 wantLen(1)
5093 if got, ok := s.TakeRandom(); !ok || got != ch1 {
5094 t.Fatalf("wrong take; ok=%v", ok)
5095 }
5096 wantLen(0)
5097 if s.Delete(dh) {
5098 t.Error("unexpected delete after take")
5099 }
5100 })
5101 t.Run("get-take-many", func(t *testing.T) {
5102 reset()
5103 m := map[chan connRequest]bool{}
5104 const N = 100
5105 var inOrder, backOut []chan connRequest
5106 for range N {
5107 c := make(chan connRequest)
5108 m[c] = true
5109 s.Add(c)
5110 inOrder = append(inOrder, c)
5111 }
5112 if s.Len() != N {
5113 t.Fatalf("Len = %v; want %v", s.Len(), N)
5114 }
5115 for s.Len() > 0 {
5116 c, ok := s.TakeRandom()
5117 if !ok {
5118 t.Fatal("failed to take when non-empty")
5119 }
5120 if !m[c] {
5121 t.Fatal("returned item not in remaining set")
5122 }
5123 delete(m, c)
5124 backOut = append(backOut, c)
5125 }
5126 if len(m) > 0 {
5127 t.Error("items remain in expected map")
5128 }
5129 if slices.Equal(inOrder, backOut) {
5130 t.Error("wasn't random")
5131 }
5132 })
5133 t.Run("close-delete", func(t *testing.T) {
5134 reset()
5135 ch := make(chan connRequest)
5136 dh := s.Add(ch)
5137 wantLen(1)
5138 s.CloseAndRemoveAll()
5139 wantLen(0)
5140 if s.Delete(dh) {
5141 t.Error("unexpected delete after CloseAndRemoveAll")
5142 }
5143 })
5144 }
5145
5146 func BenchmarkConnRequestSet(b *testing.B) {
5147 var s connRequestSet
5148 for range b.N {
5149 for range 16 {
5150 s.Add(nil)
5151 }
5152 for range 8 {
5153 if _, ok := s.TakeRandom(); !ok {
5154 b.Fatal("want ok")
5155 }
5156 }
5157 for range 8 {
5158 s.Add(nil)
5159 }
5160 for range 16 {
5161 if _, ok := s.TakeRandom(); !ok {
5162 b.Fatal("want ok")
5163 }
5164 }
5165 if _, ok := s.TakeRandom(); ok {
5166 b.Fatal("unexpected ok")
5167 }
5168 }
5169 }
5170
5171 func TestIssue69837(t *testing.T) {
5172 u := Null[uint]{V: 1, Valid: true}
5173 val, err := driver.DefaultParameterConverter.ConvertValue(u)
5174 if err != nil {
5175 t.Errorf("ConvertValue() error = %v, want nil", err)
5176 }
5177
5178 if v, ok := val.(int64); !ok {
5179 t.Errorf("val.(type): got %T, expected int64", val)
5180 } else if v != 1 {
5181 t.Errorf("val: got %d, expected 1", v)
5182 }
5183 }
5184
5185 type issue69728Type struct {
5186 ID int
5187 Name string
5188 }
5189
5190 func (t issue69728Type) Value() (driver.Value, error) {
5191 return []byte(fmt.Sprintf("%d, %s", t.ID, t.Name)), nil
5192 }
5193
5194 func TestIssue69728(t *testing.T) {
5195 forValue := Null[issue69728Type]{
5196 Valid: true,
5197 V: issue69728Type{
5198 ID: 42,
5199 Name: "foobar",
5200 },
5201 }
5202
5203 v1, err := forValue.Value()
5204 if err != nil {
5205 t.Errorf("forValue.Value() error = %v, want nil", err)
5206 }
5207
5208 v2, err := forValue.V.Value()
5209 if err != nil {
5210 t.Errorf("forValue.V.Value() error = %v, want nil", err)
5211 }
5212
5213 if !reflect.DeepEqual(v1, v2) {
5214 t.Errorf("not equal; v1 = %v, v2 = %v", v1, v2)
5215 }
5216 }
5217
5218 func TestColumnConverterWithUnknownInputCount(t *testing.T) {
5219 db := OpenDB(&unknownInputsConnector{})
5220 stmt, err := db.Prepare("SELECT ?")
5221 if err != nil {
5222 t.Fatal(err)
5223 }
5224 _, err = stmt.Exec(1)
5225 if err != nil {
5226 t.Fatal(err)
5227 }
5228 }
5229
5230 type unknownInputsConnector struct{}
5231
5232 func (unknownInputsConnector) Connect(context.Context) (driver.Conn, error) {
5233 return unknownInputsConn{}, nil
5234 }
5235
5236 func (unknownInputsConnector) Driver() driver.Driver { return nil }
5237
5238 type unknownInputsConn struct{}
5239
5240 func (unknownInputsConn) Prepare(string) (driver.Stmt, error) { return unknownInputsStmt{}, nil }
5241 func (unknownInputsConn) Close() error { return nil }
5242 func (unknownInputsConn) Begin() (driver.Tx, error) { return nil, nil }
5243
5244 type unknownInputsStmt struct{}
5245
5246 func (unknownInputsStmt) Close() error { return nil }
5247 func (unknownInputsStmt) NumInput() int { return -1 }
5248 func (unknownInputsStmt) Exec(args []driver.Value) (driver.Result, error) {
5249 if _, ok := args[0].(string); !ok {
5250 return nil, fmt.Errorf("Expected string, got %T", args[0])
5251 }
5252 return nil, nil
5253 }
5254 func (unknownInputsStmt) Query([]driver.Value) (driver.Rows, error) { return nil, nil }
5255 func (unknownInputsStmt) ColumnConverter(idx int) driver.ValueConverter {
5256 return unknownInputsValueConverter{}
5257 }
5258
5259 type unknownInputsValueConverter struct{}
5260
5261 func (unknownInputsValueConverter) ConvertValue(v any) (driver.Value, error) {
5262 return "string", nil
5263 }
5264
5265 func TestNullTypeScanErrorConsistency(t *testing.T) {
5266
5267
5268
5269
5270 tests := []struct {
5271 name string
5272 scanner Scanner
5273 input any
5274 wantErr bool
5275 }{
5276 {
5277 name: "NullInt32 with invalid input",
5278 scanner: &NullInt32{},
5279 input: []byte("not_a_number"),
5280 wantErr: true,
5281 },
5282 {
5283 name: "NullInt64 with invalid input",
5284 scanner: &NullInt64{},
5285 input: []byte("not_a_number"),
5286 wantErr: true,
5287 },
5288 {
5289 name: "NullFloat64 with invalid input",
5290 scanner: &NullFloat64{},
5291 input: []byte("not_a_float"),
5292 wantErr: true,
5293 },
5294 {
5295 name: "NullBool with invalid input",
5296 scanner: &NullBool{},
5297 input: []byte("not_a_bool"),
5298 wantErr: true,
5299 },
5300
5301 {
5302 name: "NullInt32 with valid input",
5303 scanner: &NullInt32{},
5304 input: int64(42),
5305 wantErr: false,
5306 },
5307 {
5308 name: "NullInt64 with valid input",
5309 scanner: &NullInt64{},
5310 input: int64(42),
5311 wantErr: false,
5312 },
5313 {
5314 name: "NullFloat64 with valid input",
5315 scanner: &NullFloat64{},
5316 input: float64(3.14),
5317 wantErr: false,
5318 },
5319 {
5320 name: "NullBool with valid input",
5321 scanner: &NullBool{},
5322 input: true,
5323 wantErr: false,
5324 },
5325 {
5326 name: "NullString with valid input",
5327 scanner: &NullString{},
5328 input: "hello",
5329 wantErr: false,
5330 },
5331 }
5332
5333 for _, tt := range tests {
5334 t.Run(tt.name, func(t *testing.T) {
5335 err := tt.scanner.Scan(tt.input)
5336
5337
5338 if (err != nil) != tt.wantErr {
5339 t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
5340 }
5341
5342
5343
5344 var valid bool
5345 switch s := tt.scanner.(type) {
5346 case *NullInt32:
5347 valid = s.Valid
5348 case *NullInt64:
5349 valid = s.Valid
5350 case *NullFloat64:
5351 valid = s.Valid
5352 case *NullBool:
5353 valid = s.Valid
5354 case *NullString:
5355 valid = s.Valid
5356 case *NullTime:
5357 valid = s.Valid
5358 }
5359
5360 if err != nil && valid {
5361 t.Errorf("Scan() returned error but Valid=true; want Valid=false when err!=nil")
5362 }
5363 if err == nil && !valid {
5364 t.Errorf("Scan() returned nil error but Valid=false; want Valid=true when err==nil")
5365 }
5366 })
5367 }
5368 }
5369
5370
5371 func TestNullTypeScanNil(t *testing.T) {
5372 tests := []struct {
5373 name string
5374 scanner Scanner
5375 }{
5376 {"NullString", &NullString{String: "preset", Valid: true}},
5377 {"NullInt64", &NullInt64{Int64: 42, Valid: true}},
5378 {"NullInt32", &NullInt32{Int32: 42, Valid: true}},
5379 {"NullFloat64", &NullFloat64{Float64: 3.14, Valid: true}},
5380 {"NullBool", &NullBool{Bool: true, Valid: true}},
5381 {"NullTime", &NullTime{Time: time.Now(), Valid: true}},
5382 }
5383
5384 for _, tt := range tests {
5385 t.Run(tt.name, func(t *testing.T) {
5386 err := tt.scanner.Scan(nil)
5387 if err != nil {
5388 t.Errorf("Scan(nil) error = %v; want nil", err)
5389 }
5390
5391 var valid bool
5392 switch s := tt.scanner.(type) {
5393 case *NullString:
5394 valid = s.Valid
5395 case *NullInt64:
5396 valid = s.Valid
5397 case *NullInt32:
5398 valid = s.Valid
5399 case *NullFloat64:
5400 valid = s.Valid
5401 case *NullBool:
5402 valid = s.Valid
5403 case *NullTime:
5404 valid = s.Valid
5405 }
5406
5407 if valid {
5408 t.Errorf("Scan(nil) left Valid=true; want Valid=false")
5409 }
5410 })
5411 }
5412 }
5413
View as plain text