Source file src/database/sql/sql_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"database/sql/driver"
    11  	"errors"
    12  	"fmt"
    13  	"internal/race"
    14  	"internal/testenv"
    15  	"math/rand"
    16  	"reflect"
    17  	"runtime"
    18  	"slices"
    19  	"strings"
    20  	"sync"
    21  	"sync/atomic"
    22  	"testing"
    23  	"testing/synctest"
    24  	"time"
    25  	"uuid"
    26  )
    27  
    28  type requireFeature string
    29  
    30  // testDatabase executes f in a synctest bubble.
    31  //
    32  // It executes several subtests, each with a database driver supporting
    33  // a different set of optional interfaces (QueryerContext, etc.).
    34  //
    35  // Limit a test to drivers implementing a certain feature by passing
    36  // a requireFeature option. For example:
    37  //
    38  //	// testFunc only executes with drivers which implement Validator.
    39  //	testDatabase(t, testFunc, requireFeature("Validator"))
    40  func testDatabase(t *testing.T, f func(t *testing.T, db *DB), opts ...any) {
    41  	var require []string
    42  	for _, o := range opts {
    43  		switch o := o.(type) {
    44  		case requireFeature:
    45  			require = append(require, string(o))
    46  		default:
    47  			t.Fatalf("unrecognized option %T", o)
    48  		}
    49  	}
    50  Test:
    51  	for _, test := range []struct {
    52  		name      string
    53  		connector driver.Connector
    54  		features  []string
    55  	}{
    56  		{
    57  			// Basic driver supporting none of the optional driver interfaces.
    58  			name:      "basic",
    59  			connector: &basicConnector{name: fakeDBName},
    60  		},
    61  		{
    62  			// Default test driver. Supports some but not all features.
    63  			// This is the "default" because this is the only driver we used
    64  			// before adding testDatabase.
    65  			name:      "default",
    66  			connector: &fakeConnector{name: fakeDBName},
    67  			features: []string{
    68  				"ConnBeginTx",
    69  				"NamedValue",
    70  				"Validator",
    71  			},
    72  		},
    73  		{
    74  			name:      "scancols",
    75  			connector: &rowsColumnScannerConnector{name: fakeDBName},
    76  			features: []string{
    77  				"ConnBeginTx",
    78  				"NamedValue",
    79  				"Validator",
    80  				"ScanColumn",
    81  			},
    82  		},
    83  	} {
    84  		for _, req := range require {
    85  			if !slices.Contains(test.features, req) {
    86  				continue Test
    87  			}
    88  		}
    89  		t.Run(test.name, func(t *testing.T) {
    90  			synctest.Test(t, func(t *testing.T) {
    91  				db := OpenDB(test.connector)
    92  				if _, err := db.Exec("WIPE"); err != nil {
    93  					t.Fatalf("exec wipe: %v", err)
    94  				}
    95  				t.Cleanup(func() {
    96  					closeDB(t, db)
    97  				})
    98  				f(t, db)
    99  			})
   100  		})
   101  	}
   102  }
   103  
   104  func init() {
   105  	type dbConn struct {
   106  		db *DB
   107  		c  *driverConn
   108  	}
   109  	freedFrom := make(map[dbConn]string)
   110  	var mu sync.Mutex
   111  	getFreedFrom := func(c dbConn) string {
   112  		mu.Lock()
   113  		defer mu.Unlock()
   114  		return freedFrom[c]
   115  	}
   116  	setFreedFrom := func(c dbConn, s string) {
   117  		mu.Lock()
   118  		defer mu.Unlock()
   119  		freedFrom[c] = s
   120  	}
   121  	putConnHook = func(db *DB, c *driverConn) {
   122  		if slices.Contains(db.freeConn, c) {
   123  			// print before panic, as panic may get lost due to conflicting panic
   124  			// (all goroutines asleep) elsewhere, since we might not unlock
   125  			// the mutex in freeConn here.
   126  			println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack())
   127  			panic("double free of conn.")
   128  		}
   129  		setFreedFrom(dbConn{db, c}, stack())
   130  	}
   131  }
   132  
   133  const fakeDBName = "foo"
   134  
   135  var chrisBirthday = time.Unix(123456789, 0)
   136  
   137  func newTestDB(t testing.TB, name string) *DB {
   138  	return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name)
   139  }
   140  
   141  func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
   142  	fc.name = fakeDBName
   143  	db := OpenDB(fc)
   144  	if _, err := db.Exec("WIPE"); err != nil {
   145  		t.Fatalf("exec wipe: %v", err)
   146  	}
   147  	t.Cleanup(func() {
   148  		closeDB(t, db)
   149  	})
   150  	if name != "" {
   151  		populate(t, db, name)
   152  	}
   153  	return db
   154  }
   155  
   156  func populate(t testing.TB, db *DB, name string) {
   157  	t.Helper()
   158  	switch name {
   159  	case "people":
   160  		exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
   161  		exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
   162  		exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
   163  		exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
   164  	case "magicquery":
   165  		// Magic table name and column, known by fakedb_test.go.
   166  		exec(t, db, "CREATE|magicquery|op=string,millis=int32")
   167  		exec(t, db, "INSERT|magicquery|op=sleep,millis=10")
   168  	case "tx_status":
   169  		// Magic table name and column, known by fakedb_test.go.
   170  		exec(t, db, "CREATE|tx_status|tx_status=string")
   171  		exec(t, db, "INSERT|tx_status|tx_status=invalid")
   172  	default:
   173  		t.Fatalf("unknown database name %q", name)
   174  	}
   175  }
   176  
   177  // basicConn implements only the bare minimum of the driver.Conn interface.
   178  type basicConn struct {
   179  	driver.Conn
   180  }
   181  
   182  func (c *basicConn) getFakeConn() *fakeConn {
   183  	return c.Conn.(*fakeConn)
   184  }
   185  
   186  func (c *basicConn) Prepare(query string) (driver.Stmt, error) {
   187  	stmt, err := c.Conn.(*fakeConn).PrepareContext(context.Background(), query)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	return &basicStmt{fc: c.Conn.(*fakeConn), Stmt: stmt}, nil
   192  }
   193  
   194  func (c *basicConn) Begin() (driver.Tx, error) {
   195  	return c.Conn.(*fakeConn).Begin()
   196  }
   197  
   198  func (c *basicConn) Close() error {
   199  	return c.Conn.(*fakeConn).Close()
   200  }
   201  
   202  type basicStmt struct {
   203  	fc *fakeConn
   204  	driver.Stmt
   205  }
   206  
   207  func (s *basicStmt) valuesToNamedValues(args []driver.Value) ([]driver.NamedValue, error) {
   208  	nv := make([]driver.NamedValue, len(args))
   209  	for i, arg := range args {
   210  		val, err := s.Stmt.(*fakeStmt).ColumnConverter(i).ConvertValue(arg)
   211  		if err != nil {
   212  			return nil, fmt.Errorf("sql: converting argument $%v type: %w", i+1, err)
   213  		}
   214  		nv[i] = driver.NamedValue{
   215  			Ordinal: i + 1,
   216  			Value:   val,
   217  		}
   218  	}
   219  	return nv, nil
   220  }
   221  
   222  func (s *basicStmt) Exec(args []driver.Value) (driver.Result, error) {
   223  	nvs, err := s.valuesToNamedValues(args)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	return s.Stmt.(*fakeStmt).ExecContext(context.Background(), nvs)
   228  }
   229  
   230  func (s *basicStmt) Query(args []driver.Value) (driver.Rows, error) {
   231  	nvs, err := s.valuesToNamedValues(args)
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  	return s.Stmt.(*fakeStmt).QueryContext(context.Background(), nvs)
   236  }
   237  
   238  type basicConnector struct {
   239  	fakeConnector
   240  }
   241  
   242  func (c *basicConnector) Connect(ctx context.Context) (driver.Conn, error) {
   243  	conn, err := c.fakeConnector.Connect(ctx)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	fc := getFakeConn(conn)
   248  	fc.skipDirtySession = true // Conn won't implement ResetSession
   249  	return &basicConn{fc}, nil
   250  }
   251  
   252  func TestOpenDB(t *testing.T) {
   253  	synctest.Test(t, testOpenDB)
   254  }
   255  func testOpenDB(t *testing.T) {
   256  	db := OpenDB(dsnConnector{dsn: fakeDBName, driver: fdriver})
   257  	defer db.Close()
   258  	if db.Driver() != fdriver {
   259  		t.Fatalf("OpenDB should return the driver of the Connector")
   260  	}
   261  }
   262  
   263  func TestDriverPanic(t *testing.T) {
   264  	synctest.Test(t, testDriverPanic)
   265  }
   266  func testDriverPanic(t *testing.T) {
   267  	db, err := Open("test", fakeDBName)
   268  	if err != nil {
   269  		t.Fatalf("Open: %v", err)
   270  	}
   271  	defer db.Close()
   272  	// Test that if driver panics, database/sql does not deadlock.
   273  	expectPanic := func(name string, f func()) {
   274  		defer func() {
   275  			err := recover()
   276  			if err == nil {
   277  				t.Fatalf("%s did not panic", name)
   278  			}
   279  		}()
   280  		f()
   281  	}
   282  
   283  	expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") })
   284  	exec(t, db, "WIPE") // check not deadlocked
   285  	expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") })
   286  	exec(t, db, "WIPE") // check not deadlocked
   287  	expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") })
   288  	exec(t, db, "WIPE")             // check not deadlocked
   289  	exec(t, db, "PANIC|Query|WIPE") // should run successfully: Exec does not call Query
   290  	exec(t, db, "WIPE")             // check not deadlocked
   291  
   292  	exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
   293  
   294  	expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") })
   295  	expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") })
   296  	expectPanic("Query Close", func() {
   297  		rows, err := db.Query("PANIC|Close|SELECT|people|age,name|")
   298  		if err != nil {
   299  			t.Fatal(err)
   300  		}
   301  		rows.Close()
   302  	})
   303  	db.Query("PANIC|Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec
   304  	exec(t, db, "WIPE")                            // check not deadlocked
   305  }
   306  
   307  func exec(t testing.TB, db *DB, query string, args ...any) {
   308  	t.Helper()
   309  	_, err := db.Exec(query, args...)
   310  	if err != nil {
   311  		t.Fatalf("Exec of %q: %v", query, err)
   312  	}
   313  }
   314  
   315  func closeDB(t testing.TB, db *DB) {
   316  	t.Helper()
   317  
   318  	if e := recover(); e != nil {
   319  		fmt.Printf("Panic: %v\n", e)
   320  		panic(e)
   321  	}
   322  	defer setHookpostCloseConn(nil)
   323  	setHookpostCloseConn(func(_ *fakeConn, err error) {
   324  		if err != nil {
   325  			t.Errorf("Error closing fakeConn: %v", err)
   326  		}
   327  	})
   328  	db.mu.Lock()
   329  	for i, dc := range db.freeConn {
   330  		if n := len(dc.openStmt); n > 0 {
   331  			// Just a sanity check. This is legal in
   332  			// general, but if we make the tests clean up
   333  			// their statements first, then we can safely
   334  			// verify this is always zero here, and any
   335  			// other value is a leak.
   336  			t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
   337  		}
   338  	}
   339  	db.mu.Unlock()
   340  
   341  	err := db.Close()
   342  	if err != nil {
   343  		t.Fatalf("error closing DB: %v", err)
   344  	}
   345  
   346  	// Connections close asynchronously; wait for them to finish doing so.
   347  	numOpenConns := func() int {
   348  		db.mu.Lock()
   349  		defer db.mu.Unlock()
   350  		return db.numOpen
   351  	}
   352  	if _, ok := t.(*testing.B); ok {
   353  		// We don't use synctest in benchmarks, so just poll.
   354  		deadline := time.Now().Add(5 * time.Second)
   355  		for numOpenConns() > 0 && time.Now().Before(deadline) {
   356  			time.Sleep(1 * time.Millisecond)
   357  		}
   358  	} else {
   359  		synctest.Wait()
   360  	}
   361  	if numOpen := numOpenConns(); numOpen != 0 {
   362  		t.Fatalf("%d connections still open after closing DB", numOpen)
   363  	}
   364  }
   365  
   366  // numPrepares assumes that db has exactly 1 idle conn and returns
   367  // its count of calls to Prepare
   368  func numPrepares(t *testing.T, db *DB) int {
   369  	if n := len(db.freeConn); n != 1 {
   370  		t.Fatalf("free conns = %d; want 1", n)
   371  	}
   372  	return getFakeConn(db.freeConn[0].ci).numPrepare
   373  }
   374  
   375  func (db *DB) numDeps() int {
   376  	synctest.Wait()
   377  	db.mu.Lock()
   378  	defer db.mu.Unlock()
   379  	return len(db.dep)
   380  }
   381  
   382  func (db *DB) numFreeConns() int {
   383  	synctest.Wait()
   384  	db.mu.Lock()
   385  	defer db.mu.Unlock()
   386  	return len(db.freeConn)
   387  }
   388  
   389  // clearAllConns closes all connections in db.
   390  func (db *DB) clearAllConns(t *testing.T) {
   391  	db.SetMaxIdleConns(0)
   392  
   393  	if g, w := db.numFreeConns(), 0; g != w {
   394  		t.Errorf("free conns = %d; want %d", g, w)
   395  	}
   396  
   397  	if n := db.numDeps(); n > 0 {
   398  		t.Errorf("number of dependencies = %d; expected 0", n)
   399  		db.dumpDeps(t)
   400  	}
   401  }
   402  
   403  func (db *DB) dumpDeps(t *testing.T) {
   404  	for fc := range db.dep {
   405  		db.dumpDep(t, 0, fc, map[finalCloser]bool{})
   406  	}
   407  }
   408  
   409  func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) {
   410  	seen[dep] = true
   411  	indent := strings.Repeat("  ", depth)
   412  	ds := db.dep[dep]
   413  	for k := range ds {
   414  		t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k)
   415  		if fc, ok := k.(finalCloser); ok {
   416  			if !seen[fc] {
   417  				db.dumpDep(t, depth+1, fc, seen)
   418  			}
   419  		}
   420  	}
   421  }
   422  
   423  func TestQuery(t *testing.T) {
   424  	testDatabase(t, testQuery)
   425  }
   426  func testQuery(t *testing.T, db *DB) {
   427  	populate(t, db, "people")
   428  	prepares0 := numPrepares(t, db)
   429  	rows, err := db.Query("SELECT|people|age,name|")
   430  	if err != nil {
   431  		t.Fatalf("Query: %v", err)
   432  	}
   433  	defer rows.Close()
   434  	type row struct {
   435  		age  int
   436  		name string
   437  	}
   438  	got := []row{}
   439  	for rows.Next() {
   440  		var r row
   441  		err = rows.Scan(&r.age, &r.name)
   442  		if err != nil {
   443  			t.Fatalf("Scan: %v", err)
   444  		}
   445  		got = append(got, r)
   446  	}
   447  	err = rows.Err()
   448  	if err != nil {
   449  		t.Fatalf("Err: %v", err)
   450  	}
   451  	want := []row{
   452  		{age: 1, name: "Alice"},
   453  		{age: 2, name: "Bob"},
   454  		{age: 3, name: "Chris"},
   455  	}
   456  	if !slices.Equal(got, want) {
   457  		t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
   458  	}
   459  
   460  	// And verify that the final rows.Next() call, which hit EOF,
   461  	// also closed the rows connection.
   462  	if n := db.numFreeConns(); n != 1 {
   463  		t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
   464  	}
   465  	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
   466  		t.Errorf("executed %d Prepare statements; want 1", prepares)
   467  	}
   468  }
   469  
   470  // TestQueryContext tests canceling the context while scanning the rows.
   471  func TestQueryContext(t *testing.T) {
   472  	testDatabase(t, testQueryContext)
   473  }
   474  func testQueryContext(t *testing.T, db *DB) {
   475  	populate(t, db, "people")
   476  	prepares0 := numPrepares(t, db)
   477  
   478  	ctx, cancel := context.WithCancel(context.Background())
   479  	defer cancel()
   480  
   481  	rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
   482  	if err != nil {
   483  		t.Fatalf("Query: %v", err)
   484  	}
   485  	type row struct {
   486  		age  int
   487  		name string
   488  	}
   489  	got := []row{}
   490  	index := 0
   491  	for rows.Next() {
   492  		if index == 2 {
   493  			cancel()
   494  			waitForRowsClose(t, rows)
   495  		}
   496  		var r row
   497  		err = rows.Scan(&r.age, &r.name)
   498  		if err != nil {
   499  			if index == 2 {
   500  				break
   501  			}
   502  			t.Fatalf("Scan: %v", err)
   503  		}
   504  		if index == 2 && err != context.Canceled {
   505  			t.Fatalf("Scan: %v; want context.Canceled", err)
   506  		}
   507  		got = append(got, r)
   508  		index++
   509  	}
   510  	select {
   511  	case <-ctx.Done():
   512  		if err := ctx.Err(); err != context.Canceled {
   513  			t.Fatalf("context err = %v; want context.Canceled", err)
   514  		}
   515  	default:
   516  		t.Fatalf("context err = nil; want context.Canceled")
   517  	}
   518  	want := []row{
   519  		{age: 1, name: "Alice"},
   520  		{age: 2, name: "Bob"},
   521  	}
   522  	if !slices.Equal(got, want) {
   523  		t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
   524  	}
   525  
   526  	// And verify that the final rows.Next() call, which hit EOF,
   527  	// also closed the rows connection.
   528  	waitForRowsClose(t, rows)
   529  	waitForFree(t, db, 1)
   530  	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
   531  		t.Errorf("executed %d Prepare statements; want 1", prepares)
   532  	}
   533  }
   534  
   535  // waitForFree checks db.numFreeConns until either it equals want or
   536  // the maxWait time elapses.
   537  func waitForFree(t *testing.T, db *DB, want int) {
   538  	synctest.Wait()
   539  	numFree := db.numFreeConns()
   540  	if numFree != want {
   541  		t.Fatalf("free conns after hitting EOF = %d; want %d", numFree, want)
   542  	}
   543  }
   544  
   545  func waitForRowsClose(t *testing.T, rows *Rows) {
   546  	synctest.Wait()
   547  	rows.closemu.RLock()
   548  	defer rows.closemu.RUnlock()
   549  	if !rows.closed {
   550  		t.Fatal("failed to close rows")
   551  	}
   552  }
   553  
   554  // TestQueryContextWait ensures that rows and all internal statements are closed when
   555  // a query context is closed during execution.
   556  func TestQueryContextWait(t *testing.T) {
   557  	testDatabase(t, testQueryContextWait)
   558  }
   559  func testQueryContextWait(t *testing.T, db *DB) {
   560  	populate(t, db, "people")
   561  	prepares0 := numPrepares(t, db)
   562  
   563  	ctx, cancel := context.WithCancel(context.Background())
   564  	defer cancel()
   565  
   566  	// This will trigger the *fakeConn.Prepare method which will take time
   567  	// performing the query. The ctxDriverPrepare func will check the context
   568  	// after this and close the rows and return an error.
   569  	c, err := db.Conn(ctx)
   570  	if err != nil {
   571  		t.Fatal(err)
   572  	}
   573  
   574  	getFakeConn(c.dc.ci).waiter = func(c context.Context) {
   575  		cancel()
   576  		<-ctx.Done()
   577  	}
   578  	_, err = c.QueryContext(ctx, "SELECT|people|age,name|")
   579  	c.Close()
   580  	if err != context.Canceled {
   581  		t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
   582  	}
   583  
   584  	// Verify closed rows connection after error condition.
   585  	waitForFree(t, db, 1)
   586  	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
   587  		t.Fatalf("executed %d Prepare statements; want 1", prepares)
   588  	}
   589  }
   590  
   591  // TestTxContextWait tests the transaction behavior when the tx context is canceled
   592  // during execution of the query.
   593  func TestTxContextWait(t *testing.T) {
   594  	testDatabase(t, func(t *testing.T, db *DB) {
   595  		testContextWait(t, false, db)
   596  	})
   597  }
   598  
   599  // TestTxContextWaitNoDiscard is the same as TestTxContextWait, but should not discard
   600  // the final connection.
   601  func TestTxContextWaitNoDiscard(t *testing.T) {
   602  	testDatabase(t, func(t *testing.T, db *DB) {
   603  		testContextWait(t, true, db)
   604  	})
   605  }
   606  
   607  func testContextWait(t *testing.T, keepConnOnRollback bool, db *DB) {
   608  	populate(t, db, "people")
   609  
   610  	ctx, cancel := context.WithCancel(context.Background())
   611  
   612  	tx, err := db.BeginTx(ctx, nil)
   613  	if err != nil {
   614  		t.Fatal(err)
   615  	}
   616  	tx.keepConnOnRollback = keepConnOnRollback
   617  
   618  	getFakeConn(tx.dc.ci).waiter = func(c context.Context) {
   619  		cancel()
   620  		<-ctx.Done()
   621  	}
   622  	// This will trigger the *fakeConn.Prepare method which will take time
   623  	// performing the query. The ctxDriverPrepare func will check the context
   624  	// after this and close the rows and return an error.
   625  	_, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
   626  	if err != context.Canceled {
   627  		t.Fatalf("expected QueryContext to error with context canceled but returned %v", err)
   628  	}
   629  
   630  	if keepConnOnRollback {
   631  		waitForFree(t, db, 1)
   632  	} else {
   633  		waitForFree(t, db, 0)
   634  	}
   635  }
   636  
   637  // TestUnsupportedOptions checks that the database fails when a driver that
   638  // doesn't implement ConnBeginTx is used with non-default options and an
   639  // un-cancellable context.
   640  func TestUnsupportedOptions(t *testing.T) {
   641  	testDatabase(t, testUnsupportedOptions)
   642  }
   643  func testUnsupportedOptions(t *testing.T, db *DB) {
   644  	populate(t, db, "people")
   645  	_, err := db.BeginTx(context.Background(), &TxOptions{
   646  		Isolation: LevelSerializable, ReadOnly: true,
   647  	})
   648  	if err == nil {
   649  		t.Fatal("expected error when using unsupported options, got nil")
   650  	}
   651  }
   652  
   653  func TestMultiResultSetQuery(t *testing.T) {
   654  	testDatabase(t, testMultiResultSetQuery)
   655  }
   656  func testMultiResultSetQuery(t *testing.T, db *DB) {
   657  	populate(t, db, "people")
   658  	prepares0 := numPrepares(t, db)
   659  	rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
   660  	if err != nil {
   661  		t.Fatalf("Query: %v", err)
   662  	}
   663  	type row1 struct {
   664  		age  int
   665  		name string
   666  	}
   667  	type row2 struct {
   668  		name string
   669  	}
   670  	got1 := []row1{}
   671  	for rows.Next() {
   672  		var r row1
   673  		err = rows.Scan(&r.age, &r.name)
   674  		if err != nil {
   675  			t.Fatalf("Scan: %v", err)
   676  		}
   677  		got1 = append(got1, r)
   678  	}
   679  	err = rows.Err()
   680  	if err != nil {
   681  		t.Fatalf("Err: %v", err)
   682  	}
   683  	want1 := []row1{
   684  		{age: 1, name: "Alice"},
   685  		{age: 2, name: "Bob"},
   686  		{age: 3, name: "Chris"},
   687  	}
   688  	if !slices.Equal(got1, want1) {
   689  		t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
   690  	}
   691  
   692  	if !rows.NextResultSet() {
   693  		t.Errorf("expected another result set")
   694  	}
   695  
   696  	got2 := []row2{}
   697  	for rows.Next() {
   698  		var r row2
   699  		err = rows.Scan(&r.name)
   700  		if err != nil {
   701  			t.Fatalf("Scan: %v", err)
   702  		}
   703  		got2 = append(got2, r)
   704  	}
   705  	err = rows.Err()
   706  	if err != nil {
   707  		t.Fatalf("Err: %v", err)
   708  	}
   709  	want2 := []row2{
   710  		{name: "Alice"},
   711  		{name: "Bob"},
   712  		{name: "Chris"},
   713  	}
   714  	if !slices.Equal(got2, want2) {
   715  		t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
   716  	}
   717  	if rows.NextResultSet() {
   718  		t.Errorf("expected no more result sets")
   719  	}
   720  
   721  	// And verify that the final rows.Next() call, which hit EOF,
   722  	// also closed the rows connection.
   723  	waitForFree(t, db, 1)
   724  	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
   725  		t.Errorf("executed %d Prepare statements; want 1", prepares)
   726  	}
   727  }
   728  
   729  func TestQueryNamedArg(t *testing.T) {
   730  	testDatabase(t, testQueryNamedArg, requireFeature("NamedValue"))
   731  }
   732  func testQueryNamedArg(t *testing.T, db *DB) {
   733  	populate(t, db, "people")
   734  	prepares0 := numPrepares(t, db)
   735  	rows, err := db.Query(
   736  		// Ensure the name and age parameters only match on placeholder name, not position.
   737  		"SELECT|people|age,name|name=?name,age=?age",
   738  		Named("age", 2),
   739  		Named("name", "Bob"),
   740  	)
   741  	if err != nil {
   742  		t.Fatalf("Query: %v", err)
   743  	}
   744  	type row struct {
   745  		age  int
   746  		name string
   747  	}
   748  	got := []row{}
   749  	for rows.Next() {
   750  		var r row
   751  		err = rows.Scan(&r.age, &r.name)
   752  		if err != nil {
   753  			t.Fatalf("Scan: %v", err)
   754  		}
   755  		got = append(got, r)
   756  	}
   757  	err = rows.Err()
   758  	if err != nil {
   759  		t.Fatalf("Err: %v", err)
   760  	}
   761  	want := []row{
   762  		{age: 2, name: "Bob"},
   763  	}
   764  	if !slices.Equal(got, want) {
   765  		t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
   766  	}
   767  
   768  	// And verify that the final rows.Next() call, which hit EOF,
   769  	// also closed the rows connection.
   770  	if n := db.numFreeConns(); n != 1 {
   771  		t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
   772  	}
   773  	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
   774  		t.Errorf("executed %d Prepare statements; want 1", prepares)
   775  	}
   776  }
   777  
   778  func TestPoolExhaustOnCancel(t *testing.T) {
   779  	synctest.Test(t, testPoolExhaustOnCancel)
   780  }
   781  func testPoolExhaustOnCancel(t *testing.T) {
   782  	max := 3
   783  	var saturate, saturateDone sync.WaitGroup
   784  	saturate.Add(max)
   785  	saturateDone.Add(max)
   786  
   787  	donePing := make(chan bool)
   788  	state := 0
   789  
   790  	// waiter will be called for all queries, including
   791  	// initial setup queries. The state is only assigned when
   792  	// no queries are made.
   793  	//
   794  	// Only allow the first batch of queries to finish once the
   795  	// second batch of Ping queries have finished.
   796  	waiter := func(ctx context.Context) {
   797  		switch state {
   798  		case 0:
   799  			// Nothing. Initial database setup.
   800  		case 1:
   801  			saturate.Done()
   802  			select {
   803  			case <-ctx.Done():
   804  			case <-donePing:
   805  			}
   806  		case 2:
   807  		}
   808  	}
   809  	db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people")
   810  
   811  	db.SetMaxOpenConns(max)
   812  
   813  	// First saturate the connection pool.
   814  	// Then start new requests for a connection that is canceled after it is requested.
   815  
   816  	state = 1
   817  	for i := 0; i < max; i++ {
   818  		go func() {
   819  			rows, err := db.Query("SELECT|people|name,photo|")
   820  			if err != nil {
   821  				t.Errorf("Query: %v", err)
   822  				return
   823  			}
   824  			rows.Close()
   825  			saturateDone.Done()
   826  		}()
   827  	}
   828  
   829  	saturate.Wait()
   830  	if t.Failed() {
   831  		t.FailNow()
   832  	}
   833  	state = 2
   834  
   835  	// Now cancel the request while it is waiting.
   836  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   837  	defer cancel()
   838  
   839  	for i := 0; i < max; i++ {
   840  		ctxReq, cancelReq := context.WithCancel(ctx)
   841  		go func() {
   842  			time.Sleep(100 * time.Millisecond)
   843  			cancelReq()
   844  		}()
   845  		err := db.PingContext(ctxReq)
   846  		if err != context.Canceled {
   847  			t.Fatalf("PingContext (Exhaust): %v", err)
   848  		}
   849  	}
   850  	close(donePing)
   851  	saturateDone.Wait()
   852  
   853  	// Now try to open a normal connection.
   854  	err := db.PingContext(ctx)
   855  	if err != nil {
   856  		t.Fatalf("PingContext (Normal): %v", err)
   857  	}
   858  }
   859  
   860  func TestRowsColumns(t *testing.T) {
   861  	testDatabase(t, testRowsColumns)
   862  }
   863  func testRowsColumns(t *testing.T, db *DB) {
   864  	populate(t, db, "people")
   865  	rows, err := db.Query("SELECT|people|age,name|")
   866  	if err != nil {
   867  		t.Fatalf("Query: %v", err)
   868  	}
   869  	cols, err := rows.Columns()
   870  	if err != nil {
   871  		t.Fatalf("Columns: %v", err)
   872  	}
   873  	want := []string{"age", "name"}
   874  	if !slices.Equal(cols, want) {
   875  		t.Errorf("got %#v; want %#v", cols, want)
   876  	}
   877  	if err := rows.Close(); err != nil {
   878  		t.Errorf("error closing rows: %s", err)
   879  	}
   880  }
   881  
   882  func TestRowsColumnTypes(t *testing.T) {
   883  	testDatabase(t, testRowsColumnTypes)
   884  }
   885  func testRowsColumnTypes(t *testing.T, db *DB) {
   886  	populate(t, db, "people")
   887  	rows, err := db.Query("SELECT|people|age,name|")
   888  	if err != nil {
   889  		t.Fatalf("Query: %v", err)
   890  	}
   891  	tt, err := rows.ColumnTypes()
   892  	if err != nil {
   893  		t.Fatalf("ColumnTypes: %v", err)
   894  	}
   895  
   896  	types := make([]reflect.Type, len(tt))
   897  	for i, tp := range tt {
   898  		st := tp.ScanType()
   899  		if st == nil {
   900  			t.Errorf("scantype is null for column %q", tp.Name())
   901  			continue
   902  		}
   903  		types[i] = st
   904  	}
   905  	values := make([]any, len(tt))
   906  	for i := range values {
   907  		values[i] = reflect.New(types[i]).Interface()
   908  	}
   909  	ct := 0
   910  	for rows.Next() {
   911  		err = rows.Scan(values...)
   912  		if err != nil {
   913  			t.Fatalf("failed to scan values in %v", err)
   914  		}
   915  		if ct == 1 {
   916  			if age := *values[0].(*int32); age != 2 {
   917  				t.Errorf("Expected 2, got %v", age)
   918  			}
   919  			if name := *values[1].(*string); name != "Bob" {
   920  				t.Errorf("Expected Bob, got %v", name)
   921  			}
   922  		}
   923  		ct++
   924  	}
   925  	if ct != 3 {
   926  		t.Errorf("expected 3 rows, got %d", ct)
   927  	}
   928  
   929  	if err := rows.Close(); err != nil {
   930  		t.Errorf("error closing rows: %s", err)
   931  	}
   932  }
   933  
   934  func TestQueryRow(t *testing.T) {
   935  	testDatabase(t, testQueryRow)
   936  }
   937  func testQueryRow(t *testing.T, db *DB) {
   938  	populate(t, db, "people")
   939  	var name string
   940  	var age int
   941  	var birthday time.Time
   942  
   943  	err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age)
   944  	if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") {
   945  		t.Errorf("expected error from wrong number of arguments; actually got: %v", err)
   946  	}
   947  
   948  	err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday)
   949  	if err != nil || !birthday.Equal(chrisBirthday) {
   950  		t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday)
   951  	}
   952  
   953  	err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name)
   954  	if err != nil {
   955  		t.Fatalf("age QueryRow+Scan: %v", err)
   956  	}
   957  	if name != "Bob" {
   958  		t.Errorf("expected name Bob, got %q", name)
   959  	}
   960  	if age != 2 {
   961  		t.Errorf("expected age 2, got %d", age)
   962  	}
   963  
   964  	err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name)
   965  	if err != nil {
   966  		t.Fatalf("name QueryRow+Scan: %v", err)
   967  	}
   968  	if name != "Alice" {
   969  		t.Errorf("expected name Alice, got %q", name)
   970  	}
   971  	if age != 1 {
   972  		t.Errorf("expected age 1, got %d", age)
   973  	}
   974  
   975  	var photo []byte
   976  	err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
   977  	if err != nil {
   978  		t.Fatalf("photo QueryRow+Scan: %v", err)
   979  	}
   980  	want := []byte("APHOTO")
   981  	if !slices.Equal(photo, want) {
   982  		t.Errorf("photo = %q; want %q", photo, want)
   983  	}
   984  }
   985  
   986  func TestRowErr(t *testing.T) {
   987  	testDatabase(t, testRowErr)
   988  }
   989  func testRowErr(t *testing.T, db *DB) {
   990  	populate(t, db, "people")
   991  
   992  	row := db.QueryRowContext(context.Background(), "SELECT|people|bdate|age=?", 3)
   993  	if err := row.Err(); err != nil {
   994  		t.Errorf("Unexpected err = %v; want %v", err, nil)
   995  	}
   996  	row.Scan()
   997  
   998  	ctx, cancel := context.WithCancel(context.Background())
   999  	cancel()
  1000  
  1001  	err := db.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 3).Err()
  1002  	exp := "context canceled"
  1003  	if err == nil || !strings.Contains(err.Error(), exp) {
  1004  		t.Errorf("Expected err = %v; got %v", exp, err)
  1005  	}
  1006  }
  1007  
  1008  func TestTxRollbackCommitErr(t *testing.T) {
  1009  	testDatabase(t, testTxRollbackCommitErr)
  1010  }
  1011  func testTxRollbackCommitErr(t *testing.T, db *DB) {
  1012  	populate(t, db, "people")
  1013  
  1014  	tx, err := db.Begin()
  1015  	if err != nil {
  1016  		t.Fatal(err)
  1017  	}
  1018  	err = tx.Rollback()
  1019  	if err != nil {
  1020  		t.Errorf("expected nil error from Rollback; got %v", err)
  1021  	}
  1022  	err = tx.Commit()
  1023  	if err != ErrTxDone {
  1024  		t.Errorf("expected %q from Commit; got %q", ErrTxDone, err)
  1025  	}
  1026  
  1027  	tx, err = db.Begin()
  1028  	if err != nil {
  1029  		t.Fatal(err)
  1030  	}
  1031  	err = tx.Commit()
  1032  	if err != nil {
  1033  		t.Errorf("expected nil error from Commit; got %v", err)
  1034  	}
  1035  	err = tx.Rollback()
  1036  	if err != ErrTxDone {
  1037  		t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err)
  1038  	}
  1039  }
  1040  
  1041  func TestStatementErrorAfterClose(t *testing.T) {
  1042  	testDatabase(t, testStatementErrorAfterClose)
  1043  }
  1044  func testStatementErrorAfterClose(t *testing.T, db *DB) {
  1045  	populate(t, db, "people")
  1046  	stmt, err := db.Prepare("SELECT|people|age|name=?")
  1047  	if err != nil {
  1048  		t.Fatalf("Prepare: %v", err)
  1049  	}
  1050  	err = stmt.Close()
  1051  	if err != nil {
  1052  		t.Fatalf("Close: %v", err)
  1053  	}
  1054  	var name string
  1055  	err = stmt.QueryRow("foo").Scan(&name)
  1056  	if err == nil {
  1057  		t.Errorf("expected error from QueryRow.Scan after Stmt.Close")
  1058  	}
  1059  }
  1060  
  1061  func TestStatementQueryRow(t *testing.T) {
  1062  	testDatabase(t, testStatementQueryRow)
  1063  }
  1064  func testStatementQueryRow(t *testing.T, db *DB) {
  1065  	populate(t, db, "people")
  1066  	stmt, err := db.Prepare("SELECT|people|age|name=?")
  1067  	if err != nil {
  1068  		t.Fatalf("Prepare: %v", err)
  1069  	}
  1070  	defer stmt.Close()
  1071  	var age int
  1072  	for n, tt := range []struct {
  1073  		name string
  1074  		want int
  1075  	}{
  1076  		{"Alice", 1},
  1077  		{"Bob", 2},
  1078  		{"Chris", 3},
  1079  	} {
  1080  		if err := stmt.QueryRow(tt.name).Scan(&age); err != nil {
  1081  			t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err)
  1082  		} else if age != tt.want {
  1083  			t.Errorf("%d: age=%d, want %d", n, age, tt.want)
  1084  		}
  1085  	}
  1086  }
  1087  
  1088  type stubDriverStmt struct {
  1089  	err error
  1090  }
  1091  
  1092  func (s stubDriverStmt) Close() error {
  1093  	return s.err
  1094  }
  1095  
  1096  func (s stubDriverStmt) NumInput() int {
  1097  	return -1
  1098  }
  1099  
  1100  func (s stubDriverStmt) Exec(args []driver.Value) (driver.Result, error) {
  1101  	return nil, nil
  1102  }
  1103  
  1104  func (s stubDriverStmt) Query(args []driver.Value) (driver.Rows, error) {
  1105  	return nil, nil
  1106  }
  1107  
  1108  // golang.org/issue/12798
  1109  func TestStatementClose(t *testing.T) {
  1110  	synctest.Test(t, testStatementClose)
  1111  }
  1112  func testStatementClose(t *testing.T) {
  1113  	want := errors.New("STMT ERROR")
  1114  
  1115  	tests := []struct {
  1116  		stmt *Stmt
  1117  		msg  string
  1118  	}{
  1119  		{&Stmt{stickyErr: want}, "stickyErr not propagated"},
  1120  		{&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
  1121  	}
  1122  	for _, test := range tests {
  1123  		if err := test.stmt.Close(); err != want {
  1124  			t.Errorf("%s. Got stmt.Close() = %v, want = %v", test.msg, err, want)
  1125  		}
  1126  	}
  1127  }
  1128  
  1129  // golang.org/issue/3734
  1130  func TestStatementQueryRowConcurrent(t *testing.T) {
  1131  	testDatabase(t, testStatementQueryRowConcurrent)
  1132  }
  1133  func testStatementQueryRowConcurrent(t *testing.T, db *DB) {
  1134  	populate(t, db, "people")
  1135  	stmt, err := db.Prepare("SELECT|people|age|name=?")
  1136  	if err != nil {
  1137  		t.Fatalf("Prepare: %v", err)
  1138  	}
  1139  	defer stmt.Close()
  1140  
  1141  	const n = 10
  1142  	ch := make(chan error, n)
  1143  	for i := 0; i < n; i++ {
  1144  		go func() {
  1145  			var age int
  1146  			err := stmt.QueryRow("Alice").Scan(&age)
  1147  			if err == nil && age != 1 {
  1148  				err = fmt.Errorf("unexpected age %d", age)
  1149  			}
  1150  			ch <- err
  1151  		}()
  1152  	}
  1153  	for i := 0; i < n; i++ {
  1154  		if err := <-ch; err != nil {
  1155  			t.Error(err)
  1156  		}
  1157  	}
  1158  }
  1159  
  1160  // just a test of fakedb itself
  1161  func TestBogusPreboundParameters(t *testing.T) {
  1162  	testDatabase(t, testBogusPreboundParameters)
  1163  }
  1164  func testBogusPreboundParameters(t *testing.T, db *DB) {
  1165  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  1166  	_, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
  1167  	if err == nil {
  1168  		t.Fatalf("expected error")
  1169  	}
  1170  	if err.Error() != `fakedb: invalid conversion to int32 from "bogusconversion"` {
  1171  		t.Errorf("unexpected error: %v", err)
  1172  	}
  1173  }
  1174  
  1175  func TestExec(t *testing.T) {
  1176  	testDatabase(t, testExec)
  1177  }
  1178  func testExec(t *testing.T, db *DB) {
  1179  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  1180  	stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
  1181  	if err != nil {
  1182  		t.Errorf("Stmt, err = %v, %v", stmt, err)
  1183  	}
  1184  	defer stmt.Close()
  1185  
  1186  	type execTest struct {
  1187  		args    []any
  1188  		wantErr string
  1189  	}
  1190  	execTests := []execTest{
  1191  		// Okay:
  1192  		{[]any{"Brad", 31}, ""},
  1193  		{[]any{"Brad", int64(31)}, ""},
  1194  		{[]any{"Bob", "32"}, ""},
  1195  		{[]any{7, 9}, ""},
  1196  
  1197  		// Invalid conversions:
  1198  		{[]any{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"},
  1199  		{[]any{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`},
  1200  
  1201  		// Wrong number of args:
  1202  		{[]any{}, "sql: expected 2 arguments, got 0"},
  1203  		{[]any{1, 2, 3}, "sql: expected 2 arguments, got 3"},
  1204  	}
  1205  	for n, et := range execTests {
  1206  		_, err := stmt.Exec(et.args...)
  1207  		errStr := ""
  1208  		if err != nil {
  1209  			errStr = err.Error()
  1210  		}
  1211  		if errStr != et.wantErr {
  1212  			t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q",
  1213  				n, et.args, errStr, et.wantErr)
  1214  		}
  1215  	}
  1216  }
  1217  
  1218  func TestTxPrepare(t *testing.T) {
  1219  	testDatabase(t, testTxPrepare)
  1220  }
  1221  func testTxPrepare(t *testing.T, db *DB) {
  1222  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  1223  	tx, err := db.Begin()
  1224  	if err != nil {
  1225  		t.Fatalf("Begin = %v", err)
  1226  	}
  1227  	stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
  1228  	if err != nil {
  1229  		t.Fatalf("Stmt, err = %v, %v", stmt, err)
  1230  	}
  1231  	defer stmt.Close()
  1232  	_, err = stmt.Exec("Bobby", 7)
  1233  	if err != nil {
  1234  		t.Fatalf("Exec = %v", err)
  1235  	}
  1236  	err = tx.Commit()
  1237  	if err != nil {
  1238  		t.Fatalf("Commit = %v", err)
  1239  	}
  1240  	// Commit() should have closed the statement
  1241  	if !stmt.closed {
  1242  		t.Fatal("Stmt not closed after Commit")
  1243  	}
  1244  }
  1245  
  1246  func TestTxStmt(t *testing.T) {
  1247  	testDatabase(t, testTxStmt)
  1248  }
  1249  func testTxStmt(t *testing.T, db *DB) {
  1250  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  1251  	stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
  1252  	if err != nil {
  1253  		t.Fatalf("Stmt, err = %v, %v", stmt, err)
  1254  	}
  1255  	defer stmt.Close()
  1256  	tx, err := db.Begin()
  1257  	if err != nil {
  1258  		t.Fatalf("Begin = %v", err)
  1259  	}
  1260  	txs := tx.Stmt(stmt)
  1261  	defer txs.Close()
  1262  	_, err = txs.Exec("Bobby", 7)
  1263  	if err != nil {
  1264  		t.Fatalf("Exec = %v", err)
  1265  	}
  1266  	err = tx.Commit()
  1267  	if err != nil {
  1268  		t.Fatalf("Commit = %v", err)
  1269  	}
  1270  	// Commit() should have closed the statement
  1271  	if !txs.closed {
  1272  		t.Fatal("Stmt not closed after Commit")
  1273  	}
  1274  }
  1275  
  1276  func TestTxStmtPreparedOnce(t *testing.T) {
  1277  	testDatabase(t, testTxStmtPreparedOnce)
  1278  }
  1279  func testTxStmtPreparedOnce(t *testing.T, db *DB) {
  1280  	exec(t, db, "CREATE|t1|name=string,age=int32")
  1281  
  1282  	prepares0 := numPrepares(t, db)
  1283  
  1284  	// db.Prepare increments numPrepares.
  1285  	stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
  1286  	if err != nil {
  1287  		t.Fatalf("Stmt, err = %v, %v", stmt, err)
  1288  	}
  1289  	defer stmt.Close()
  1290  
  1291  	tx, err := db.Begin()
  1292  	if err != nil {
  1293  		t.Fatalf("Begin = %v", err)
  1294  	}
  1295  
  1296  	txs1 := tx.Stmt(stmt)
  1297  	txs2 := tx.Stmt(stmt)
  1298  
  1299  	_, err = txs1.Exec("Go", 7)
  1300  	if err != nil {
  1301  		t.Fatalf("Exec = %v", err)
  1302  	}
  1303  	txs1.Close()
  1304  
  1305  	_, err = txs2.Exec("Gopher", 8)
  1306  	if err != nil {
  1307  		t.Fatalf("Exec = %v", err)
  1308  	}
  1309  	txs2.Close()
  1310  
  1311  	err = tx.Commit()
  1312  	if err != nil {
  1313  		t.Fatalf("Commit = %v", err)
  1314  	}
  1315  
  1316  	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
  1317  		t.Errorf("executed %d Prepare statements; want 1", prepares)
  1318  	}
  1319  }
  1320  
  1321  func TestTxStmtClosedRePrepares(t *testing.T) {
  1322  	testDatabase(t, testTxStmtClosedRePrepares)
  1323  }
  1324  func testTxStmtClosedRePrepares(t *testing.T, db *DB) {
  1325  	exec(t, db, "CREATE|t1|name=string,age=int32")
  1326  
  1327  	prepares0 := numPrepares(t, db)
  1328  
  1329  	// db.Prepare increments numPrepares.
  1330  	stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
  1331  	if err != nil {
  1332  		t.Fatalf("Stmt, err = %v, %v", stmt, err)
  1333  	}
  1334  	tx, err := db.Begin()
  1335  	if err != nil {
  1336  		t.Fatalf("Begin = %v", err)
  1337  	}
  1338  	err = stmt.Close()
  1339  	if err != nil {
  1340  		t.Fatalf("stmt.Close() = %v", err)
  1341  	}
  1342  	// tx.Stmt increments numPrepares because stmt is closed.
  1343  	txs := tx.Stmt(stmt)
  1344  	if txs.stickyErr != nil {
  1345  		t.Fatal(txs.stickyErr)
  1346  	}
  1347  	if txs.parentStmt != nil {
  1348  		t.Fatal("expected nil parentStmt")
  1349  	}
  1350  	_, err = txs.Exec(`Eric`, 82)
  1351  	if err != nil {
  1352  		t.Fatalf("txs.Exec = %v", err)
  1353  	}
  1354  
  1355  	err = txs.Close()
  1356  	if err != nil {
  1357  		t.Fatalf("txs.Close = %v", err)
  1358  	}
  1359  
  1360  	tx.Rollback()
  1361  
  1362  	if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
  1363  		t.Errorf("executed %d Prepare statements; want 2", prepares)
  1364  	}
  1365  }
  1366  
  1367  func TestParentStmtOutlivesTxStmt(t *testing.T) {
  1368  	testDatabase(t, testParentStmtOutlivesTxStmt)
  1369  }
  1370  func testParentStmtOutlivesTxStmt(t *testing.T, db *DB) {
  1371  	exec(t, db, "CREATE|t1|name=string,age=int32")
  1372  
  1373  	// Make sure everything happens on the same connection.
  1374  	db.SetMaxOpenConns(1)
  1375  
  1376  	prepares0 := numPrepares(t, db)
  1377  
  1378  	// db.Prepare increments numPrepares.
  1379  	stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
  1380  	if err != nil {
  1381  		t.Fatalf("Stmt, err = %v, %v", stmt, err)
  1382  	}
  1383  	defer stmt.Close()
  1384  	tx, err := db.Begin()
  1385  	if err != nil {
  1386  		t.Fatalf("Begin = %v", err)
  1387  	}
  1388  	txs := tx.Stmt(stmt)
  1389  	if len(stmt.css) != 1 {
  1390  		t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css))
  1391  	}
  1392  	err = txs.Close()
  1393  	if err != nil {
  1394  		t.Fatalf("txs.Close() = %v", err)
  1395  	}
  1396  	err = tx.Rollback()
  1397  	if err != nil {
  1398  		t.Fatalf("tx.Rollback() = %v", err)
  1399  	}
  1400  	// txs must not be valid.
  1401  	_, err = txs.Exec("Suzan", 30)
  1402  	if err == nil {
  1403  		t.Fatalf("txs.Exec(), expected err")
  1404  	}
  1405  	// Stmt must still be valid.
  1406  	_, err = stmt.Exec("Janina", 25)
  1407  	if err != nil {
  1408  		t.Fatalf("stmt.Exec() = %v", err)
  1409  	}
  1410  
  1411  	if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
  1412  		t.Errorf("executed %d Prepare statements; want 1", prepares)
  1413  	}
  1414  }
  1415  
  1416  // Test that tx.Stmt called with a statement already
  1417  // associated with tx as argument re-prepares the same
  1418  // statement again.
  1419  func TestTxStmtFromTxStmtRePrepares(t *testing.T) {
  1420  	testDatabase(t, testTxStmtFromTxStmtRePrepares)
  1421  }
  1422  func testTxStmtFromTxStmtRePrepares(t *testing.T, db *DB) {
  1423  	exec(t, db, "CREATE|t1|name=string,age=int32")
  1424  	prepares0 := numPrepares(t, db)
  1425  	// db.Prepare increments numPrepares.
  1426  	stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
  1427  	if err != nil {
  1428  		t.Fatalf("Stmt, err = %v, %v", stmt, err)
  1429  	}
  1430  	defer stmt.Close()
  1431  
  1432  	tx, err := db.Begin()
  1433  	if err != nil {
  1434  		t.Fatalf("Begin = %v", err)
  1435  	}
  1436  	txs1 := tx.Stmt(stmt)
  1437  
  1438  	// tx.Stmt(txs1) increments numPrepares because txs1 already
  1439  	// belongs to a transaction (albeit the same transaction).
  1440  	txs2 := tx.Stmt(txs1)
  1441  	if txs2.stickyErr != nil {
  1442  		t.Fatal(txs2.stickyErr)
  1443  	}
  1444  	if txs2.parentStmt != nil {
  1445  		t.Fatal("expected nil parentStmt")
  1446  	}
  1447  	_, err = txs2.Exec(`Eric`, 82)
  1448  	if err != nil {
  1449  		t.Fatal(err)
  1450  	}
  1451  
  1452  	err = txs1.Close()
  1453  	if err != nil {
  1454  		t.Fatalf("txs1.Close = %v", err)
  1455  	}
  1456  	err = txs2.Close()
  1457  	if err != nil {
  1458  		t.Fatalf("txs1.Close = %v", err)
  1459  	}
  1460  	err = tx.Rollback()
  1461  	if err != nil {
  1462  		t.Fatalf("tx.Rollback = %v", err)
  1463  	}
  1464  
  1465  	if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
  1466  		t.Errorf("executed %d Prepare statements; want 2", prepares)
  1467  	}
  1468  }
  1469  
  1470  // Issue: https://golang.org/issue/2784
  1471  // This test didn't fail before because we got lucky with the fakedb driver.
  1472  // It was failing, and now not, in github.com/bradfitz/go-sql-test
  1473  func TestTxQuery(t *testing.T) {
  1474  	testDatabase(t, testTxQuery)
  1475  }
  1476  func testTxQuery(t *testing.T, db *DB) {
  1477  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  1478  	exec(t, db, "INSERT|t1|name=Alice")
  1479  
  1480  	tx, err := db.Begin()
  1481  	if err != nil {
  1482  		t.Fatal(err)
  1483  	}
  1484  	defer tx.Rollback()
  1485  
  1486  	r, err := tx.Query("SELECT|t1|name|")
  1487  	if err != nil {
  1488  		t.Fatal(err)
  1489  	}
  1490  	defer r.Close()
  1491  
  1492  	if !r.Next() {
  1493  		if r.Err() != nil {
  1494  			t.Fatal(r.Err())
  1495  		}
  1496  		t.Fatal("expected one row")
  1497  	}
  1498  
  1499  	var x string
  1500  	err = r.Scan(&x)
  1501  	if err != nil {
  1502  		t.Fatal(err)
  1503  	}
  1504  }
  1505  
  1506  func TestTxQueryInvalid(t *testing.T) {
  1507  	testDatabase(t, testTxQueryInvalid)
  1508  }
  1509  func testTxQueryInvalid(t *testing.T, db *DB) {
  1510  	defer closeDB(t, db)
  1511  
  1512  	tx, err := db.Begin()
  1513  	if err != nil {
  1514  		t.Fatal(err)
  1515  	}
  1516  	defer tx.Rollback()
  1517  
  1518  	_, err = tx.Query("SELECT|t1|name|")
  1519  	if err == nil {
  1520  		t.Fatal("Error expected")
  1521  	}
  1522  }
  1523  
  1524  // Tests fix for issue 4433, that retries in Begin happen when
  1525  // conn.Begin() returns ErrBadConn
  1526  func TestTxErrBadConn(t *testing.T) {
  1527  	synctest.Test(t, testTxErrBadConn)
  1528  }
  1529  func testTxErrBadConn(t *testing.T) {
  1530  	db, err := Open("test", fakeDBName+";badConn")
  1531  	if err != nil {
  1532  		t.Fatalf("Open: %v", err)
  1533  	}
  1534  	if _, err := db.Exec("WIPE"); err != nil {
  1535  		t.Fatalf("exec wipe: %v", err)
  1536  	}
  1537  	defer closeDB(t, db)
  1538  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  1539  	stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
  1540  	if err != nil {
  1541  		t.Fatalf("Stmt, err = %v, %v", stmt, err)
  1542  	}
  1543  	defer stmt.Close()
  1544  	tx, err := db.Begin()
  1545  	if err != nil {
  1546  		t.Fatalf("Begin = %v", err)
  1547  	}
  1548  	txs := tx.Stmt(stmt)
  1549  	defer txs.Close()
  1550  	_, err = txs.Exec("Bobby", 7)
  1551  	if err != nil {
  1552  		t.Fatalf("Exec = %v", err)
  1553  	}
  1554  	err = tx.Commit()
  1555  	if err != nil {
  1556  		t.Fatalf("Commit = %v", err)
  1557  	}
  1558  }
  1559  
  1560  func TestConnQuery(t *testing.T) {
  1561  	testDatabase(t, testConnQuery)
  1562  }
  1563  func testConnQuery(t *testing.T, db *DB) {
  1564  	populate(t, db, "people")
  1565  
  1566  	ctx := t.Context()
  1567  	conn, err := db.Conn(ctx)
  1568  	if err != nil {
  1569  		t.Fatal(err)
  1570  	}
  1571  	getFakeConn(conn.dc.ci).skipDirtySession = true
  1572  	defer conn.Close()
  1573  
  1574  	var name string
  1575  	err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", 3).Scan(&name)
  1576  	if err != nil {
  1577  		t.Fatal(err)
  1578  	}
  1579  	if name != "Chris" {
  1580  		t.Fatalf("unexpected result, got %q want Chris", name)
  1581  	}
  1582  
  1583  	err = conn.PingContext(ctx)
  1584  	if err != nil {
  1585  		t.Fatal(err)
  1586  	}
  1587  }
  1588  
  1589  func TestConnRaw(t *testing.T) {
  1590  	testDatabase(t, testConnRaw)
  1591  }
  1592  func testConnRaw(t *testing.T, db *DB) {
  1593  	populate(t, db, "people")
  1594  
  1595  	ctx := t.Context()
  1596  	conn, err := db.Conn(ctx)
  1597  	if err != nil {
  1598  		t.Fatal(err)
  1599  	}
  1600  	getFakeConn(conn.dc.ci).skipDirtySession = true
  1601  	defer conn.Close()
  1602  
  1603  	sawFunc := false
  1604  	err = conn.Raw(func(dc any) error {
  1605  		sawFunc = true
  1606  		_ = getFakeConn(dc.(driver.Conn))
  1607  		return nil
  1608  	})
  1609  	if err != nil {
  1610  		t.Fatal(err)
  1611  	}
  1612  	if !sawFunc {
  1613  		t.Fatal("Raw func not called")
  1614  	}
  1615  
  1616  	func() {
  1617  		defer func() {
  1618  			x := recover()
  1619  			if x == nil {
  1620  				t.Fatal("expected panic")
  1621  			}
  1622  			conn.closemu.Lock()
  1623  			closed := conn.dc == nil
  1624  			conn.closemu.Unlock()
  1625  			if !closed {
  1626  				t.Fatal("expected connection to be closed after panic")
  1627  			}
  1628  		}()
  1629  		err = conn.Raw(func(dc any) error {
  1630  			panic("Conn.Raw panic should return an error")
  1631  		})
  1632  		t.Fatal("expected panic from Raw func")
  1633  	}()
  1634  }
  1635  
  1636  func TestCursorFake(t *testing.T) {
  1637  	testDatabase(t, testCursorFake)
  1638  }
  1639  func testCursorFake(t *testing.T, db *DB) {
  1640  	populate(t, db, "people")
  1641  
  1642  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
  1643  	defer cancel()
  1644  
  1645  	exec(t, db, "CREATE|peoplecursor|list=table")
  1646  	exec(t, db, "INSERT|peoplecursor|list=people!name!age")
  1647  
  1648  	rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`)
  1649  	if err != nil {
  1650  		t.Fatal(err)
  1651  	}
  1652  	defer rows.Close()
  1653  
  1654  	if !rows.Next() {
  1655  		t.Fatal("no rows")
  1656  	}
  1657  	var cursor = &Rows{}
  1658  	err = rows.Scan(cursor)
  1659  	if err != nil {
  1660  		t.Fatal(err)
  1661  	}
  1662  	defer cursor.Close()
  1663  
  1664  	const expectedRows = 3
  1665  	var currentRow int64
  1666  
  1667  	var n int64
  1668  	var s string
  1669  	for cursor.Next() {
  1670  		currentRow++
  1671  		err = cursor.Scan(&s, &n)
  1672  		if err != nil {
  1673  			t.Fatal(err)
  1674  		}
  1675  		if n != currentRow {
  1676  			t.Errorf("expected number(Age)=%d, got %d", currentRow, n)
  1677  		}
  1678  	}
  1679  	if currentRow != expectedRows {
  1680  		t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow)
  1681  	}
  1682  }
  1683  
  1684  func TestCursorDoubleRowsPointer(t *testing.T) {
  1685  	testDatabase(t, testCursorDoubleRowsPointer)
  1686  }
  1687  func testCursorDoubleRowsPointer(t *testing.T, db *DB) {
  1688  	exec(t, db, "CREATE|table1|col=string")
  1689  	exec(t, db, "INSERT|table1|col=value")
  1690  	exec(t, db, "CREATE|cursor|list=table")
  1691  	exec(t, db, "INSERT|cursor|list=table1!col")
  1692  
  1693  	rows, err := db.QueryContext(t.Context(), `SELECT|cursor|list|`)
  1694  	if err != nil {
  1695  		t.Fatal(err)
  1696  	}
  1697  	defer rows.Close()
  1698  
  1699  	if !rows.Next() {
  1700  		t.Fatal("no rows")
  1701  	}
  1702  	var cursor *Rows
  1703  	if err := rows.Scan(&cursor); err != nil {
  1704  		t.Fatal(err)
  1705  	}
  1706  	defer cursor.Close()
  1707  
  1708  	if !cursor.Next() {
  1709  		t.Fatal("no child rows")
  1710  	}
  1711  	var col string
  1712  	if err := cursor.Scan(&col); err != nil {
  1713  		t.Fatal(err)
  1714  	}
  1715  	if got, want := col, "value"; got != want {
  1716  		t.Errorf("read col=%q, want %q", got, want)
  1717  	}
  1718  }
  1719  
  1720  func TestCursorNull(t *testing.T) {
  1721  	testDatabase(t, testCursorNull)
  1722  }
  1723  func testCursorNull(t *testing.T, db *DB) {
  1724  	exec(t, db, "CREATE|cursor|list=nulltable")
  1725  	exec(t, db, "INSERT|cursor|list=?", nil)
  1726  
  1727  	rows, err := db.QueryContext(t.Context(), `SELECT|cursor|list|`)
  1728  	if err != nil {
  1729  		t.Fatal(err)
  1730  	}
  1731  	defer rows.Close()
  1732  
  1733  	if !rows.Next() {
  1734  		t.Fatal("no rows")
  1735  	}
  1736  
  1737  	var cursor *Rows
  1738  	if err := rows.Scan(&cursor); err != nil {
  1739  		t.Fatal(err)
  1740  	}
  1741  	if cursor != nil {
  1742  		t.Errorf("Scan returned cursor, expected nil")
  1743  	}
  1744  }
  1745  
  1746  // TestCursorCancel exercises calling Rows.Close at various places,
  1747  // including canceling a cursor (child Rows).
  1748  func TestCursorCancel(t *testing.T) {
  1749  	for _, test := range []struct {
  1750  		name     string
  1751  		cancelOn string
  1752  		want     []string
  1753  	}{{
  1754  		// don't cancel
  1755  		name: "no cancel",
  1756  		want: []string{
  1757  			"table1",
  1758  			"1.1",
  1759  			"1.2",
  1760  			"table2",
  1761  			"2.1",
  1762  			"2.2",
  1763  		},
  1764  	}, {
  1765  		name:     "outer cancel",
  1766  		cancelOn: "table2",
  1767  		want: []string{
  1768  			"table1",
  1769  			"1.1",
  1770  			"1.2",
  1771  			"table2",
  1772  		},
  1773  	}, {
  1774  		name:     "inner cancel",
  1775  		cancelOn: "1.1",
  1776  		want: []string{
  1777  			"table1",
  1778  			"1.1",
  1779  			"table2",
  1780  			"2.1",
  1781  			"2.2",
  1782  		},
  1783  	}} {
  1784  		t.Run(test.name, func(t *testing.T) {
  1785  			testDatabase(t, func(t *testing.T, db *DB) {
  1786  				testCursorCancel(t, db, test.cancelOn, test.want)
  1787  			})
  1788  		})
  1789  	}
  1790  }
  1791  func testCursorCancel(t *testing.T, db *DB, cancelOn string, want []string) {
  1792  	exec(t, db, "CREATE|table1|col=string")
  1793  	exec(t, db, "INSERT|table1|col=1.1")
  1794  	exec(t, db, "INSERT|table1|col=1.2")
  1795  	exec(t, db, "CREATE|table2|col=string")
  1796  	exec(t, db, "INSERT|table2|col=2.1")
  1797  	exec(t, db, "INSERT|table2|col=2.2")
  1798  
  1799  	exec(t, db, "CREATE|cursor|name=string,list=table")
  1800  	exec(t, db, "INSERT|cursor|name=table1,list=table1!col")
  1801  	exec(t, db, "INSERT|cursor|name=table2,list=table2!col")
  1802  
  1803  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
  1804  	defer cancel()
  1805  
  1806  	rows, err := db.QueryContext(ctx, `SELECT|cursor|name,list|`)
  1807  	if err != nil {
  1808  		t.Fatal(err)
  1809  	}
  1810  	defer rows.Close()
  1811  
  1812  	var got []string
  1813  	for rows.Next() {
  1814  		var name string
  1815  		cursor := &Rows{}
  1816  		if err := rows.Scan(&name, cursor); err != nil {
  1817  			t.Fatal(err)
  1818  		}
  1819  		got = append(got, name)
  1820  		if name == cancelOn {
  1821  			rows.Close()
  1822  		}
  1823  		for cursor.Next() {
  1824  			var col string
  1825  			if err := cursor.Scan(&col); err != nil {
  1826  				t.Fatal(err)
  1827  			}
  1828  			got = append(got, col)
  1829  			if col == cancelOn {
  1830  				cursor.Close()
  1831  			}
  1832  		}
  1833  	}
  1834  
  1835  	if !slices.Equal(got, want) {
  1836  		t.Errorf("cancel after reading %q:\ngot:  %v\nwant: %v", cancelOn, got, want)
  1837  	}
  1838  }
  1839  
  1840  func TestInvalidNilValues(t *testing.T) {
  1841  	var date1 time.Time
  1842  	var date2 int
  1843  
  1844  	tests := []struct {
  1845  		name          string
  1846  		input         any
  1847  		expectedError string
  1848  	}{
  1849  		{
  1850  			name:          "time.Time",
  1851  			input:         &date1,
  1852  			expectedError: `sql: Scan error on column index 0, name "bdate": unsupported Scan, storing driver.Value type <nil> into type *time.Time`,
  1853  		},
  1854  		{
  1855  			name:          "int",
  1856  			input:         &date2,
  1857  			expectedError: `sql: Scan error on column index 0, name "bdate": converting NULL to int is unsupported`,
  1858  		},
  1859  	}
  1860  
  1861  	for _, tt := range tests {
  1862  		synctestSubtest(t, tt.name, func(t *testing.T) {
  1863  			db := newTestDB(t, "people")
  1864  			defer closeDB(t, db)
  1865  
  1866  			ctx := t.Context()
  1867  			conn, err := db.Conn(ctx)
  1868  			if err != nil {
  1869  				t.Fatal(err)
  1870  			}
  1871  			getFakeConn(conn.dc.ci).skipDirtySession = true
  1872  			defer conn.Close()
  1873  
  1874  			err = conn.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 1).Scan(tt.input)
  1875  			if err == nil {
  1876  				t.Fatal("expected error when querying nil column, but succeeded")
  1877  			}
  1878  			if err.Error() != tt.expectedError {
  1879  				t.Fatalf("Expected error: %s\nReceived: %s", tt.expectedError, err.Error())
  1880  			}
  1881  
  1882  			err = conn.PingContext(ctx)
  1883  			if err != nil {
  1884  				t.Fatal(err)
  1885  			}
  1886  		})
  1887  	}
  1888  }
  1889  
  1890  func TestConnTx(t *testing.T) {
  1891  	testDatabase(t, testConnTx)
  1892  }
  1893  func testConnTx(t *testing.T, db *DB) {
  1894  	populate(t, db, "people")
  1895  
  1896  	ctx := t.Context()
  1897  	conn, err := db.Conn(ctx)
  1898  	if err != nil {
  1899  		t.Fatal(err)
  1900  	}
  1901  	getFakeConn(conn.dc.ci).skipDirtySession = true
  1902  	defer conn.Close()
  1903  
  1904  	tx, err := conn.BeginTx(ctx, nil)
  1905  	if err != nil {
  1906  		t.Fatal(err)
  1907  	}
  1908  	insertName, insertAge := "Nancy", 33
  1909  	_, err = tx.ExecContext(ctx, "INSERT|people|name=?,age=?,photo=APHOTO", insertName, insertAge)
  1910  	if err != nil {
  1911  		t.Fatal(err)
  1912  	}
  1913  	err = tx.Commit()
  1914  	if err != nil {
  1915  		t.Fatal(err)
  1916  	}
  1917  
  1918  	var selectName string
  1919  	err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", insertAge).Scan(&selectName)
  1920  	if err != nil {
  1921  		t.Fatal(err)
  1922  	}
  1923  	if selectName != insertName {
  1924  		t.Fatalf("got %q want %q", selectName, insertName)
  1925  	}
  1926  }
  1927  
  1928  // TestConnIsValid verifies that a database connection that should be discarded,
  1929  // is actually discarded and does not re-enter the connection pool.
  1930  // If the IsValid method from *fakeConn is removed, this test will fail.
  1931  func TestConnIsValid(t *testing.T) {
  1932  	testDatabase(t, testConnIsValid, requireFeature("Validator"))
  1933  }
  1934  func testConnIsValid(t *testing.T, db *DB) {
  1935  	populate(t, db, "people")
  1936  
  1937  	db.SetMaxOpenConns(1)
  1938  
  1939  	ctx := context.Background()
  1940  
  1941  	c, err := db.Conn(ctx)
  1942  	if err != nil {
  1943  		t.Fatal(err)
  1944  	}
  1945  
  1946  	err = c.Raw(func(raw any) error {
  1947  		dc := getFakeConn(raw.(driver.Conn))
  1948  		dc.stickyBad = true
  1949  		return nil
  1950  	})
  1951  	if err != nil {
  1952  		t.Fatal(err)
  1953  	}
  1954  	c.Close()
  1955  
  1956  	if len(db.freeConn) > 0 && getFakeConn(db.freeConn[0].ci).stickyBad {
  1957  		t.Fatal("bad connection returned to pool; expected bad connection to be discarded")
  1958  	}
  1959  }
  1960  
  1961  // Tests fix for issue 2542, that we release a lock when querying on
  1962  // a closed connection.
  1963  func TestIssue2542Deadlock(t *testing.T) {
  1964  	testDatabase(t, testIssue2542Deadlock)
  1965  }
  1966  func testIssue2542Deadlock(t *testing.T, db *DB) {
  1967  	populate(t, db, "people")
  1968  	closeDB(t, db)
  1969  	for i := 0; i < 2; i++ {
  1970  		_, err := db.Query("SELECT|people|age,name|")
  1971  		if err == nil {
  1972  			t.Fatalf("expected error")
  1973  		}
  1974  	}
  1975  }
  1976  
  1977  // From golang.org/issue/3865
  1978  func TestCloseStmtBeforeRows(t *testing.T) {
  1979  	testDatabase(t, testCloseStmtBeforeRows)
  1980  }
  1981  func testCloseStmtBeforeRows(t *testing.T, db *DB) {
  1982  	populate(t, db, "people")
  1983  
  1984  	s, err := db.Prepare("SELECT|people|name|")
  1985  	if err != nil {
  1986  		t.Fatal(err)
  1987  	}
  1988  
  1989  	r, err := s.Query()
  1990  	if err != nil {
  1991  		s.Close()
  1992  		t.Fatal(err)
  1993  	}
  1994  
  1995  	err = s.Close()
  1996  	if err != nil {
  1997  		t.Fatal(err)
  1998  	}
  1999  
  2000  	r.Close()
  2001  }
  2002  
  2003  // Tests fix for issue 2788, that we bind nil to a []byte if the
  2004  // value in the column is sql null
  2005  func TestNullByteSlice(t *testing.T) {
  2006  	testDatabase(t, testNullByteSlice)
  2007  }
  2008  func testNullByteSlice(t *testing.T, db *DB) {
  2009  	exec(t, db, "CREATE|t|id=int32,name=nullstring")
  2010  	exec(t, db, "INSERT|t|id=10,name=?", nil)
  2011  
  2012  	var name []byte
  2013  
  2014  	err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
  2015  	if err != nil {
  2016  		t.Fatal(err)
  2017  	}
  2018  	if name != nil {
  2019  		t.Fatalf("name []byte should be nil for null column value, got: %#v", name)
  2020  	}
  2021  
  2022  	exec(t, db, "INSERT|t|id=11,name=?", "bob")
  2023  	err = db.QueryRow("SELECT|t|name|id=?", 11).Scan(&name)
  2024  	if err != nil {
  2025  		t.Fatal(err)
  2026  	}
  2027  	if string(name) != "bob" {
  2028  		t.Fatalf("name []byte should be bob, got: %q", string(name))
  2029  	}
  2030  }
  2031  
  2032  func TestPointerParamsAndScans(t *testing.T) {
  2033  	testDatabase(t, testPointerParamsAndScans)
  2034  }
  2035  func testPointerParamsAndScans(t *testing.T, db *DB) {
  2036  	exec(t, db, "CREATE|t|id=int32,name=nullstring")
  2037  
  2038  	bob := "bob"
  2039  	var name *string
  2040  
  2041  	name = &bob
  2042  	exec(t, db, "INSERT|t|id=10,name=?", name)
  2043  	name = nil
  2044  	exec(t, db, "INSERT|t|id=20,name=?", name)
  2045  
  2046  	err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
  2047  	if err != nil {
  2048  		t.Fatalf("querying id 10: %v", err)
  2049  	}
  2050  	if name == nil {
  2051  		t.Errorf("id 10's name = nil; want bob")
  2052  	} else if *name != "bob" {
  2053  		t.Errorf("id 10's name = %q; want bob", *name)
  2054  	}
  2055  
  2056  	err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name)
  2057  	if err != nil {
  2058  		t.Fatalf("querying id 20: %v", err)
  2059  	}
  2060  	if name != nil {
  2061  		t.Errorf("id 20 = %q; want nil", *name)
  2062  	}
  2063  }
  2064  
  2065  func TestQueryRowClosingStmt(t *testing.T) {
  2066  	testDatabase(t, testQueryRowClosingStmt)
  2067  }
  2068  func testQueryRowClosingStmt(t *testing.T, db *DB) {
  2069  	populate(t, db, "people")
  2070  	var name string
  2071  	var age int
  2072  	err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name)
  2073  	if err != nil {
  2074  		t.Fatal(err)
  2075  	}
  2076  	if len(db.freeConn) != 1 {
  2077  		t.Fatalf("expected 1 free conn")
  2078  	}
  2079  	fakeConn := getFakeConn(db.freeConn[0].ci)
  2080  	if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
  2081  		t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
  2082  	}
  2083  }
  2084  
  2085  var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)
  2086  
  2087  func init() {
  2088  	rowsCloseHook = func() func(*Rows, *error) {
  2089  		fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
  2090  		return fn
  2091  	}
  2092  }
  2093  
  2094  func setRowsCloseHook(fn func(*Rows, *error)) {
  2095  	if fn == nil {
  2096  		// Can't change an atomic.Value back to nil, so set it to this
  2097  		// no-op func instead.
  2098  		fn = func(*Rows, *error) {}
  2099  	}
  2100  	atomicRowsCloseHook.Store(fn)
  2101  }
  2102  
  2103  // Test issue 6651
  2104  func TestIssue6651(t *testing.T) {
  2105  	testDatabase(t, testIssue6651)
  2106  }
  2107  func testIssue6651(t *testing.T, db *DB) {
  2108  	populate(t, db, "people")
  2109  
  2110  	var v string
  2111  
  2112  	want := "error in rows.Next"
  2113  	rowsCursorNextHook = func(dest []driver.Value) error {
  2114  		return errors.New(want)
  2115  	}
  2116  	defer func() { rowsCursorNextHook = nil }()
  2117  
  2118  	err := db.QueryRow("SELECT|people|name|").Scan(&v)
  2119  	if err == nil || err.Error() != want {
  2120  		t.Errorf("error = %q; want %q", err, want)
  2121  	}
  2122  	rowsCursorNextHook = nil
  2123  
  2124  	want = "error in rows.Close"
  2125  	setRowsCloseHook(func(rows *Rows, err *error) {
  2126  		*err = errors.New(want)
  2127  	})
  2128  	defer setRowsCloseHook(nil)
  2129  	err = db.QueryRow("SELECT|people|name|").Scan(&v)
  2130  	if err == nil || err.Error() != want {
  2131  		t.Errorf("error = %q; want %q", err, want)
  2132  	}
  2133  }
  2134  
  2135  type nullTestRow struct {
  2136  	nullParam    any
  2137  	notNullParam any
  2138  	scanNullVal  any
  2139  }
  2140  
  2141  type nullTestSpec struct {
  2142  	nullType    string
  2143  	notNullType string
  2144  	rows        [6]nullTestRow
  2145  }
  2146  
  2147  func TestNullStringParam(t *testing.T) {
  2148  	spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
  2149  		{NullString{"aqua", true}, "", NullString{"aqua", true}},
  2150  		{NullString{"brown", false}, "", NullString{"", false}},
  2151  		{"chartreuse", "", NullString{"chartreuse", true}},
  2152  		{NullString{"darkred", true}, "", NullString{"darkred", true}},
  2153  		{NullString{"eel", false}, "", NullString{"", false}},
  2154  		{"foo", NullString{"black", false}, nil},
  2155  	}}
  2156  	testDatabase(t, func(t *testing.T, db *DB) {
  2157  		nullTestRun(t, spec, db)
  2158  	})
  2159  }
  2160  
  2161  func TestGenericNullStringParam(t *testing.T) {
  2162  	spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
  2163  		{Null[string]{"aqua", true}, "", Null[string]{"aqua", true}},
  2164  		{Null[string]{"brown", false}, "", Null[string]{"", false}},
  2165  		{"chartreuse", "", Null[string]{"chartreuse", true}},
  2166  		{Null[string]{"darkred", true}, "", Null[string]{"darkred", true}},
  2167  		{Null[string]{"eel", false}, "", Null[string]{"", false}},
  2168  		{"foo", Null[string]{"black", false}, nil},
  2169  	}}
  2170  	testDatabase(t, func(t *testing.T, db *DB) {
  2171  		nullTestRun(t, spec, db)
  2172  	})
  2173  }
  2174  
  2175  func TestNullInt64Param(t *testing.T) {
  2176  	spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{
  2177  		{NullInt64{31, true}, 1, NullInt64{31, true}},
  2178  		{NullInt64{-22, false}, 1, NullInt64{0, false}},
  2179  		{22, 1, NullInt64{22, true}},
  2180  		{NullInt64{33, true}, 1, NullInt64{33, true}},
  2181  		{NullInt64{222, false}, 1, NullInt64{0, false}},
  2182  		{0, NullInt64{31, false}, nil},
  2183  	}}
  2184  	testDatabase(t, func(t *testing.T, db *DB) {
  2185  		nullTestRun(t, spec, db)
  2186  	})
  2187  }
  2188  
  2189  func TestNullInt32Param(t *testing.T) {
  2190  	spec := nullTestSpec{"nullint32", "int32", [6]nullTestRow{
  2191  		{NullInt32{31, true}, 1, NullInt32{31, true}},
  2192  		{NullInt32{-22, false}, 1, NullInt32{0, false}},
  2193  		{22, 1, NullInt32{22, true}},
  2194  		{NullInt32{33, true}, 1, NullInt32{33, true}},
  2195  		{NullInt32{222, false}, 1, NullInt32{0, false}},
  2196  		{0, NullInt32{31, false}, nil},
  2197  	}}
  2198  	testDatabase(t, func(t *testing.T, db *DB) {
  2199  		nullTestRun(t, spec, db)
  2200  	})
  2201  }
  2202  
  2203  func TestNullInt16Param(t *testing.T) {
  2204  	spec := nullTestSpec{"nullint16", "int16", [6]nullTestRow{
  2205  		{NullInt16{31, true}, 1, NullInt16{31, true}},
  2206  		{NullInt16{-22, false}, 1, NullInt16{0, false}},
  2207  		{22, 1, NullInt16{22, true}},
  2208  		{NullInt16{33, true}, 1, NullInt16{33, true}},
  2209  		{NullInt16{222, false}, 1, NullInt16{0, false}},
  2210  		{0, NullInt16{31, false}, nil},
  2211  	}}
  2212  	testDatabase(t, func(t *testing.T, db *DB) {
  2213  		nullTestRun(t, spec, db)
  2214  	})
  2215  }
  2216  
  2217  func TestNullByteParam(t *testing.T) {
  2218  	spec := nullTestSpec{"nullbyte", "byte", [6]nullTestRow{
  2219  		{NullByte{31, true}, 1, NullByte{31, true}},
  2220  		{NullByte{0, false}, 1, NullByte{0, false}},
  2221  		{22, 1, NullByte{22, true}},
  2222  		{NullByte{33, true}, 1, NullByte{33, true}},
  2223  		{NullByte{222, false}, 1, NullByte{0, false}},
  2224  		{0, NullByte{31, false}, nil},
  2225  	}}
  2226  	testDatabase(t, func(t *testing.T, db *DB) {
  2227  		nullTestRun(t, spec, db)
  2228  	})
  2229  }
  2230  
  2231  func TestNullFloat64Param(t *testing.T) {
  2232  	spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{
  2233  		{NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}},
  2234  		{NullFloat64{13.1, false}, 1, NullFloat64{0, false}},
  2235  		{-22.9, 1, NullFloat64{-22.9, true}},
  2236  		{NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}},
  2237  		{NullFloat64{222, false}, 1, NullFloat64{0, false}},
  2238  		{10, NullFloat64{31.2, false}, nil},
  2239  	}}
  2240  	testDatabase(t, func(t *testing.T, db *DB) {
  2241  		nullTestRun(t, spec, db)
  2242  	})
  2243  }
  2244  
  2245  func TestNullBoolParam(t *testing.T) {
  2246  	spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{
  2247  		{NullBool{false, true}, true, NullBool{false, true}},
  2248  		{NullBool{true, false}, false, NullBool{false, false}},
  2249  		{true, true, NullBool{true, true}},
  2250  		{NullBool{true, true}, false, NullBool{true, true}},
  2251  		{NullBool{true, false}, true, NullBool{false, false}},
  2252  		{true, NullBool{true, false}, nil},
  2253  	}}
  2254  	testDatabase(t, func(t *testing.T, db *DB) {
  2255  		nullTestRun(t, spec, db)
  2256  	})
  2257  }
  2258  
  2259  func TestNullTimeParam(t *testing.T) {
  2260  	t0 := time.Time{}
  2261  	t1 := time.Date(2000, 1, 1, 8, 9, 10, 11, time.UTC)
  2262  	t2 := time.Date(2010, 1, 1, 8, 9, 10, 11, time.UTC)
  2263  	spec := nullTestSpec{"nulldatetime", "datetime", [6]nullTestRow{
  2264  		{NullTime{t1, true}, t2, NullTime{t1, true}},
  2265  		{NullTime{t1, false}, t2, NullTime{t0, false}},
  2266  		{t1, t2, NullTime{t1, true}},
  2267  		{NullTime{t1, true}, t2, NullTime{t1, true}},
  2268  		{NullTime{t1, false}, t2, NullTime{t0, false}},
  2269  		{t2, NullTime{t1, false}, nil},
  2270  	}}
  2271  	testDatabase(t, func(t *testing.T, db *DB) {
  2272  		nullTestRun(t, spec, db)
  2273  	})
  2274  }
  2275  
  2276  func TestNullUUIDParam(t *testing.T) {
  2277  	u0 := uuid.UUID{}
  2278  	u1 := uuid.MustParse("46cd2740-6081-4289-a659-03b61ebb92f7")
  2279  	u2 := uuid.MustParse("46cd2740-6081-4289-a659-03b61ebb92f7")
  2280  	spec := nullTestSpec{"nulluuid", "uuid", [6]nullTestRow{
  2281  		{Null[uuid.UUID]{u1, true}, u2, Null[uuid.UUID]{u1, true}},
  2282  		{Null[uuid.UUID]{u1, false}, u2, Null[uuid.UUID]{u0, false}},
  2283  		{u1, u2, Null[uuid.UUID]{u1, true}},
  2284  		{Null[uuid.UUID]{u1, true}, u2, Null[uuid.UUID]{u1, true}},
  2285  		{Null[uuid.UUID]{u1, false}, u2, Null[uuid.UUID]{u0, false}},
  2286  		{u2, Null[uuid.UUID]{u1, false}, nil},
  2287  	}}
  2288  	testDatabase(t, func(t *testing.T, db *DB) {
  2289  		nullTestRun(t, spec, db)
  2290  	})
  2291  }
  2292  
  2293  func nullTestRun(t *testing.T, spec nullTestSpec, db *DB) {
  2294  	exec(t, db, fmt.Sprintf("CREATE|t|id=int32,name=string,nullf=%s,notnullf=%s", spec.nullType, spec.notNullType))
  2295  
  2296  	// Inserts with db.Exec:
  2297  	exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 1, "alice", spec.rows[0].nullParam, spec.rows[0].notNullParam)
  2298  	exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 2, "bob", spec.rows[1].nullParam, spec.rows[1].notNullParam)
  2299  
  2300  	// Inserts with a prepared statement:
  2301  	stmt, err := db.Prepare("INSERT|t|id=?,name=?,nullf=?,notnullf=?")
  2302  	if err != nil {
  2303  		t.Fatalf("prepare: %v", err)
  2304  	}
  2305  	defer stmt.Close()
  2306  	if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
  2307  		t.Errorf("exec insert chris: %v", err)
  2308  	}
  2309  	if _, err := stmt.Exec(4, "dave", spec.rows[3].nullParam, spec.rows[3].notNullParam); err != nil {
  2310  		t.Errorf("exec insert dave: %v", err)
  2311  	}
  2312  	if _, err := stmt.Exec(5, "eleanor", spec.rows[4].nullParam, spec.rows[4].notNullParam); err != nil {
  2313  		t.Errorf("exec insert eleanor: %v", err)
  2314  	}
  2315  
  2316  	// Can't put null val into non-null col
  2317  	row5 := spec.rows[5]
  2318  	if _, err := stmt.Exec(6, "bob", row5.nullParam, row5.notNullParam); err == nil {
  2319  		t.Errorf("expected error inserting nil val with prepared statement Exec: NULL=%#v, NOT-NULL=%#v", row5.nullParam, row5.notNullParam)
  2320  	}
  2321  
  2322  	_, err = db.Exec("INSERT|t|id=?,name=?,nullf=?", 999, nil, nil)
  2323  	if err == nil {
  2324  		// TODO: this test fails, but it's just because
  2325  		// fakeConn implements the optional Execer interface,
  2326  		// so arguably this is the correct behavior. But
  2327  		// maybe I should flesh out the fakeConn.Exec
  2328  		// implementation so this properly fails.
  2329  		// t.Errorf("expected error inserting nil name with Exec")
  2330  	}
  2331  
  2332  	paramtype := reflect.TypeOf(spec.rows[0].nullParam)
  2333  	bindVal := reflect.New(paramtype).Interface()
  2334  
  2335  	for i := 0; i < 5; i++ {
  2336  		id := i + 1
  2337  		if err := db.QueryRow("SELECT|t|nullf|id=?", id).Scan(bindVal); err != nil {
  2338  			t.Errorf("id=%d Scan: %v", id, err)
  2339  		}
  2340  		bindValDeref := reflect.ValueOf(bindVal).Elem().Interface()
  2341  		if !reflect.DeepEqual(bindValDeref, spec.rows[i].scanNullVal) {
  2342  			t.Errorf("id=%d got %#v, want %#v", id, bindValDeref, spec.rows[i].scanNullVal)
  2343  		}
  2344  	}
  2345  }
  2346  
  2347  // golang.org/issue/4859
  2348  func TestQueryRowNilScanDest(t *testing.T) {
  2349  	testDatabase(t, testQueryRowNilScanDest)
  2350  }
  2351  func testQueryRowNilScanDest(t *testing.T, db *DB) {
  2352  	populate(t, db, "people")
  2353  	var name *string // nil pointer
  2354  	err := db.QueryRow("SELECT|people|name|").Scan(name)
  2355  	want := `sql: Scan error on column index 0, name "name": destination pointer is nil`
  2356  	if err == nil || err.Error() != want {
  2357  		t.Errorf("error = %q; want %q", err.Error(), want)
  2358  	}
  2359  }
  2360  
  2361  func TestIssue4902(t *testing.T) {
  2362  	testDatabase(t, testIssue4902)
  2363  }
  2364  func testIssue4902(t *testing.T, db *DB) {
  2365  	populate(t, db, "people")
  2366  
  2367  	driver := db.Driver().(*fakeDriver)
  2368  	opens0 := driver.openCount
  2369  
  2370  	var stmt *Stmt
  2371  	var err error
  2372  	for i := 0; i < 10; i++ {
  2373  		stmt, err = db.Prepare("SELECT|people|name|")
  2374  		if err != nil {
  2375  			t.Fatal(err)
  2376  		}
  2377  		err = stmt.Close()
  2378  		if err != nil {
  2379  			t.Fatal(err)
  2380  		}
  2381  	}
  2382  
  2383  	opens := driver.openCount - opens0
  2384  	if opens > 1 {
  2385  		t.Errorf("opens = %d; want <= 1", opens)
  2386  		t.Logf("db = %#v", db)
  2387  		t.Logf("driver = %#v", driver)
  2388  		t.Logf("stmt = %#v", stmt)
  2389  	}
  2390  }
  2391  
  2392  // Issue 3857
  2393  // This used to deadlock.
  2394  func TestSimultaneousQueries(t *testing.T) {
  2395  	testDatabase(t, testSimultaneousQueries)
  2396  }
  2397  func testSimultaneousQueries(t *testing.T, db *DB) {
  2398  	populate(t, db, "people")
  2399  
  2400  	tx, err := db.Begin()
  2401  	if err != nil {
  2402  		t.Fatal(err)
  2403  	}
  2404  	defer tx.Rollback()
  2405  
  2406  	r1, err := tx.Query("SELECT|people|name|")
  2407  	if err != nil {
  2408  		t.Fatal(err)
  2409  	}
  2410  	defer r1.Close()
  2411  
  2412  	r2, err := tx.Query("SELECT|people|name|")
  2413  	if err != nil {
  2414  		t.Fatal(err)
  2415  	}
  2416  	defer r2.Close()
  2417  }
  2418  
  2419  func TestMaxIdleConns(t *testing.T) {
  2420  	testDatabase(t, testMaxIdleConns)
  2421  }
  2422  func testMaxIdleConns(t *testing.T, db *DB) {
  2423  	populate(t, db, "people")
  2424  
  2425  	tx, err := db.Begin()
  2426  	if err != nil {
  2427  		t.Fatal(err)
  2428  	}
  2429  	tx.Commit()
  2430  	if got := len(db.freeConn); got != 1 {
  2431  		t.Errorf("freeConns = %d; want 1", got)
  2432  	}
  2433  
  2434  	db.SetMaxIdleConns(0)
  2435  
  2436  	if got := len(db.freeConn); got != 0 {
  2437  		t.Errorf("freeConns after set to zero = %d; want 0", got)
  2438  	}
  2439  
  2440  	tx, err = db.Begin()
  2441  	if err != nil {
  2442  		t.Fatal(err)
  2443  	}
  2444  	tx.Commit()
  2445  	if got := len(db.freeConn); got != 0 {
  2446  		t.Errorf("freeConns = %d; want 0", got)
  2447  	}
  2448  }
  2449  
  2450  func TestMaxOpenConns(t *testing.T) {
  2451  	testDatabase(t, testMaxOpenConns)
  2452  }
  2453  func testMaxOpenConns(t *testing.T, db *DB) {
  2454  	if testing.Short() {
  2455  		t.Skip("skipping in short mode")
  2456  	}
  2457  	defer setHookpostCloseConn(nil)
  2458  	setHookpostCloseConn(func(_ *fakeConn, err error) {
  2459  		if err != nil {
  2460  			t.Errorf("Error closing fakeConn: %v", err)
  2461  		}
  2462  	})
  2463  
  2464  	populate(t, db, "magicquery")
  2465  
  2466  	driver := db.Driver().(*fakeDriver)
  2467  
  2468  	// Force the number of open connections to 0 so we can get an accurate
  2469  	// count for the test
  2470  	db.clearAllConns(t)
  2471  
  2472  	driver.mu.Lock()
  2473  	opens0 := driver.openCount
  2474  	closes0 := driver.closeCount
  2475  	driver.mu.Unlock()
  2476  
  2477  	db.SetMaxIdleConns(10)
  2478  	db.SetMaxOpenConns(10)
  2479  
  2480  	stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
  2481  	if err != nil {
  2482  		t.Fatal(err)
  2483  	}
  2484  
  2485  	// Start 50 parallel slow queries.
  2486  	const (
  2487  		nquery      = 50
  2488  		sleepMillis = 25
  2489  		nbatch      = 2
  2490  	)
  2491  	var wg sync.WaitGroup
  2492  	for batch := 0; batch < nbatch; batch++ {
  2493  		for i := 0; i < nquery; i++ {
  2494  			wg.Add(1)
  2495  			go func() {
  2496  				defer wg.Done()
  2497  				var op string
  2498  				if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
  2499  					t.Error(err)
  2500  				}
  2501  			}()
  2502  		}
  2503  		// Wait for the batch of queries above to finish before starting the next round.
  2504  		wg.Wait()
  2505  	}
  2506  
  2507  	if g, w := db.numFreeConns(), 10; g != w {
  2508  		t.Errorf("free conns = %d; want %d", g, w)
  2509  	}
  2510  
  2511  	if n := db.numDeps(); n > 20 {
  2512  		t.Errorf("number of dependencies = %d; expected <= 20", n)
  2513  		db.dumpDeps(t)
  2514  	}
  2515  
  2516  	driver.mu.Lock()
  2517  	opens := driver.openCount - opens0
  2518  	closes := driver.closeCount - closes0
  2519  	driver.mu.Unlock()
  2520  
  2521  	if opens > 10 {
  2522  		t.Logf("open calls = %d", opens)
  2523  		t.Logf("close calls = %d", closes)
  2524  		t.Errorf("db connections opened = %d; want <= 10", opens)
  2525  		db.dumpDeps(t)
  2526  	}
  2527  
  2528  	if err := stmt.Close(); err != nil {
  2529  		t.Fatal(err)
  2530  	}
  2531  
  2532  	if g, w := db.numFreeConns(), 10; g != w {
  2533  		t.Errorf("free conns = %d; want %d", g, w)
  2534  	}
  2535  
  2536  	if n := db.numDeps(); n > 10 {
  2537  		t.Errorf("number of dependencies = %d; expected <= 10", n)
  2538  		db.dumpDeps(t)
  2539  	}
  2540  
  2541  	db.SetMaxOpenConns(5)
  2542  
  2543  	if g, w := db.numFreeConns(), 5; g != w {
  2544  		t.Errorf("free conns = %d; want %d", g, w)
  2545  	}
  2546  
  2547  	if n := db.numDeps(); n > 5 {
  2548  		t.Errorf("number of dependencies = %d; expected 0", n)
  2549  		db.dumpDeps(t)
  2550  	}
  2551  
  2552  	db.SetMaxOpenConns(0)
  2553  
  2554  	if g, w := db.numFreeConns(), 5; g != w {
  2555  		t.Errorf("free conns = %d; want %d", g, w)
  2556  	}
  2557  
  2558  	if n := db.numDeps(); n > 5 {
  2559  		t.Errorf("number of dependencies = %d; expected 0", n)
  2560  		db.dumpDeps(t)
  2561  	}
  2562  
  2563  	db.clearAllConns(t)
  2564  }
  2565  
  2566  // Issue 9453: tests that SetMaxOpenConns can be lowered at runtime
  2567  // and affects the subsequent release of connections.
  2568  func TestMaxOpenConnsOnBusy(t *testing.T) {
  2569  	testDatabase(t, testMaxOpenConnsOnBusy)
  2570  }
  2571  func testMaxOpenConnsOnBusy(t *testing.T, db *DB) {
  2572  	defer setHookpostCloseConn(nil)
  2573  	setHookpostCloseConn(func(_ *fakeConn, err error) {
  2574  		if err != nil {
  2575  			t.Errorf("Error closing fakeConn: %v", err)
  2576  		}
  2577  	})
  2578  
  2579  	populate(t, db, "magicquery")
  2580  
  2581  	db.SetMaxOpenConns(3)
  2582  
  2583  	ctx := context.Background()
  2584  
  2585  	conn0, err := db.conn(ctx, cachedOrNewConn)
  2586  	if err != nil {
  2587  		t.Fatalf("db open conn fail: %v", err)
  2588  	}
  2589  
  2590  	conn1, err := db.conn(ctx, cachedOrNewConn)
  2591  	if err != nil {
  2592  		t.Fatalf("db open conn fail: %v", err)
  2593  	}
  2594  
  2595  	conn2, err := db.conn(ctx, cachedOrNewConn)
  2596  	if err != nil {
  2597  		t.Fatalf("db open conn fail: %v", err)
  2598  	}
  2599  
  2600  	if g, w := db.numOpen, 3; g != w {
  2601  		t.Errorf("free conns = %d; want %d", g, w)
  2602  	}
  2603  
  2604  	db.SetMaxOpenConns(2)
  2605  	if g, w := db.numOpen, 3; g != w {
  2606  		t.Errorf("free conns = %d; want %d", g, w)
  2607  	}
  2608  
  2609  	conn0.releaseConn(nil)
  2610  	conn1.releaseConn(nil)
  2611  	if g, w := db.numOpen, 2; g != w {
  2612  		t.Errorf("free conns = %d; want %d", g, w)
  2613  	}
  2614  
  2615  	conn2.releaseConn(nil)
  2616  	if g, w := db.numOpen, 2; g != w {
  2617  		t.Errorf("free conns = %d; want %d", g, w)
  2618  	}
  2619  }
  2620  
  2621  // Issue 10886: tests that all connection attempts return when more than
  2622  // DB.maxOpen connections are in flight and the first DB.maxOpen fail.
  2623  func TestPendingConnsAfterErr(t *testing.T) {
  2624  	synctest.Test(t, testPendingConnsAfterErr)
  2625  }
  2626  func testPendingConnsAfterErr(t *testing.T) {
  2627  	const (
  2628  		maxOpen = 2
  2629  		tryOpen = maxOpen*2 + 2
  2630  	)
  2631  
  2632  	// No queries will be run.
  2633  	db, err := Open("test", fakeDBName)
  2634  	if err != nil {
  2635  		t.Fatalf("Open: %v", err)
  2636  	}
  2637  	defer closeDB(t, db)
  2638  	defer func() {
  2639  		for k, v := range db.lastPut {
  2640  			t.Logf("%p: %v", k, v)
  2641  		}
  2642  	}()
  2643  
  2644  	db.SetMaxOpenConns(maxOpen)
  2645  	db.SetMaxIdleConns(0)
  2646  
  2647  	errOffline := errors.New("db offline")
  2648  
  2649  	defer func() { setHookOpenErr(nil) }()
  2650  
  2651  	errs := make(chan error, tryOpen)
  2652  
  2653  	var opening sync.WaitGroup
  2654  	opening.Add(tryOpen)
  2655  
  2656  	setHookOpenErr(func() error {
  2657  		// Wait for all connections to enqueue.
  2658  		opening.Wait()
  2659  		return errOffline
  2660  	})
  2661  
  2662  	for i := 0; i < tryOpen; i++ {
  2663  		go func() {
  2664  			opening.Done() // signal one connection is in flight
  2665  			_, err := db.Exec("will never run")
  2666  			errs <- err
  2667  		}()
  2668  	}
  2669  
  2670  	opening.Wait() // wait for all workers to begin running
  2671  
  2672  	const timeout = 5 * time.Second
  2673  	to := time.NewTimer(timeout)
  2674  	defer to.Stop()
  2675  
  2676  	// check that all connections fail without deadlock
  2677  	for i := 0; i < tryOpen; i++ {
  2678  		select {
  2679  		case err := <-errs:
  2680  			if got, want := err, errOffline; got != want {
  2681  				t.Errorf("unexpected err: got %v, want %v", got, want)
  2682  			}
  2683  		case <-to.C:
  2684  			t.Fatalf("orphaned connection request(s), still waiting after %v", timeout)
  2685  		}
  2686  	}
  2687  
  2688  	// Wait a reasonable time for the database to close all connections.
  2689  	tick := time.NewTicker(3 * time.Millisecond)
  2690  	defer tick.Stop()
  2691  	for {
  2692  		select {
  2693  		case <-tick.C:
  2694  			db.mu.Lock()
  2695  			if db.numOpen == 0 {
  2696  				db.mu.Unlock()
  2697  				return
  2698  			}
  2699  			db.mu.Unlock()
  2700  		case <-to.C:
  2701  			// Closing the database will check for numOpen and fail the test.
  2702  			return
  2703  		}
  2704  	}
  2705  }
  2706  
  2707  func TestSingleOpenConn(t *testing.T) {
  2708  	testDatabase(t, testSingleOpenConn)
  2709  }
  2710  func testSingleOpenConn(t *testing.T, db *DB) {
  2711  	populate(t, db, "people")
  2712  
  2713  	db.SetMaxOpenConns(1)
  2714  
  2715  	rows, err := db.Query("SELECT|people|name|")
  2716  	if err != nil {
  2717  		t.Fatal(err)
  2718  	}
  2719  	if err = rows.Close(); err != nil {
  2720  		t.Fatal(err)
  2721  	}
  2722  	// shouldn't deadlock
  2723  	rows, err = db.Query("SELECT|people|name|")
  2724  	if err != nil {
  2725  		t.Fatal(err)
  2726  	}
  2727  	if err = rows.Close(); err != nil {
  2728  		t.Fatal(err)
  2729  	}
  2730  }
  2731  
  2732  func TestStats(t *testing.T) {
  2733  	testDatabase(t, testStats)
  2734  }
  2735  func testStats(t *testing.T, db *DB) {
  2736  	populate(t, db, "people")
  2737  	stats := db.Stats()
  2738  	if got := stats.OpenConnections; got != 1 {
  2739  		t.Errorf("stats.OpenConnections = %d; want 1", got)
  2740  	}
  2741  
  2742  	tx, err := db.Begin()
  2743  	if err != nil {
  2744  		t.Fatal(err)
  2745  	}
  2746  	tx.Commit()
  2747  
  2748  	closeDB(t, db)
  2749  	stats = db.Stats()
  2750  	if got := stats.OpenConnections; got != 0 {
  2751  		t.Errorf("stats.OpenConnections = %d; want 0", got)
  2752  	}
  2753  }
  2754  
  2755  func TestConnMaxLifetime(t *testing.T) {
  2756  	testDatabase(t, testConnMaxLifetime)
  2757  }
  2758  func testConnMaxLifetime(t *testing.T, db *DB) {
  2759  	populate(t, db, "magicquery")
  2760  
  2761  	driver := db.Driver().(*fakeDriver)
  2762  
  2763  	// Force the number of open connections to 0 so we can get an accurate
  2764  	// count for the test
  2765  	db.clearAllConns(t)
  2766  
  2767  	driver.mu.Lock()
  2768  	opens0 := driver.openCount
  2769  	closes0 := driver.closeCount
  2770  	driver.mu.Unlock()
  2771  
  2772  	db.SetMaxIdleConns(10)
  2773  	db.SetMaxOpenConns(10)
  2774  
  2775  	tx, err := db.Begin()
  2776  	if err != nil {
  2777  		t.Fatal(err)
  2778  	}
  2779  
  2780  	synctest.Sleep(1 * time.Second)
  2781  	tx2, err := db.Begin()
  2782  	if err != nil {
  2783  		t.Fatal(err)
  2784  	}
  2785  
  2786  	tx.Commit()
  2787  	tx2.Commit()
  2788  
  2789  	driver.mu.Lock()
  2790  	opens := driver.openCount - opens0
  2791  	closes := driver.closeCount - closes0
  2792  	driver.mu.Unlock()
  2793  
  2794  	if opens != 2 {
  2795  		t.Errorf("opens = %d; want 2", opens)
  2796  	}
  2797  	if closes != 0 {
  2798  		t.Errorf("closes = %d; want 0", closes)
  2799  	}
  2800  	if g, w := db.numFreeConns(), 2; g != w {
  2801  		t.Errorf("free conns = %d; want %d", g, w)
  2802  	}
  2803  
  2804  	// Expire first conn
  2805  	synctest.Sleep(10 * time.Second)
  2806  	db.SetConnMaxLifetime(10 * time.Second)
  2807  
  2808  	tx, err = db.Begin()
  2809  	if err != nil {
  2810  		t.Fatal(err)
  2811  	}
  2812  	tx2, err = db.Begin()
  2813  	if err != nil {
  2814  		t.Fatal(err)
  2815  	}
  2816  	tx.Commit()
  2817  	tx2.Commit()
  2818  
  2819  	// Give connectionCleaner chance to run.
  2820  	synctest.Wait()
  2821  	driver.mu.Lock()
  2822  	opens = driver.openCount - opens0
  2823  	closes = driver.closeCount - closes0
  2824  	driver.mu.Unlock()
  2825  
  2826  	if opens != 3 {
  2827  		t.Errorf("opens = %d; want 3", opens)
  2828  	}
  2829  	if closes != 1 {
  2830  		t.Errorf("closes = %d; want 1", closes)
  2831  	}
  2832  
  2833  	if s := db.Stats(); s.MaxLifetimeClosed != 1 {
  2834  		t.Errorf("MaxLifetimeClosed = %d; want 1 %#v", s.MaxLifetimeClosed, s)
  2835  	}
  2836  }
  2837  
  2838  // golang.org/issue/5323
  2839  func TestStmtCloseDeps(t *testing.T) {
  2840  	testDatabase(t, testStmtCloseDeps)
  2841  }
  2842  func testStmtCloseDeps(t *testing.T, db *DB) {
  2843  	if testing.Short() {
  2844  		t.Skip("skipping in short mode")
  2845  	}
  2846  	defer setHookpostCloseConn(nil)
  2847  	setHookpostCloseConn(func(_ *fakeConn, err error) {
  2848  		if err != nil {
  2849  			t.Errorf("Error closing fakeConn: %v", err)
  2850  		}
  2851  	})
  2852  
  2853  	populate(t, db, "magicquery")
  2854  
  2855  	driver := db.Driver().(*fakeDriver)
  2856  
  2857  	driver.mu.Lock()
  2858  	opens0 := driver.openCount
  2859  	closes0 := driver.closeCount
  2860  	driver.mu.Unlock()
  2861  	openDelta0 := opens0 - closes0
  2862  
  2863  	stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
  2864  	if err != nil {
  2865  		t.Fatal(err)
  2866  	}
  2867  
  2868  	// Start 50 parallel slow queries.
  2869  	const (
  2870  		nquery      = 50
  2871  		sleepMillis = 25
  2872  		nbatch      = 2
  2873  	)
  2874  	var wg sync.WaitGroup
  2875  	for batch := 0; batch < nbatch; batch++ {
  2876  		for i := 0; i < nquery; i++ {
  2877  			wg.Add(1)
  2878  			go func() {
  2879  				defer wg.Done()
  2880  				var op string
  2881  				if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
  2882  					t.Error(err)
  2883  				}
  2884  			}()
  2885  		}
  2886  		// Wait for the batch of queries above to finish before starting the next round.
  2887  		wg.Wait()
  2888  	}
  2889  
  2890  	if g, w := db.numFreeConns(), 2; g != w {
  2891  		t.Errorf("free conns = %d; want %d", g, w)
  2892  	}
  2893  
  2894  	if n := db.numDeps(); n > 4 {
  2895  		t.Errorf("number of dependencies = %d; expected <= 4", n)
  2896  		db.dumpDeps(t)
  2897  	}
  2898  
  2899  	driver.mu.Lock()
  2900  	opens := driver.openCount - opens0
  2901  	closes := driver.closeCount - closes0
  2902  	openDelta := (driver.openCount - driver.closeCount) - openDelta0
  2903  	driver.mu.Unlock()
  2904  
  2905  	if openDelta > 2 {
  2906  		t.Logf("open calls = %d", opens)
  2907  		t.Logf("close calls = %d", closes)
  2908  		t.Logf("open delta = %d", openDelta)
  2909  		t.Errorf("db connections opened = %d; want <= 2", openDelta)
  2910  		db.dumpDeps(t)
  2911  	}
  2912  
  2913  	synctest.Wait()
  2914  	if len(stmt.css) > nquery {
  2915  		t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
  2916  	}
  2917  
  2918  	if err := stmt.Close(); err != nil {
  2919  		t.Fatal(err)
  2920  	}
  2921  
  2922  	if g, w := db.numFreeConns(), 2; g != w {
  2923  		t.Errorf("free conns = %d; want %d", g, w)
  2924  	}
  2925  
  2926  	if n := db.numDeps(); n > 2 {
  2927  		t.Errorf("number of dependencies = %d; expected <= 2", n)
  2928  		db.dumpDeps(t)
  2929  	}
  2930  
  2931  	db.clearAllConns(t)
  2932  }
  2933  
  2934  // golang.org/issue/5046
  2935  func TestCloseConnBeforeStmts(t *testing.T) {
  2936  	testDatabase(t, testCloseConnBeforeStmts)
  2937  }
  2938  func testCloseConnBeforeStmts(t *testing.T, db *DB) {
  2939  	populate(t, db, "people")
  2940  
  2941  	defer setHookpostCloseConn(nil)
  2942  	setHookpostCloseConn(func(_ *fakeConn, err error) {
  2943  		if err != nil {
  2944  			t.Errorf("Error closing fakeConn: %v; from %s", err, stack())
  2945  			db.dumpDeps(t)
  2946  			t.Errorf("DB = %#v", db)
  2947  		}
  2948  	})
  2949  
  2950  	stmt, err := db.Prepare("SELECT|people|name|")
  2951  	if err != nil {
  2952  		t.Fatal(err)
  2953  	}
  2954  
  2955  	if len(db.freeConn) != 1 {
  2956  		t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
  2957  	}
  2958  	dc := db.freeConn[0]
  2959  	if dc.closed {
  2960  		t.Errorf("conn shouldn't be closed")
  2961  	}
  2962  
  2963  	if n := len(dc.openStmt); n != 1 {
  2964  		t.Errorf("driverConn num openStmt = %d; want 1", n)
  2965  	}
  2966  	err = db.Close()
  2967  	if err != nil {
  2968  		t.Errorf("db Close = %v", err)
  2969  	}
  2970  	if !dc.closed {
  2971  		t.Errorf("after db.Close, driverConn should be closed")
  2972  	}
  2973  	if n := len(dc.openStmt); n != 0 {
  2974  		t.Errorf("driverConn num openStmt = %d; want 0", n)
  2975  	}
  2976  
  2977  	err = stmt.Close()
  2978  	if err != nil {
  2979  		t.Errorf("Stmt close = %v", err)
  2980  	}
  2981  
  2982  	if !dc.closed {
  2983  		t.Errorf("conn should be closed")
  2984  	}
  2985  	if dc.ci != nil {
  2986  		t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
  2987  	}
  2988  }
  2989  
  2990  // golang.org/issue/5283: don't release the Rows' connection in Close
  2991  // before calling Stmt.Close.
  2992  func TestRowsCloseOrder(t *testing.T) {
  2993  	testDatabase(t, testRowsCloseOrder)
  2994  }
  2995  func testRowsCloseOrder(t *testing.T, db *DB) {
  2996  	populate(t, db, "people")
  2997  
  2998  	db.SetMaxIdleConns(0)
  2999  	setStrictFakeConnClose(t)
  3000  	defer setStrictFakeConnClose(nil)
  3001  
  3002  	rows, err := db.Query("SELECT|people|age,name|")
  3003  	if err != nil {
  3004  		t.Fatal(err)
  3005  	}
  3006  	err = rows.Close()
  3007  	if err != nil {
  3008  		t.Fatal(err)
  3009  	}
  3010  }
  3011  
  3012  func TestRowsImplicitClose(t *testing.T) {
  3013  	testDatabase(t, testRowsImplicitClose)
  3014  }
  3015  func testRowsImplicitClose(t *testing.T, db *DB) {
  3016  	populate(t, db, "people")
  3017  
  3018  	rows, err := db.Query("SELECT|people|age,name|")
  3019  	if err != nil {
  3020  		t.Fatal(err)
  3021  	}
  3022  
  3023  	want, fail := 2, errors.New("fail")
  3024  	r := getRowsCursor(rows)
  3025  	r.errPos, r.err = want, fail
  3026  
  3027  	got := 0
  3028  	for rows.Next() {
  3029  		got++
  3030  	}
  3031  	if got != want {
  3032  		t.Errorf("got %d rows, want %d", got, want)
  3033  	}
  3034  	if err := rows.Err(); err != fail {
  3035  		t.Errorf("got error %v, want %v", err, fail)
  3036  	}
  3037  	if !r.closed {
  3038  		t.Errorf("r.closed is false, want true")
  3039  	}
  3040  }
  3041  
  3042  func TestRowsCloseError(t *testing.T) {
  3043  	testDatabase(t, testRowsCloseError)
  3044  }
  3045  func testRowsCloseError(t *testing.T, db *DB) {
  3046  	populate(t, db, "people")
  3047  	rows, err := db.Query("SELECT|people|age,name|")
  3048  	if err != nil {
  3049  		t.Fatalf("Query: %v", err)
  3050  	}
  3051  	type row struct {
  3052  		age  int
  3053  		name string
  3054  	}
  3055  	got := []row{}
  3056  
  3057  	rc := getRowsCursor(rows)
  3058  	rc.closeErr = errors.New("rowsCursor: failed to close")
  3059  
  3060  	for rows.Next() {
  3061  		var r row
  3062  		err = rows.Scan(&r.age, &r.name)
  3063  		if err != nil {
  3064  			t.Fatalf("Scan: %v", err)
  3065  		}
  3066  		got = append(got, r)
  3067  	}
  3068  	err = rows.Err()
  3069  	if err != rc.closeErr {
  3070  		t.Fatalf("unexpected err: got %v, want %v", err, rc.closeErr)
  3071  	}
  3072  }
  3073  
  3074  func TestStmtCloseOrder(t *testing.T) {
  3075  	testDatabase(t, testStmtCloseOrder)
  3076  }
  3077  func testStmtCloseOrder(t *testing.T, db *DB) {
  3078  	populate(t, db, "people")
  3079  
  3080  	db.SetMaxIdleConns(0)
  3081  	setStrictFakeConnClose(t)
  3082  	defer setStrictFakeConnClose(nil)
  3083  
  3084  	_, err := db.Query("SELECT|non_existent|name|")
  3085  	if err == nil {
  3086  		t.Fatal("Querying non-existent table should fail")
  3087  	}
  3088  }
  3089  
  3090  // Test cases where there's more than maxBadConnRetries bad connections in the
  3091  // pool (issue 8834)
  3092  func TestManyErrBadConn(t *testing.T) {
  3093  	synctest.Test(t, testManyErrBadConn)
  3094  }
  3095  func testManyErrBadConn(t *testing.T) {
  3096  	manyErrBadConnSetup := func(first ...func(db *DB)) *DB {
  3097  		db := newTestDB(t, "people")
  3098  
  3099  		for _, f := range first {
  3100  			f(db)
  3101  		}
  3102  
  3103  		nconn := maxBadConnRetries + 1
  3104  		db.SetMaxIdleConns(nconn)
  3105  		db.SetMaxOpenConns(nconn)
  3106  		// open enough connections
  3107  		func() {
  3108  			for i := 0; i < nconn; i++ {
  3109  				rows, err := db.Query("SELECT|people|age,name|")
  3110  				if err != nil {
  3111  					t.Fatal(err)
  3112  				}
  3113  				defer rows.Close()
  3114  			}
  3115  		}()
  3116  
  3117  		db.mu.Lock()
  3118  		defer db.mu.Unlock()
  3119  		if db.numOpen != nconn {
  3120  			t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn)
  3121  		} else if len(db.freeConn) != nconn {
  3122  			t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
  3123  		}
  3124  		for _, conn := range db.freeConn {
  3125  			conn.Lock()
  3126  			getFakeConn(conn.ci).stickyBad = true
  3127  			conn.Unlock()
  3128  		}
  3129  		return db
  3130  	}
  3131  
  3132  	// Query
  3133  	db := manyErrBadConnSetup()
  3134  	defer closeDB(t, db)
  3135  	rows, err := db.Query("SELECT|people|age,name|")
  3136  	if err != nil {
  3137  		t.Fatal(err)
  3138  	}
  3139  	if err = rows.Close(); err != nil {
  3140  		t.Fatal(err)
  3141  	}
  3142  
  3143  	// Exec
  3144  	db = manyErrBadConnSetup()
  3145  	defer closeDB(t, db)
  3146  	_, err = db.Exec("INSERT|people|name=Julia,age=19")
  3147  	if err != nil {
  3148  		t.Fatal(err)
  3149  	}
  3150  
  3151  	// Begin
  3152  	db = manyErrBadConnSetup()
  3153  	defer closeDB(t, db)
  3154  	tx, err := db.Begin()
  3155  	if err != nil {
  3156  		t.Fatal(err)
  3157  	}
  3158  	if err = tx.Rollback(); err != nil {
  3159  		t.Fatal(err)
  3160  	}
  3161  
  3162  	// Prepare
  3163  	db = manyErrBadConnSetup()
  3164  	defer closeDB(t, db)
  3165  	stmt, err := db.Prepare("SELECT|people|age,name|")
  3166  	if err != nil {
  3167  		t.Fatal(err)
  3168  	}
  3169  	if err = stmt.Close(); err != nil {
  3170  		t.Fatal(err)
  3171  	}
  3172  
  3173  	// Stmt.Exec
  3174  	db = manyErrBadConnSetup(func(db *DB) {
  3175  		stmt, err = db.Prepare("INSERT|people|name=Julia,age=19")
  3176  		if err != nil {
  3177  			t.Fatal(err)
  3178  		}
  3179  	})
  3180  	defer closeDB(t, db)
  3181  	_, err = stmt.Exec()
  3182  	if err != nil {
  3183  		t.Fatal(err)
  3184  	}
  3185  	if err = stmt.Close(); err != nil {
  3186  		t.Fatal(err)
  3187  	}
  3188  
  3189  	// Stmt.Query
  3190  	db = manyErrBadConnSetup(func(db *DB) {
  3191  		stmt, err = db.Prepare("SELECT|people|age,name|")
  3192  		if err != nil {
  3193  			t.Fatal(err)
  3194  		}
  3195  	})
  3196  	defer closeDB(t, db)
  3197  	rows, err = stmt.Query()
  3198  	if err != nil {
  3199  		t.Fatal(err)
  3200  	}
  3201  	if err = rows.Close(); err != nil {
  3202  		t.Fatal(err)
  3203  	}
  3204  	if err = stmt.Close(); err != nil {
  3205  		t.Fatal(err)
  3206  	}
  3207  
  3208  	// Conn
  3209  	db = manyErrBadConnSetup()
  3210  	defer closeDB(t, db)
  3211  	ctx := t.Context()
  3212  	conn, err := db.Conn(ctx)
  3213  	if err != nil {
  3214  		t.Fatal(err)
  3215  	}
  3216  	getFakeConn(conn.dc.ci).skipDirtySession = true
  3217  	err = conn.Close()
  3218  	if err != nil {
  3219  		t.Fatal(err)
  3220  	}
  3221  
  3222  	// Ping
  3223  	db = manyErrBadConnSetup()
  3224  	defer closeDB(t, db)
  3225  	err = db.PingContext(ctx)
  3226  	if err != nil {
  3227  		t.Fatal(err)
  3228  	}
  3229  }
  3230  
  3231  // Issue 34775: Ensure that a Tx cannot commit after a rollback.
  3232  func TestTxCannotCommitAfterRollback(t *testing.T) {
  3233  	testDatabase(t, testTxCannotCommitAfterRollback)
  3234  }
  3235  func testTxCannotCommitAfterRollback(t *testing.T, db *DB) {
  3236  	populate(t, db, "tx_status")
  3237  
  3238  	// First check query reporting is correct.
  3239  	var txStatus string
  3240  	err := db.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
  3241  	if err != nil {
  3242  		t.Fatal(err)
  3243  	}
  3244  	if g, w := txStatus, "autocommit"; g != w {
  3245  		t.Fatalf("tx_status=%q, wanted %q", g, w)
  3246  	}
  3247  
  3248  	ctx, cancel := context.WithCancel(context.Background())
  3249  	defer cancel()
  3250  
  3251  	tx, err := db.BeginTx(ctx, nil)
  3252  	if err != nil {
  3253  		t.Fatal(err)
  3254  	}
  3255  
  3256  	// Ignore dirty session for this test.
  3257  	// A failing test should trigger the dirty session flag as well,
  3258  	// but that isn't exactly what this should test for.
  3259  	tx.txi.(*fakeTx).c.skipDirtySession = true
  3260  
  3261  	defer tx.Rollback()
  3262  
  3263  	err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
  3264  	if err != nil {
  3265  		t.Fatal(err)
  3266  	}
  3267  	if g, w := txStatus, "transaction"; g != w {
  3268  		t.Fatalf("tx_status=%q, wanted %q", g, w)
  3269  	}
  3270  
  3271  	// 1. Begin a transaction.
  3272  	// 2. (A) Start a query, (B) begin Tx rollback through a ctx cancel.
  3273  	// 3. Check if 2.A has committed in Tx (pass) or outside of Tx (fail).
  3274  	sendQuery := make(chan struct{})
  3275  	// The Tx status is returned through the row results, ensure
  3276  	// that the rows results are not canceled.
  3277  	bypassRowsAwaitDone = true
  3278  	hookTxGrabConn = func() {
  3279  		cancel()
  3280  		<-sendQuery
  3281  	}
  3282  	rollbackHook = func() {
  3283  		close(sendQuery)
  3284  	}
  3285  	defer func() {
  3286  		hookTxGrabConn = nil
  3287  		rollbackHook = nil
  3288  		bypassRowsAwaitDone = false
  3289  	}()
  3290  
  3291  	err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
  3292  	if err != nil {
  3293  		// A failure here would be expected if skipDirtySession was not set to true above.
  3294  		t.Fatal(err)
  3295  	}
  3296  	if g, w := txStatus, "transaction"; g != w {
  3297  		t.Fatalf("tx_status=%q, wanted %q", g, w)
  3298  	}
  3299  }
  3300  
  3301  // Issue 40985 transaction statement deadlock while context cancel.
  3302  func TestTxStmtDeadlock(t *testing.T) {
  3303  	testDatabase(t, testTxStmtDeadlock)
  3304  }
  3305  func testTxStmtDeadlock(t *testing.T, db *DB) {
  3306  	populate(t, db, "people")
  3307  
  3308  	ctx, cancel := context.WithCancel(context.Background())
  3309  	defer cancel()
  3310  	tx, err := db.BeginTx(ctx, nil)
  3311  	if err != nil {
  3312  		t.Fatal(err)
  3313  	}
  3314  
  3315  	stmt, err := tx.Prepare("SELECT|people|name,age|age=?")
  3316  	if err != nil {
  3317  		t.Fatal(err)
  3318  	}
  3319  	cancel()
  3320  	// Run number of stmt queries to reproduce deadlock from context cancel
  3321  	for i := 0; i < 1e3; i++ {
  3322  		// Encounter any close related errors (e.g. ErrTxDone, stmt is closed)
  3323  		// is expected due to context cancel.
  3324  		_, err = stmt.Query(1)
  3325  		if err != nil {
  3326  			break
  3327  		}
  3328  	}
  3329  	_ = tx.Rollback()
  3330  }
  3331  
  3332  func synctestSubtest(t *testing.T, name string, f func(t *testing.T)) {
  3333  	t.Run(name, func(t *testing.T) {
  3334  		synctest.Test(t, f)
  3335  	})
  3336  }
  3337  
  3338  // Issue32530 encounters an issue where a connection may
  3339  // expire right after it comes out of a used connection pool
  3340  // even when a new connection is requested.
  3341  func TestConnExpiresFreshOutOfPool(t *testing.T) {
  3342  	execCases := []struct {
  3343  		expired  bool
  3344  		badReset bool
  3345  	}{
  3346  		{false, false},
  3347  		{true, false},
  3348  		{false, true},
  3349  	}
  3350  
  3351  	for _, ec := range execCases {
  3352  		name := fmt.Sprintf("expired=%t,badReset=%t", ec.expired, ec.badReset)
  3353  		synctestSubtest(t, name, func(t *testing.T) {
  3354  			ctx := t.Context()
  3355  
  3356  			db := newTestDB(t, "magicquery")
  3357  
  3358  			db.SetMaxOpenConns(1)
  3359  
  3360  			db.clearAllConns(t)
  3361  
  3362  			db.SetMaxIdleConns(1)
  3363  			db.SetConnMaxLifetime(10 * time.Second)
  3364  
  3365  			conn, err := db.conn(ctx, alwaysNewConn)
  3366  			if err != nil {
  3367  				t.Fatal(err)
  3368  			}
  3369  
  3370  			afterPutConn := make(chan struct{})
  3371  
  3372  			go func() {
  3373  				defer close(afterPutConn)
  3374  
  3375  				conn, err := db.conn(ctx, alwaysNewConn)
  3376  				if err == nil {
  3377  					db.putConn(conn, err, false)
  3378  				} else {
  3379  					t.Errorf("db.conn: %v", err)
  3380  				}
  3381  			}()
  3382  			synctest.Wait()
  3383  
  3384  			if t.Failed() {
  3385  				return
  3386  			}
  3387  
  3388  			synctest.Sleep(11 * time.Second)
  3389  
  3390  			getFakeConn(conn.ci).stickyBad = ec.badReset
  3391  
  3392  			db.putConn(conn, err, true)
  3393  
  3394  			<-afterPutConn
  3395  		})
  3396  	}
  3397  }
  3398  
  3399  // TestIssue20575 ensures the Rows from query does not block
  3400  // closing a transaction. Ensure Rows is closed while closing a transaction.
  3401  func TestIssue20575(t *testing.T) {
  3402  	testDatabase(t, testIssue20575)
  3403  }
  3404  func testIssue20575(t *testing.T, db *DB) {
  3405  	populate(t, db, "people")
  3406  
  3407  	tx, err := db.Begin()
  3408  	if err != nil {
  3409  		t.Fatal(err)
  3410  	}
  3411  	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
  3412  	defer cancel()
  3413  	_, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
  3414  	if err != nil {
  3415  		t.Fatal(err)
  3416  	}
  3417  	// Do not close Rows from QueryContext.
  3418  	err = tx.Rollback()
  3419  	if err != nil {
  3420  		t.Fatal(err)
  3421  	}
  3422  	select {
  3423  	default:
  3424  	case <-ctx.Done():
  3425  		t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err())
  3426  	}
  3427  }
  3428  
  3429  // TestIssue20622 tests closing the transaction before rows is closed, requires
  3430  // the race detector to fail.
  3431  func TestIssue20622(t *testing.T) {
  3432  	testDatabase(t, testIssue20622)
  3433  }
  3434  func testIssue20622(t *testing.T, db *DB) {
  3435  	populate(t, db, "people")
  3436  
  3437  	ctx, cancel := context.WithCancel(context.Background())
  3438  	defer cancel()
  3439  
  3440  	tx, err := db.BeginTx(ctx, nil)
  3441  	if err != nil {
  3442  		t.Fatal(err)
  3443  	}
  3444  
  3445  	rows, err := tx.Query("SELECT|people|age,name|")
  3446  	if err != nil {
  3447  		t.Fatal(err)
  3448  	}
  3449  
  3450  	count := 0
  3451  	for rows.Next() {
  3452  		count++
  3453  		var age int
  3454  		var name string
  3455  		if err := rows.Scan(&age, &name); err != nil {
  3456  			t.Fatal("scan failed", err)
  3457  		}
  3458  
  3459  		if count == 1 {
  3460  			cancel()
  3461  		}
  3462  		time.Sleep(100 * time.Millisecond)
  3463  	}
  3464  	rows.Close()
  3465  	tx.Commit()
  3466  }
  3467  
  3468  // golang.org/issue/5718
  3469  func TestErrBadConnReconnect(t *testing.T) {
  3470  	testDatabase(t, testErrBadConnReconnect)
  3471  }
  3472  func testErrBadConnReconnect(t *testing.T, db *DB) {
  3473  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  3474  
  3475  	simulateBadConn := func(name string, hook *func() bool, op func() error) {
  3476  		broken, retried := false, false
  3477  		numOpen := db.numOpen
  3478  
  3479  		// simulate a broken connection on the first try
  3480  		*hook = func() bool {
  3481  			if !broken {
  3482  				broken = true
  3483  				return true
  3484  			}
  3485  			retried = true
  3486  			return false
  3487  		}
  3488  
  3489  		if err := op(); err != nil {
  3490  			t.Errorf(name+": %v", err)
  3491  			return
  3492  		}
  3493  
  3494  		if !broken || !retried {
  3495  			t.Error(name + ": Failed to simulate broken connection")
  3496  		}
  3497  		*hook = nil
  3498  
  3499  		if numOpen != db.numOpen {
  3500  			t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
  3501  			numOpen = db.numOpen
  3502  		}
  3503  	}
  3504  
  3505  	// db.Exec
  3506  	dbExec := func() error {
  3507  		_, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
  3508  		return err
  3509  	}
  3510  	simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec)
  3511  	simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec)
  3512  
  3513  	// db.Query
  3514  	dbQuery := func() error {
  3515  		rows, err := db.Query("SELECT|t1|age,name|")
  3516  		if err == nil {
  3517  			err = rows.Close()
  3518  		}
  3519  		return err
  3520  	}
  3521  	simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery)
  3522  	simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery)
  3523  
  3524  	// db.Prepare
  3525  	simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error {
  3526  		stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
  3527  		if err != nil {
  3528  			return err
  3529  		}
  3530  		stmt.Close()
  3531  		return nil
  3532  	})
  3533  
  3534  	// Provide a way to force a re-prepare of a statement on next execution
  3535  	forcePrepare := func(stmt *Stmt) {
  3536  		stmt.css = nil
  3537  	}
  3538  
  3539  	// stmt.Exec
  3540  	stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
  3541  	if err != nil {
  3542  		t.Fatalf("prepare: %v", err)
  3543  	}
  3544  	defer stmt1.Close()
  3545  	// make sure we must prepare the stmt first
  3546  	forcePrepare(stmt1)
  3547  
  3548  	stmtExec := func() error {
  3549  		_, err := stmt1.Exec("Gopher", 3, false)
  3550  		return err
  3551  	}
  3552  	simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec)
  3553  	simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec)
  3554  
  3555  	// stmt.Query
  3556  	stmt2, err := db.Prepare("SELECT|t1|age,name|")
  3557  	if err != nil {
  3558  		t.Fatalf("prepare: %v", err)
  3559  	}
  3560  	defer stmt2.Close()
  3561  	// make sure we must prepare the stmt first
  3562  	forcePrepare(stmt2)
  3563  
  3564  	stmtQuery := func() error {
  3565  		rows, err := stmt2.Query()
  3566  		if err == nil {
  3567  			err = rows.Close()
  3568  		}
  3569  		return err
  3570  	}
  3571  	simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery)
  3572  	simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery)
  3573  }
  3574  
  3575  // golang.org/issue/11264
  3576  func TestTxEndBadConn(t *testing.T) {
  3577  	testDatabase(t, testTxEndBadConn)
  3578  }
  3579  func testTxEndBadConn(t *testing.T, db *DB) {
  3580  	db.SetMaxIdleConns(0)
  3581  	exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
  3582  	db.SetMaxIdleConns(1)
  3583  
  3584  	simulateBadConn := func(name string, hook *func() bool, op func() error) {
  3585  		broken := false
  3586  		numOpen := db.numOpen
  3587  
  3588  		*hook = func() bool {
  3589  			if !broken {
  3590  				broken = true
  3591  			}
  3592  			return broken
  3593  		}
  3594  
  3595  		if err := op(); !errors.Is(err, driver.ErrBadConn) {
  3596  			t.Errorf(name+": %v", err)
  3597  			return
  3598  		}
  3599  
  3600  		if !broken {
  3601  			t.Error(name + ": Failed to simulate broken connection")
  3602  		}
  3603  		*hook = nil
  3604  
  3605  		if numOpen != db.numOpen {
  3606  			t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
  3607  		}
  3608  	}
  3609  
  3610  	// db.Exec
  3611  	dbExec := func(endTx func(tx *Tx) error) func() error {
  3612  		return func() error {
  3613  			tx, err := db.Begin()
  3614  			if err != nil {
  3615  				return err
  3616  			}
  3617  			_, err = tx.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
  3618  			if err != nil {
  3619  				return err
  3620  			}
  3621  			return endTx(tx)
  3622  		}
  3623  	}
  3624  	simulateBadConn("db.Tx.Exec commit", &hookCommitBadConn, dbExec((*Tx).Commit))
  3625  	simulateBadConn("db.Tx.Exec rollback", &hookRollbackBadConn, dbExec((*Tx).Rollback))
  3626  
  3627  	// db.Query
  3628  	dbQuery := func(endTx func(tx *Tx) error) func() error {
  3629  		return func() error {
  3630  			tx, err := db.Begin()
  3631  			if err != nil {
  3632  				return err
  3633  			}
  3634  			rows, err := tx.Query("SELECT|t1|age,name|")
  3635  			if err == nil {
  3636  				err = rows.Close()
  3637  			} else {
  3638  				return err
  3639  			}
  3640  			return endTx(tx)
  3641  		}
  3642  	}
  3643  	simulateBadConn("db.Tx.Query commit", &hookCommitBadConn, dbQuery((*Tx).Commit))
  3644  	simulateBadConn("db.Tx.Query rollback", &hookRollbackBadConn, dbQuery((*Tx).Rollback))
  3645  }
  3646  
  3647  type concurrentTest interface {
  3648  	init(t testing.TB, db *DB)
  3649  	finish(t testing.TB)
  3650  	test(t testing.TB) error
  3651  }
  3652  
  3653  type concurrentDBQueryTest struct {
  3654  	db *DB
  3655  }
  3656  
  3657  func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) {
  3658  	c.db = db
  3659  }
  3660  
  3661  func (c *concurrentDBQueryTest) finish(t testing.TB) {
  3662  	c.db = nil
  3663  }
  3664  
  3665  func (c *concurrentDBQueryTest) test(t testing.TB) error {
  3666  	rows, err := c.db.Query("SELECT|people|name|")
  3667  	if err != nil {
  3668  		t.Error(err)
  3669  		return err
  3670  	}
  3671  	var name string
  3672  	for rows.Next() {
  3673  		rows.Scan(&name)
  3674  	}
  3675  	rows.Close()
  3676  	return nil
  3677  }
  3678  
  3679  type concurrentDBExecTest struct {
  3680  	db *DB
  3681  }
  3682  
  3683  func (c *concurrentDBExecTest) init(t testing.TB, db *DB) {
  3684  	c.db = db
  3685  }
  3686  
  3687  func (c *concurrentDBExecTest) finish(t testing.TB) {
  3688  	c.db = nil
  3689  }
  3690  
  3691  func (c *concurrentDBExecTest) test(t testing.TB) error {
  3692  	_, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
  3693  	if err != nil {
  3694  		t.Error(err)
  3695  		return err
  3696  	}
  3697  	return nil
  3698  }
  3699  
  3700  type concurrentStmtQueryTest struct {
  3701  	db   *DB
  3702  	stmt *Stmt
  3703  }
  3704  
  3705  func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) {
  3706  	c.db = db
  3707  	var err error
  3708  	c.stmt, err = db.Prepare("SELECT|people|name|")
  3709  	if err != nil {
  3710  		t.Fatal(err)
  3711  	}
  3712  }
  3713  
  3714  func (c *concurrentStmtQueryTest) finish(t testing.TB) {
  3715  	if c.stmt != nil {
  3716  		c.stmt.Close()
  3717  		c.stmt = nil
  3718  	}
  3719  	c.db = nil
  3720  }
  3721  
  3722  func (c *concurrentStmtQueryTest) test(t testing.TB) error {
  3723  	rows, err := c.stmt.Query()
  3724  	if err != nil {
  3725  		t.Errorf("error on query:  %v", err)
  3726  		return err
  3727  	}
  3728  
  3729  	var name string
  3730  	for rows.Next() {
  3731  		rows.Scan(&name)
  3732  	}
  3733  	rows.Close()
  3734  	return nil
  3735  }
  3736  
  3737  type concurrentStmtExecTest struct {
  3738  	db   *DB
  3739  	stmt *Stmt
  3740  }
  3741  
  3742  func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) {
  3743  	c.db = db
  3744  	var err error
  3745  	c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
  3746  	if err != nil {
  3747  		t.Fatal(err)
  3748  	}
  3749  }
  3750  
  3751  func (c *concurrentStmtExecTest) finish(t testing.TB) {
  3752  	if c.stmt != nil {
  3753  		c.stmt.Close()
  3754  		c.stmt = nil
  3755  	}
  3756  	c.db = nil
  3757  }
  3758  
  3759  func (c *concurrentStmtExecTest) test(t testing.TB) error {
  3760  	_, err := c.stmt.Exec(3, chrisBirthday)
  3761  	if err != nil {
  3762  		t.Errorf("error on exec:  %v", err)
  3763  		return err
  3764  	}
  3765  	return nil
  3766  }
  3767  
  3768  type concurrentTxQueryTest struct {
  3769  	db *DB
  3770  	tx *Tx
  3771  }
  3772  
  3773  func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) {
  3774  	c.db = db
  3775  	var err error
  3776  	c.tx, err = c.db.Begin()
  3777  	if err != nil {
  3778  		t.Fatal(err)
  3779  	}
  3780  }
  3781  
  3782  func (c *concurrentTxQueryTest) finish(t testing.TB) {
  3783  	if c.tx != nil {
  3784  		c.tx.Rollback()
  3785  		c.tx = nil
  3786  	}
  3787  	c.db = nil
  3788  }
  3789  
  3790  func (c *concurrentTxQueryTest) test(t testing.TB) error {
  3791  	rows, err := c.db.Query("SELECT|people|name|")
  3792  	if err != nil {
  3793  		t.Error(err)
  3794  		return err
  3795  	}
  3796  	var name string
  3797  	for rows.Next() {
  3798  		rows.Scan(&name)
  3799  	}
  3800  	rows.Close()
  3801  	return nil
  3802  }
  3803  
  3804  type concurrentTxExecTest struct {
  3805  	db *DB
  3806  	tx *Tx
  3807  }
  3808  
  3809  func (c *concurrentTxExecTest) init(t testing.TB, db *DB) {
  3810  	c.db = db
  3811  	var err error
  3812  	c.tx, err = c.db.Begin()
  3813  	if err != nil {
  3814  		t.Fatal(err)
  3815  	}
  3816  }
  3817  
  3818  func (c *concurrentTxExecTest) finish(t testing.TB) {
  3819  	if c.tx != nil {
  3820  		c.tx.Rollback()
  3821  		c.tx = nil
  3822  	}
  3823  	c.db = nil
  3824  }
  3825  
  3826  func (c *concurrentTxExecTest) test(t testing.TB) error {
  3827  	_, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
  3828  	if err != nil {
  3829  		t.Error(err)
  3830  		return err
  3831  	}
  3832  	return nil
  3833  }
  3834  
  3835  type concurrentTxStmtQueryTest struct {
  3836  	db   *DB
  3837  	tx   *Tx
  3838  	stmt *Stmt
  3839  }
  3840  
  3841  func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) {
  3842  	c.db = db
  3843  	var err error
  3844  	c.tx, err = c.db.Begin()
  3845  	if err != nil {
  3846  		t.Fatal(err)
  3847  	}
  3848  	c.stmt, err = c.tx.Prepare("SELECT|people|name|")
  3849  	if err != nil {
  3850  		t.Fatal(err)
  3851  	}
  3852  }
  3853  
  3854  func (c *concurrentTxStmtQueryTest) finish(t testing.TB) {
  3855  	if c.stmt != nil {
  3856  		c.stmt.Close()
  3857  		c.stmt = nil
  3858  	}
  3859  	if c.tx != nil {
  3860  		c.tx.Rollback()
  3861  		c.tx = nil
  3862  	}
  3863  	c.db = nil
  3864  }
  3865  
  3866  func (c *concurrentTxStmtQueryTest) test(t testing.TB) error {
  3867  	rows, err := c.stmt.Query()
  3868  	if err != nil {
  3869  		t.Errorf("error on query:  %v", err)
  3870  		return err
  3871  	}
  3872  
  3873  	var name string
  3874  	for rows.Next() {
  3875  		rows.Scan(&name)
  3876  	}
  3877  	rows.Close()
  3878  	return nil
  3879  }
  3880  
  3881  type concurrentTxStmtExecTest struct {
  3882  	db   *DB
  3883  	tx   *Tx
  3884  	stmt *Stmt
  3885  }
  3886  
  3887  func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) {
  3888  	c.db = db
  3889  	var err error
  3890  	c.tx, err = c.db.Begin()
  3891  	if err != nil {
  3892  		t.Fatal(err)
  3893  	}
  3894  	c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
  3895  	if err != nil {
  3896  		t.Fatal(err)
  3897  	}
  3898  }
  3899  
  3900  func (c *concurrentTxStmtExecTest) finish(t testing.TB) {
  3901  	if c.stmt != nil {
  3902  		c.stmt.Close()
  3903  		c.stmt = nil
  3904  	}
  3905  	if c.tx != nil {
  3906  		c.tx.Rollback()
  3907  		c.tx = nil
  3908  	}
  3909  	c.db = nil
  3910  }
  3911  
  3912  func (c *concurrentTxStmtExecTest) test(t testing.TB) error {
  3913  	_, err := c.stmt.Exec(3, chrisBirthday)
  3914  	if err != nil {
  3915  		t.Errorf("error on exec:  %v", err)
  3916  		return err
  3917  	}
  3918  	return nil
  3919  }
  3920  
  3921  type concurrentRandomTest struct {
  3922  	tests []concurrentTest
  3923  }
  3924  
  3925  func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
  3926  	c.tests = []concurrentTest{
  3927  		new(concurrentDBQueryTest),
  3928  		new(concurrentDBExecTest),
  3929  		new(concurrentStmtQueryTest),
  3930  		new(concurrentStmtExecTest),
  3931  		new(concurrentTxQueryTest),
  3932  		new(concurrentTxExecTest),
  3933  		new(concurrentTxStmtQueryTest),
  3934  		new(concurrentTxStmtExecTest),
  3935  	}
  3936  	for _, ct := range c.tests {
  3937  		ct.init(t, db)
  3938  	}
  3939  }
  3940  
  3941  func (c *concurrentRandomTest) finish(t testing.TB) {
  3942  	for _, ct := range c.tests {
  3943  		ct.finish(t)
  3944  	}
  3945  }
  3946  
  3947  func (c *concurrentRandomTest) test(t testing.TB) error {
  3948  	ct := c.tests[rand.Intn(len(c.tests))]
  3949  	return ct.test(t)
  3950  }
  3951  
  3952  func doConcurrentTest(t testing.TB, ct concurrentTest) {
  3953  	maxProcs, numReqs := 1, 500
  3954  	if testing.Short() {
  3955  		maxProcs, numReqs = 4, 50
  3956  	}
  3957  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
  3958  
  3959  	db := newTestDB(t, "people")
  3960  	defer closeDB(t, db)
  3961  
  3962  	ct.init(t, db)
  3963  	defer ct.finish(t)
  3964  
  3965  	var wg sync.WaitGroup
  3966  	wg.Add(numReqs)
  3967  
  3968  	reqs := make(chan bool)
  3969  	defer close(reqs)
  3970  
  3971  	for i := 0; i < maxProcs*2; i++ {
  3972  		go func() {
  3973  			for range reqs {
  3974  				err := ct.test(t)
  3975  				if err != nil {
  3976  					wg.Done()
  3977  					continue
  3978  				}
  3979  				wg.Done()
  3980  			}
  3981  		}()
  3982  	}
  3983  
  3984  	for i := 0; i < numReqs; i++ {
  3985  		reqs <- true
  3986  	}
  3987  
  3988  	wg.Wait()
  3989  }
  3990  
  3991  func TestIssue6081(t *testing.T) {
  3992  	testDatabase(t, testIssue6081)
  3993  }
  3994  func testIssue6081(t *testing.T, db *DB) {
  3995  	populate(t, db, "people")
  3996  
  3997  	drv := db.Driver().(*fakeDriver)
  3998  	drv.mu.Lock()
  3999  	opens0 := drv.openCount
  4000  	closes0 := drv.closeCount
  4001  	drv.mu.Unlock()
  4002  
  4003  	stmt, err := db.Prepare("SELECT|people|name|")
  4004  	if err != nil {
  4005  		t.Fatal(err)
  4006  	}
  4007  	setRowsCloseHook(func(rows *Rows, err *error) {
  4008  		*err = driver.ErrBadConn
  4009  	})
  4010  	defer setRowsCloseHook(nil)
  4011  	for i := 0; i < 10; i++ {
  4012  		rows, err := stmt.Query()
  4013  		if err != nil {
  4014  			t.Fatal(err)
  4015  		}
  4016  		rows.Close()
  4017  	}
  4018  	if n := len(stmt.css); n > 1 {
  4019  		t.Errorf("len(css slice) = %d; want <= 1", n)
  4020  	}
  4021  	stmt.Close()
  4022  	if n := len(stmt.css); n != 0 {
  4023  		t.Errorf("len(css slice) after Close = %d; want 0", n)
  4024  	}
  4025  
  4026  	drv.mu.Lock()
  4027  	opens := drv.openCount - opens0
  4028  	closes := drv.closeCount - closes0
  4029  	drv.mu.Unlock()
  4030  	if opens < 9 {
  4031  		t.Errorf("opens = %d; want >= 9", opens)
  4032  	}
  4033  	if closes < 9 {
  4034  		t.Errorf("closes = %d; want >= 9", closes)
  4035  	}
  4036  }
  4037  
  4038  // TestIssue18429 attempts to stress rolling back the transaction from a
  4039  // context cancel while simultaneously calling Tx.Rollback. Rolling back from a
  4040  // context happens concurrently so tx.rollback and tx.Commit must guard against
  4041  // double entry.
  4042  //
  4043  // In the test, a context is canceled while the query is in process so
  4044  // the internal rollback will run concurrently with the explicitly called
  4045  // Tx.Rollback.
  4046  //
  4047  // The addition of calling rows.Next also tests
  4048  // Issue 21117.
  4049  func TestIssue18429(t *testing.T) {
  4050  	testDatabase(t, testIssue18429)
  4051  }
  4052  func testIssue18429(t *testing.T, db *DB) {
  4053  	populate(t, db, "people")
  4054  
  4055  	ctx := context.Background()
  4056  	sem := make(chan bool, 20)
  4057  	var wg sync.WaitGroup
  4058  
  4059  	const milliWait = 30
  4060  
  4061  	for i := 0; i < 100; i++ {
  4062  		sem <- true
  4063  		wg.Add(1)
  4064  		go func() {
  4065  			defer func() {
  4066  				<-sem
  4067  				wg.Done()
  4068  			}()
  4069  			qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
  4070  
  4071  			ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
  4072  			defer cancel()
  4073  
  4074  			tx, err := db.BeginTx(ctx, nil)
  4075  			if err != nil {
  4076  				return
  4077  			}
  4078  			// This is expected to give a cancel error most, but not all the time.
  4079  			// Test failure will happen with a panic or other race condition being
  4080  			// reported.
  4081  			rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
  4082  			if rows != nil {
  4083  				var name string
  4084  				// Call Next to test Issue 21117 and check for races.
  4085  				for rows.Next() {
  4086  					// Scan the buffer so it is read and checked for races.
  4087  					rows.Scan(&name)
  4088  				}
  4089  				rows.Close()
  4090  			}
  4091  			// This call will race with the context cancel rollback to complete
  4092  			// if the rollback itself isn't guarded.
  4093  			tx.Rollback()
  4094  		}()
  4095  	}
  4096  	wg.Wait()
  4097  }
  4098  
  4099  // TestIssue20160 attempts to test a short context life on a stmt Query.
  4100  func TestIssue20160(t *testing.T) {
  4101  	testDatabase(t, testIssue20160)
  4102  }
  4103  func testIssue20160(t *testing.T, db *DB) {
  4104  	populate(t, db, "people")
  4105  
  4106  	ctx := context.Background()
  4107  	sem := make(chan bool, 20)
  4108  	var wg sync.WaitGroup
  4109  
  4110  	const milliWait = 30
  4111  
  4112  	stmt, err := db.PrepareContext(ctx, "SELECT|people|name|")
  4113  	if err != nil {
  4114  		t.Fatal(err)
  4115  	}
  4116  	defer stmt.Close()
  4117  
  4118  	for i := 0; i < 100; i++ {
  4119  		sem <- true
  4120  		wg.Add(1)
  4121  		go func() {
  4122  			defer func() {
  4123  				<-sem
  4124  				wg.Done()
  4125  			}()
  4126  			ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
  4127  			defer cancel()
  4128  
  4129  			// This is expected to give a cancel error most, but not all the time.
  4130  			// Test failure will happen with a panic or other race condition being
  4131  			// reported.
  4132  			rows, _ := stmt.QueryContext(ctx)
  4133  			if rows != nil {
  4134  				rows.Close()
  4135  			}
  4136  		}()
  4137  	}
  4138  	wg.Wait()
  4139  }
  4140  
  4141  // TestIssue18719 closes the context right before use. The sql.driverConn
  4142  // will nil out the ci on close in a lock, but if another process uses it right after
  4143  // it will panic with on the nil ref.
  4144  //
  4145  // See https://golang.org/cl/35550 .
  4146  func TestIssue18719(t *testing.T) {
  4147  	testDatabase(t, testIssue18719, requireFeature("BeginTx"))
  4148  }
  4149  func testIssue18719(t *testing.T, db *DB) {
  4150  	populate(t, db, "people")
  4151  
  4152  	ctx, cancel := context.WithCancel(context.Background())
  4153  	defer cancel()
  4154  
  4155  	tx, err := db.BeginTx(ctx, nil)
  4156  	if err != nil {
  4157  		t.Fatal(err)
  4158  	}
  4159  
  4160  	hookTxGrabConn = func() {
  4161  		cancel()
  4162  		synctest.Wait()
  4163  		if !tx.isDone() {
  4164  			t.Errorf("tx is not done")
  4165  		}
  4166  	}
  4167  	defer func() { hookTxGrabConn = nil }()
  4168  
  4169  	// This call will grab the connection and cancel the context
  4170  	// after it has done so. Code after must deal with the canceled state.
  4171  	_, err = tx.QueryContext(ctx, "SELECT|people|name|")
  4172  	if err != nil {
  4173  		t.Fatalf("expected error %v but got %v", nil, err)
  4174  	}
  4175  
  4176  	// Rows may be ignored because it will be closed when the context is canceled.
  4177  
  4178  	// Do not explicitly rollback. The rollback will happen from the
  4179  	// canceled context.
  4180  
  4181  	cancel()
  4182  }
  4183  
  4184  func TestIssue20647(t *testing.T) {
  4185  	testDatabase(t, testIssue20647)
  4186  }
  4187  func testIssue20647(t *testing.T, db *DB) {
  4188  	populate(t, db, "people")
  4189  
  4190  	ctx := t.Context()
  4191  
  4192  	conn, err := db.Conn(ctx)
  4193  	if err != nil {
  4194  		t.Fatal(err)
  4195  	}
  4196  	getFakeConn(conn.dc.ci).skipDirtySession = true
  4197  	defer conn.Close()
  4198  
  4199  	stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
  4200  	if err != nil {
  4201  		t.Fatal(err)
  4202  	}
  4203  	defer stmt.Close()
  4204  
  4205  	rows1, err := stmt.QueryContext(ctx)
  4206  	if err != nil {
  4207  		t.Fatal("rows1", err)
  4208  	}
  4209  	defer rows1.Close()
  4210  
  4211  	rows2, err := stmt.QueryContext(ctx)
  4212  	if err != nil {
  4213  		t.Fatal("rows2", err)
  4214  	}
  4215  	defer rows2.Close()
  4216  
  4217  	if rows1.dc != rows2.dc {
  4218  		t.Fatal("stmt prepared on Conn does not use same connection")
  4219  	}
  4220  }
  4221  
  4222  func TestConcurrency(t *testing.T) {
  4223  	list := []struct {
  4224  		name string
  4225  		ct   concurrentTest
  4226  	}{
  4227  		{"Query", new(concurrentDBQueryTest)},
  4228  		{"Exec", new(concurrentDBExecTest)},
  4229  		{"StmtQuery", new(concurrentStmtQueryTest)},
  4230  		{"StmtExec", new(concurrentStmtExecTest)},
  4231  		{"TxQuery", new(concurrentTxQueryTest)},
  4232  		{"TxExec", new(concurrentTxExecTest)},
  4233  		{"TxStmtQuery", new(concurrentTxStmtQueryTest)},
  4234  		{"TxStmtExec", new(concurrentTxStmtExecTest)},
  4235  		{"Random", new(concurrentRandomTest)},
  4236  	}
  4237  	for _, item := range list {
  4238  		synctestSubtest(t, item.name, func(t *testing.T) {
  4239  			doConcurrentTest(t, item.ct)
  4240  		})
  4241  	}
  4242  }
  4243  
  4244  func TestConnectionLeak(t *testing.T) {
  4245  	testDatabase(t, testConnectionLeak)
  4246  }
  4247  func testConnectionLeak(t *testing.T, db *DB) {
  4248  	populate(t, db, "people")
  4249  	// Start by opening defaultMaxIdleConns
  4250  	rows := make([]*Rows, defaultMaxIdleConns)
  4251  	// We need to SetMaxOpenConns > MaxIdleConns, so the DB can open
  4252  	// a new connection and we can fill the idle queue with the released
  4253  	// connections.
  4254  	db.SetMaxOpenConns(len(rows) + 1)
  4255  	for ii := range rows {
  4256  		r, err := db.Query("SELECT|people|name|")
  4257  		if err != nil {
  4258  			t.Fatal(err)
  4259  		}
  4260  		r.Next()
  4261  		if err := r.Err(); err != nil {
  4262  			t.Fatal(err)
  4263  		}
  4264  		rows[ii] = r
  4265  	}
  4266  	// Now we have defaultMaxIdleConns busy connections. Open
  4267  	// a new one, but wait until the busy connections are released
  4268  	// before returning control to DB.
  4269  	drv := db.Driver().(*fakeDriver)
  4270  	drv.waitCh = make(chan struct{}, 1)
  4271  	drv.waitingCh = make(chan struct{}, 1)
  4272  	var wg sync.WaitGroup
  4273  	wg.Add(1)
  4274  	go func() {
  4275  		r, err := db.Query("SELECT|people|name|")
  4276  		if err != nil {
  4277  			t.Error(err)
  4278  			return
  4279  		}
  4280  		r.Close()
  4281  		wg.Done()
  4282  	}()
  4283  	// Wait until the goroutine we've just created has started waiting.
  4284  	<-drv.waitingCh
  4285  	// Now close the busy connections. This provides a connection for
  4286  	// the blocked goroutine and then fills up the idle queue.
  4287  	for _, v := range rows {
  4288  		v.Close()
  4289  	}
  4290  	// At this point we give the new connection to DB. This connection is
  4291  	// now useless, since the idle queue is full and there are no pending
  4292  	// requests. DB should deal with this situation without leaking the
  4293  	// connection.
  4294  	drv.waitCh <- struct{}{}
  4295  	wg.Wait()
  4296  }
  4297  
  4298  func TestStatsMaxIdleClosedZero(t *testing.T) {
  4299  	testDatabase(t, testStatsMaxIdleClosedZero)
  4300  }
  4301  func testStatsMaxIdleClosedZero(t *testing.T, db *DB) {
  4302  	populate(t, db, "people")
  4303  
  4304  	db.SetMaxOpenConns(1)
  4305  	db.SetMaxIdleConns(1)
  4306  	db.SetConnMaxLifetime(0)
  4307  
  4308  	preMaxIdleClosed := db.Stats().MaxIdleClosed
  4309  
  4310  	for i := 0; i < 10; i++ {
  4311  		rows, err := db.Query("SELECT|people|name|")
  4312  		if err != nil {
  4313  			t.Fatal(err)
  4314  		}
  4315  		rows.Close()
  4316  	}
  4317  
  4318  	st := db.Stats()
  4319  	maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
  4320  	t.Logf("MaxIdleClosed: %d", maxIdleClosed)
  4321  	if maxIdleClosed != 0 {
  4322  		t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
  4323  	}
  4324  }
  4325  
  4326  func TestStatsMaxIdleClosedTen(t *testing.T) {
  4327  	testDatabase(t, testStatsMaxIdleClosedTen)
  4328  }
  4329  func testStatsMaxIdleClosedTen(t *testing.T, db *DB) {
  4330  	populate(t, db, "people")
  4331  
  4332  	db.SetMaxOpenConns(1)
  4333  	db.SetMaxIdleConns(0)
  4334  	db.SetConnMaxLifetime(0)
  4335  
  4336  	preMaxIdleClosed := db.Stats().MaxIdleClosed
  4337  
  4338  	for i := 0; i < 10; i++ {
  4339  		rows, err := db.Query("SELECT|people|name|")
  4340  		if err != nil {
  4341  			t.Fatal(err)
  4342  		}
  4343  		rows.Close()
  4344  	}
  4345  
  4346  	st := db.Stats()
  4347  	maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
  4348  	t.Logf("MaxIdleClosed: %d", maxIdleClosed)
  4349  	if maxIdleClosed != 10 {
  4350  		t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
  4351  	}
  4352  }
  4353  
  4354  // testUseConns uses count concurrent connections with 1 nanosecond apart.
  4355  func testUseConns(t *testing.T, count int, db *DB) {
  4356  	conns := make([]*Conn, count)
  4357  	ctx := context.Background()
  4358  	for i := range conns {
  4359  		time.Sleep(1 * time.Nanosecond)
  4360  		c, err := db.Conn(ctx)
  4361  		if err != nil {
  4362  			t.Error(err)
  4363  		}
  4364  		conns[i] = c
  4365  	}
  4366  
  4367  	for i := len(conns) - 1; i >= 0; i-- {
  4368  		time.Sleep(1 * time.Nanosecond)
  4369  		if err := conns[i].Close(); err != nil {
  4370  			t.Error(err)
  4371  		}
  4372  	}
  4373  }
  4374  
  4375  func TestMaxIdleTime(t *testing.T) {
  4376  	usedConns := 5
  4377  	reusedConns := 2
  4378  	list := []struct {
  4379  		wantMaxIdleTime   time.Duration
  4380  		wantMaxLifetime   time.Duration
  4381  		wantNextCheck     time.Duration
  4382  		wantIdleClosed    int64
  4383  		wantMaxIdleClosed int64
  4384  		timeOffset        time.Duration
  4385  		secondTimeOffset  time.Duration
  4386  	}{
  4387  		{
  4388  			time.Millisecond,
  4389  			0,
  4390  			time.Millisecond - time.Nanosecond,
  4391  			int64(usedConns - reusedConns),
  4392  			int64(usedConns - reusedConns),
  4393  			10 * time.Millisecond,
  4394  			0,
  4395  		},
  4396  		{
  4397  			// Want to close some connections via max idle time and one by max lifetime.
  4398  			time.Millisecond,
  4399  			// nowFunc() - MaxLifetime should be 1 * time.Nanosecond in connectionCleanerRunLocked.
  4400  			// This guarantees that first opened connection is to be closed.
  4401  			// Thus it is timeOffset + secondTimeOffset + 3 (+2 for Close while reusing conns and +1 for Conn).
  4402  			10*time.Millisecond + 100*time.Nanosecond + 3*time.Nanosecond,
  4403  			time.Nanosecond,
  4404  			// Closed all not reused connections and extra one by max lifetime.
  4405  			int64(usedConns - reusedConns + 1),
  4406  			int64(usedConns - reusedConns),
  4407  			10 * time.Millisecond,
  4408  			// Add second offset because otherwise connections are expired via max lifetime in Close.
  4409  			100 * time.Nanosecond,
  4410  		},
  4411  		{
  4412  			time.Hour,
  4413  			0,
  4414  			time.Second,
  4415  			0,
  4416  			0,
  4417  			10 * time.Millisecond,
  4418  			0},
  4419  	}
  4420  	for _, item := range list {
  4421  		synctestSubtest(t, fmt.Sprintf("%v", item.wantMaxIdleTime), func(t *testing.T) {
  4422  			startTime := time.Now()
  4423  			db := newTestDB(t, "people")
  4424  			defer closeDB(t, db)
  4425  
  4426  			db.SetMaxOpenConns(usedConns)
  4427  			db.SetMaxIdleConns(usedConns)
  4428  			db.SetConnMaxIdleTime(item.wantMaxIdleTime)
  4429  			db.SetConnMaxLifetime(item.wantMaxLifetime)
  4430  
  4431  			preMaxIdleClosed := db.Stats().MaxIdleTimeClosed
  4432  
  4433  			// Busy usedConns.
  4434  			testUseConns(t, usedConns, db)
  4435  
  4436  			time.Sleep(time.Until(startTime.Add(item.timeOffset)))
  4437  
  4438  			// Reuse connections which should never be considered idle
  4439  			// and exercises the sorting for issue 39471.
  4440  			testUseConns(t, reusedConns, db)
  4441  
  4442  			time.Sleep(item.secondTimeOffset)
  4443  
  4444  			db.mu.Lock()
  4445  			nc, closing := db.connectionCleanerRunLocked(time.Second)
  4446  			if nc != item.wantNextCheck {
  4447  				t.Errorf("got %v; want %v next check duration", nc, item.wantNextCheck)
  4448  			}
  4449  
  4450  			// Validate freeConn order.
  4451  			var last time.Time
  4452  			for _, c := range db.freeConn {
  4453  				if last.After(c.returnedAt) {
  4454  					t.Error("freeConn is not ordered by returnedAt")
  4455  					break
  4456  				}
  4457  				last = c.returnedAt
  4458  			}
  4459  
  4460  			db.mu.Unlock()
  4461  			for _, c := range closing {
  4462  				c.Close()
  4463  			}
  4464  			if g, w := int64(len(closing)), item.wantIdleClosed; g != w {
  4465  				t.Errorf("got: %d; want %d closed conns", g, w)
  4466  			}
  4467  
  4468  			st := db.Stats()
  4469  			maxIdleClosed := st.MaxIdleTimeClosed - preMaxIdleClosed
  4470  			if g, w := maxIdleClosed, item.wantMaxIdleClosed; g != w {
  4471  				t.Errorf("got: %d; want %d max idle closed conns", g, w)
  4472  			}
  4473  		})
  4474  	}
  4475  }
  4476  
  4477  type nvcDriver struct {
  4478  	fakeDriver
  4479  	skipNamedValueCheck bool
  4480  }
  4481  
  4482  func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
  4483  	c, err := d.fakeDriver.Open(dsn)
  4484  	fc := getFakeConn(c)
  4485  	fc.db.allowAny = true
  4486  	return &nvcConn{fc, d.skipNamedValueCheck}, err
  4487  }
  4488  
  4489  type nvcConn struct {
  4490  	*fakeConn
  4491  	skipNamedValueCheck bool
  4492  }
  4493  
  4494  type decimalInt struct {
  4495  	value int
  4496  }
  4497  
  4498  type doNotInclude struct{}
  4499  
  4500  var _ driver.NamedValueChecker = &nvcConn{}
  4501  
  4502  func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
  4503  	if c.skipNamedValueCheck {
  4504  		return driver.ErrSkip
  4505  	}
  4506  	switch v := nv.Value.(type) {
  4507  	default:
  4508  		return driver.ErrSkip
  4509  	case Out:
  4510  		switch ov := v.Dest.(type) {
  4511  		default:
  4512  			return errors.New("unknown NameValueCheck OUTPUT type")
  4513  		case *string:
  4514  			*ov = "from-server"
  4515  			nv.Value = "OUT:*string"
  4516  		}
  4517  		return nil
  4518  	case decimalInt, []int64:
  4519  		return nil
  4520  	case doNotInclude:
  4521  		return driver.ErrRemoveArgument
  4522  	}
  4523  }
  4524  
  4525  func TestNamedValueChecker(t *testing.T) {
  4526  	synctest.Test(t, testNamedValueChecker)
  4527  }
  4528  func testNamedValueChecker(t *testing.T) {
  4529  	Register("NamedValueCheck", &nvcDriver{})
  4530  	db, err := Open("NamedValueCheck", "")
  4531  	if err != nil {
  4532  		t.Fatal(err)
  4533  	}
  4534  	defer db.Close()
  4535  
  4536  	ctx := t.Context()
  4537  	_, err = db.ExecContext(ctx, "WIPE")
  4538  	if err != nil {
  4539  		t.Fatal("exec wipe", err)
  4540  	}
  4541  
  4542  	_, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
  4543  	if err != nil {
  4544  		t.Fatal("exec create", err)
  4545  	}
  4546  
  4547  	o1 := ""
  4548  	_, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimalInt{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
  4549  	if err != nil {
  4550  		t.Fatal("exec insert", err)
  4551  	}
  4552  	var (
  4553  		str1 string
  4554  		dec1 decimalInt
  4555  		arr1 []int64
  4556  	)
  4557  	err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
  4558  	if err != nil {
  4559  		t.Fatal("select", err)
  4560  	}
  4561  
  4562  	list := []struct{ got, want any }{
  4563  		{o1, "from-server"},
  4564  		{dec1, decimalInt{123}},
  4565  		{str1, "hello"},
  4566  		{arr1, []int64{42, 128, 707}},
  4567  	}
  4568  
  4569  	for index, item := range list {
  4570  		if !reflect.DeepEqual(item.got, item.want) {
  4571  			t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
  4572  		}
  4573  	}
  4574  }
  4575  
  4576  func TestNamedValueCheckerSkip(t *testing.T) {
  4577  	synctest.Test(t, testNamedValueCheckerSkip)
  4578  }
  4579  func testNamedValueCheckerSkip(t *testing.T) {
  4580  	Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
  4581  	db, err := Open("NamedValueCheckSkip", "")
  4582  	if err != nil {
  4583  		t.Fatal(err)
  4584  	}
  4585  	defer db.Close()
  4586  
  4587  	ctx := t.Context()
  4588  	_, err = db.ExecContext(ctx, "WIPE")
  4589  	if err != nil {
  4590  		t.Fatal("exec wipe", err)
  4591  	}
  4592  
  4593  	_, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
  4594  	if err != nil {
  4595  		t.Fatal("exec create", err)
  4596  	}
  4597  
  4598  	_, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimalInt{123}))
  4599  	if err == nil {
  4600  		t.Fatalf("expected error with bad argument, got %v", err)
  4601  	}
  4602  }
  4603  
  4604  func TestOpenConnector(t *testing.T) {
  4605  	synctest.Test(t, testOpenConnector)
  4606  }
  4607  func testOpenConnector(t *testing.T) {
  4608  	Register("testctx", &fakeDriverCtx{})
  4609  	db, err := Open("testctx", "people")
  4610  	if err != nil {
  4611  		t.Fatal(err)
  4612  	}
  4613  	defer db.Close()
  4614  
  4615  	c, ok := db.connector.(*fakeConnector)
  4616  	if !ok {
  4617  		t.Fatal("not using *fakeConnector")
  4618  	}
  4619  
  4620  	if err := db.Close(); err != nil {
  4621  		t.Fatal(err)
  4622  	}
  4623  
  4624  	if !c.closed {
  4625  		t.Fatal("connector is not closed")
  4626  	}
  4627  }
  4628  
  4629  type ctxOnlyDriver struct {
  4630  	fakeDriver
  4631  }
  4632  
  4633  func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) {
  4634  	conn, err := d.fakeDriver.Open(dsn)
  4635  	if err != nil {
  4636  		return nil, err
  4637  	}
  4638  	return &ctxOnlyConn{fc: getFakeConn(conn)}, nil
  4639  }
  4640  
  4641  var (
  4642  	_ driver.Conn           = &ctxOnlyConn{}
  4643  	_ driver.QueryerContext = &ctxOnlyConn{}
  4644  	_ driver.ExecerContext  = &ctxOnlyConn{}
  4645  )
  4646  
  4647  type ctxOnlyConn struct {
  4648  	fc *fakeConn
  4649  
  4650  	queryCtxCalled bool
  4651  	execCtxCalled  bool
  4652  }
  4653  
  4654  func (c *ctxOnlyConn) Begin() (driver.Tx, error) {
  4655  	return c.fc.Begin()
  4656  }
  4657  
  4658  func (c *ctxOnlyConn) Close() error {
  4659  	return c.fc.Close()
  4660  }
  4661  
  4662  // Prepare is still part of the Conn interface, so while it isn't used
  4663  // must be defined for compatibility.
  4664  func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) {
  4665  	panic("not used")
  4666  }
  4667  
  4668  func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
  4669  	return c.fc.PrepareContext(ctx, q)
  4670  }
  4671  
  4672  func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
  4673  	c.queryCtxCalled = true
  4674  	return c.fc.QueryContext(ctx, q, args)
  4675  }
  4676  
  4677  func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
  4678  	c.execCtxCalled = true
  4679  	return c.fc.ExecContext(ctx, q, args)
  4680  }
  4681  
  4682  // TestQueryExecContextOnly ensures drivers only need to implement QueryContext
  4683  // and ExecContext methods.
  4684  func TestQueryExecContextOnly(t *testing.T) {
  4685  	synctest.Test(t, testQueryExecContextOnly)
  4686  }
  4687  func testQueryExecContextOnly(t *testing.T) {
  4688  	// Ensure connection does not implement non-context interfaces.
  4689  	var connType driver.Conn = &ctxOnlyConn{}
  4690  	if _, ok := connType.(driver.Execer); ok {
  4691  		t.Fatalf("%T must not implement driver.Execer", connType)
  4692  	}
  4693  	if _, ok := connType.(driver.Queryer); ok {
  4694  		t.Fatalf("%T must not implement driver.Queryer", connType)
  4695  	}
  4696  
  4697  	Register("ContextOnly", &ctxOnlyDriver{})
  4698  	db, err := Open("ContextOnly", "")
  4699  	if err != nil {
  4700  		t.Fatal(err)
  4701  	}
  4702  	defer db.Close()
  4703  
  4704  	ctx := t.Context()
  4705  
  4706  	conn, err := db.Conn(ctx)
  4707  	if err != nil {
  4708  		t.Fatal("db.Conn", err)
  4709  	}
  4710  	defer conn.Close()
  4711  	coc := conn.dc.ci.(*ctxOnlyConn)
  4712  	coc.fc.skipDirtySession = true
  4713  
  4714  	_, err = conn.ExecContext(ctx, "WIPE")
  4715  	if err != nil {
  4716  		t.Fatal("exec wipe", err)
  4717  	}
  4718  
  4719  	_, err = conn.ExecContext(ctx, "CREATE|keys|v1=string")
  4720  	if err != nil {
  4721  		t.Fatal("exec create", err)
  4722  	}
  4723  	expectedValue := "value1"
  4724  	_, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue)
  4725  	if err != nil {
  4726  		t.Fatal("exec insert", err)
  4727  	}
  4728  	rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|")
  4729  	if err != nil {
  4730  		t.Fatal("query select", err)
  4731  	}
  4732  	v1 := ""
  4733  	for rows.Next() {
  4734  		err = rows.Scan(&v1)
  4735  		if err != nil {
  4736  			t.Fatal("rows scan", err)
  4737  		}
  4738  	}
  4739  	rows.Close()
  4740  
  4741  	if v1 != expectedValue {
  4742  		t.Fatalf("expected %q, got %q", expectedValue, v1)
  4743  	}
  4744  
  4745  	if !coc.execCtxCalled {
  4746  		t.Error("ExecContext not called")
  4747  	}
  4748  	if !coc.queryCtxCalled {
  4749  		t.Error("QueryContext not called")
  4750  	}
  4751  }
  4752  
  4753  type alwaysErrScanner struct{}
  4754  
  4755  var errTestScanWrap = errors.New("errTestScanWrap")
  4756  
  4757  func (alwaysErrScanner) Scan(any) error {
  4758  	return errTestScanWrap
  4759  }
  4760  
  4761  // Issue 38099: Ensure that Rows.Scan properly wraps underlying errors.
  4762  func TestRowsScanProperlyWrapsErrors(t *testing.T) {
  4763  	testDatabase(t, testRowsScanProperlyWrapsErrors)
  4764  }
  4765  func testRowsScanProperlyWrapsErrors(t *testing.T, db *DB) {
  4766  	populate(t, db, "people")
  4767  
  4768  	rows, err := db.Query("SELECT|people|age|")
  4769  	if err != nil {
  4770  		t.Fatalf("Query: %v", err)
  4771  	}
  4772  
  4773  	var res alwaysErrScanner
  4774  
  4775  	for rows.Next() {
  4776  		err = rows.Scan(&res)
  4777  		if err == nil {
  4778  			t.Fatal("expecting back an error")
  4779  		}
  4780  		if !errors.Is(err, errTestScanWrap) {
  4781  			t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errTestScanWrap)
  4782  		}
  4783  		// Ensure that error substring matching still correctly works.
  4784  		if !strings.Contains(err.Error(), errTestScanWrap.Error()) {
  4785  			t.Fatalf("Error %v does not contain %v", err, errTestScanWrap)
  4786  		}
  4787  	}
  4788  }
  4789  
  4790  type alwaysErrValuer struct{}
  4791  
  4792  // errEmpty is returned when an empty value is found
  4793  var errEmpty = errors.New("empty value")
  4794  
  4795  func (v alwaysErrValuer) Value() (driver.Value, error) {
  4796  	return nil, errEmpty
  4797  }
  4798  
  4799  // Issue 64707: Ensure that Stmt.Exec and Stmt.Query properly wraps underlying errors.
  4800  func TestDriverArgsWrapsErrors(t *testing.T) {
  4801  
  4802  	synctestSubtest(t, "exec", func(t *testing.T) {
  4803  		db := newTestDB(t, "people")
  4804  		defer closeDB(t, db)
  4805  		_, err := db.Exec("INSERT|keys|dec1=?", alwaysErrValuer{})
  4806  		if err == nil {
  4807  			t.Fatal("expecting back an error")
  4808  		}
  4809  		if !errors.Is(err, errEmpty) {
  4810  			t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errEmpty)
  4811  		}
  4812  		// Ensure that error substring matching still correctly works.
  4813  		if !strings.Contains(err.Error(), errEmpty.Error()) {
  4814  			t.Fatalf("Error %v does not contain %v", err, errEmpty)
  4815  		}
  4816  	})
  4817  
  4818  	synctestSubtest(t, "query", func(t *testing.T) {
  4819  		db := newTestDB(t, "people")
  4820  		defer closeDB(t, db)
  4821  		_, err := db.Query("INSERT|keys|dec1=?", alwaysErrValuer{})
  4822  		if err == nil {
  4823  			t.Fatal("expecting back an error")
  4824  		}
  4825  		if !errors.Is(err, errEmpty) {
  4826  			t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errEmpty)
  4827  		}
  4828  		// Ensure that error substring matching still correctly works.
  4829  		if !strings.Contains(err.Error(), errEmpty.Error()) {
  4830  			t.Fatalf("Error %v does not contain %v", err, errEmpty)
  4831  		}
  4832  	})
  4833  }
  4834  
  4835  func TestContextCancelDuringRawBytesScan(t *testing.T) {
  4836  	for _, mode := range []string{"nocancel", "top", "bottom", "go"} {
  4837  		synctestSubtest(t, mode, func(t *testing.T) {
  4838  			testContextCancelDuringRawBytesScan(t, mode)
  4839  		})
  4840  	}
  4841  }
  4842  
  4843  // From go.dev/issue/60304
  4844  func testContextCancelDuringRawBytesScan(t *testing.T, mode string) {
  4845  	db := newTestDB(t, "people")
  4846  	defer closeDB(t, db)
  4847  
  4848  	// cancel used to call close asynchronously.
  4849  	// This test checks that it waits so as not to interfere with RawBytes.
  4850  	ctx, cancel := context.WithCancel(context.Background())
  4851  	defer cancel()
  4852  
  4853  	r, err := db.QueryContext(ctx, "SELECT|people|name|")
  4854  	if err != nil {
  4855  		t.Fatal(err)
  4856  	}
  4857  	numRows := 0
  4858  	var sink byte
  4859  	for r.Next() {
  4860  		if mode == "top" && numRows == 2 {
  4861  			// cancel between Next and Scan is observed by Scan as err = context.Canceled.
  4862  			cancel()
  4863  			synctest.Wait()
  4864  		}
  4865  		numRows++
  4866  		var s RawBytes
  4867  		err = r.Scan(&s)
  4868  		if numRows == 3 && err == context.Canceled {
  4869  			if r.closemuScanHold {
  4870  				t.Errorf("expected closemu NOT to be held")
  4871  			}
  4872  			break
  4873  		}
  4874  		if !r.closemuScanHold {
  4875  			t.Errorf("expected closemu to be held")
  4876  		}
  4877  		if err != nil {
  4878  			t.Fatal(err)
  4879  		}
  4880  		t.Logf("read %q", s)
  4881  		if mode == "bottom" && numRows == 2 {
  4882  			// cancel before Next should be observed by Next, exiting the loop.
  4883  			cancel()
  4884  			synctest.Wait()
  4885  		}
  4886  		if mode == "go" && numRows == 2 {
  4887  			// cancel at any future time, to catch other cases
  4888  			go cancel()
  4889  		}
  4890  		for _, b := range s { // some operation reading from the raw memory
  4891  			sink += b
  4892  		}
  4893  		// r.Columns must not deadlock acquiring closemu.RLock.
  4894  		_, _ = r.Columns()
  4895  	}
  4896  	if r.closemuScanHold {
  4897  		t.Errorf("closemu held; should not be")
  4898  	}
  4899  
  4900  	// There are 3 rows. We canceled after reading 2 so we expect either
  4901  	// 2 or 3 depending on how the awaitDone goroutine schedules.
  4902  	switch numRows {
  4903  	case 0, 1:
  4904  		t.Errorf("got %d rows; want 2+", numRows)
  4905  	case 2:
  4906  		if err := r.Err(); err != context.Canceled {
  4907  			t.Errorf("unexpected error: %v (%T)", err, err)
  4908  		}
  4909  	default:
  4910  		// Made it to the end. This is rare, but fine. Permit it.
  4911  	}
  4912  
  4913  	if err := r.Close(); err != nil {
  4914  		t.Fatal(err)
  4915  	}
  4916  }
  4917  
  4918  func TestContextCancelBetweenNextAndErr(t *testing.T) {
  4919  	testDatabase(t, testContextCancelBetweenNextAndErr)
  4920  }
  4921  func testContextCancelBetweenNextAndErr(t *testing.T, db *DB) {
  4922  	populate(t, db, "people")
  4923  	ctx, cancel := context.WithCancel(context.Background())
  4924  	defer cancel()
  4925  
  4926  	r, err := db.QueryContext(ctx, "SELECT|people|name|")
  4927  	if err != nil {
  4928  		t.Fatal(err)
  4929  	}
  4930  	for r.Next() {
  4931  	}
  4932  	cancel() // wake up the awaitDone goroutine
  4933  	synctest.Wait()
  4934  	if err := r.Err(); err != nil {
  4935  		t.Fatal(err)
  4936  	}
  4937  }
  4938  
  4939  type testScanner struct {
  4940  	scanf func(src any) error
  4941  }
  4942  
  4943  func (ts testScanner) Scan(src any) error { return ts.scanf(src) }
  4944  
  4945  func TestContextCancelDuringScan(t *testing.T) {
  4946  	testDatabase(t, testContextCancelDuringScan)
  4947  }
  4948  func testContextCancelDuringScan(t *testing.T, db *DB) {
  4949  	populate(t, db, "people")
  4950  
  4951  	ctx, cancel := context.WithCancel(context.Background())
  4952  	defer cancel()
  4953  
  4954  	scanStart := make(chan any)
  4955  	scanEnd := make(chan error)
  4956  	scanner := &testScanner{
  4957  		scanf: func(src any) error {
  4958  			scanStart <- src
  4959  			return <-scanEnd
  4960  		},
  4961  	}
  4962  
  4963  	// Start a query, and pause it mid-scan.
  4964  	want := []byte("Alice")
  4965  	r, err := db.QueryContext(ctx, "SELECT|people|name|name=?", string(want))
  4966  	if err != nil {
  4967  		t.Fatal(err)
  4968  	}
  4969  	if !r.Next() {
  4970  		t.Fatalf("r.Next() = false, want true")
  4971  	}
  4972  	go func() {
  4973  		r.Scan(scanner)
  4974  	}()
  4975  	got := <-scanStart
  4976  	defer close(scanEnd)
  4977  	gotBytes, ok := got.([]byte)
  4978  	if !ok {
  4979  		t.Fatalf("r.Scan returned %T, want []byte", got)
  4980  	}
  4981  	if !bytes.Equal(gotBytes, want) {
  4982  		t.Fatalf("before cancel: r.Scan returned %q, want %q", gotBytes, want)
  4983  	}
  4984  
  4985  	// Cancel the query.
  4986  	cancel()
  4987  	synctest.Wait()
  4988  
  4989  	// Cancelling the query should not have changed the result.
  4990  	if !bytes.Equal(gotBytes, want) {
  4991  		t.Fatalf("after cancel: r.Scan result is now %q, want %q", gotBytes, want)
  4992  	}
  4993  }
  4994  
  4995  func TestNilErrorAfterClose(t *testing.T) {
  4996  	testDatabase(t, testNilErrorAfterClose)
  4997  }
  4998  func testNilErrorAfterClose(t *testing.T, db *DB) {
  4999  	populate(t, db, "people")
  5000  
  5001  	// This WithCancel is important; Rows contains an optimization to avoid
  5002  	// spawning a goroutine when the query/transaction context cannot be
  5003  	// canceled, but this test tests a bug which is caused by said goroutine.
  5004  	ctx, cancel := context.WithCancel(context.Background())
  5005  	defer cancel()
  5006  
  5007  	r, err := db.QueryContext(ctx, "SELECT|people|name|")
  5008  	if err != nil {
  5009  		t.Fatal(err)
  5010  	}
  5011  
  5012  	if err := r.Close(); err != nil {
  5013  		t.Fatal(err)
  5014  	}
  5015  
  5016  	synctest.Wait()
  5017  	if err := r.Err(); err != nil {
  5018  		t.Fatal(err)
  5019  	}
  5020  }
  5021  
  5022  // Issue #65201.
  5023  //
  5024  // If a RawBytes is reused across multiple queries,
  5025  // subsequent queries shouldn't overwrite driver-owned memory from previous queries.
  5026  func TestRawBytesReuse(t *testing.T) {
  5027  	testDatabase(t, testRawBytesReuse)
  5028  }
  5029  func testRawBytesReuse(t *testing.T, db *DB) {
  5030  	populate(t, db, "people")
  5031  
  5032  	var raw RawBytes
  5033  
  5034  	// The RawBytes in this query aliases driver-owned memory.
  5035  	rows, err := db.Query("SELECT|people|name|")
  5036  	if err != nil {
  5037  		t.Fatal(err)
  5038  	}
  5039  	rows.Next()
  5040  	rows.Scan(&raw) // now raw is pointing to driver-owned memory
  5041  	name1 := string(raw)
  5042  	rows.Close()
  5043  
  5044  	// The RawBytes in this query does not alias driver-owned memory.
  5045  	rows, err = db.Query("SELECT|people|age|")
  5046  	if err != nil {
  5047  		t.Fatal(err)
  5048  	}
  5049  	rows.Next()
  5050  	rows.Scan(&raw) // this must not write to the driver-owned memory in raw
  5051  	rows.Close()
  5052  
  5053  	// Repeat the first query. Nothing should have changed.
  5054  	rows, err = db.Query("SELECT|people|name|")
  5055  	if err != nil {
  5056  		t.Fatal(err)
  5057  	}
  5058  	rows.Next()
  5059  	rows.Scan(&raw) // raw points to driver-owned memory again
  5060  	name2 := string(raw)
  5061  	rows.Close()
  5062  	if name1 != name2 {
  5063  		t.Fatalf("Scan read name %q, want %q", name2, name1)
  5064  	}
  5065  }
  5066  
  5067  // badConn implements a bad driver.Conn, for TestBadDriver.
  5068  // The Exec method panics.
  5069  type badConn struct{}
  5070  
  5071  func (bc badConn) Prepare(query string) (driver.Stmt, error) {
  5072  	return nil, errors.New("badConn Prepare")
  5073  }
  5074  
  5075  func (bc badConn) Close() error {
  5076  	return nil
  5077  }
  5078  
  5079  func (bc badConn) Begin() (driver.Tx, error) {
  5080  	return nil, errors.New("badConn Begin")
  5081  }
  5082  
  5083  func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  5084  	panic("badConn.Exec")
  5085  }
  5086  
  5087  // badDriver is a driver.Driver that uses badConn.
  5088  type badDriver struct{}
  5089  
  5090  func (bd badDriver) Open(name string) (driver.Conn, error) {
  5091  	return badConn{}, nil
  5092  }
  5093  
  5094  // Issue 15901.
  5095  func TestBadDriver(t *testing.T) {
  5096  	synctest.Test(t, testBadDriver)
  5097  }
  5098  func testBadDriver(t *testing.T) {
  5099  	Register("bad", badDriver{})
  5100  	db, err := Open("bad", "ignored")
  5101  	if err != nil {
  5102  		t.Fatal(err)
  5103  	}
  5104  	defer func() {
  5105  		if r := recover(); r == nil {
  5106  			t.Error("expected panic")
  5107  		} else {
  5108  			if want := "badConn.Exec"; r.(string) != want {
  5109  				t.Errorf("panic was %v, expected %v", r, want)
  5110  			}
  5111  		}
  5112  	}()
  5113  	defer db.Close()
  5114  	db.Exec("ignored")
  5115  }
  5116  
  5117  type pingDriver struct {
  5118  	fails bool
  5119  }
  5120  
  5121  type pingConn struct {
  5122  	badConn
  5123  	driver *pingDriver
  5124  }
  5125  
  5126  var pingError = errors.New("Ping failed")
  5127  
  5128  func (pc pingConn) Ping(ctx context.Context) error {
  5129  	if pc.driver.fails {
  5130  		return pingError
  5131  	}
  5132  	return nil
  5133  }
  5134  
  5135  var _ driver.Pinger = pingConn{}
  5136  
  5137  func (pd *pingDriver) Open(name string) (driver.Conn, error) {
  5138  	return pingConn{driver: pd}, nil
  5139  }
  5140  
  5141  func TestPing(t *testing.T) {
  5142  	synctest.Test(t, testPing)
  5143  }
  5144  func testPing(t *testing.T) {
  5145  	driver := &pingDriver{}
  5146  	Register("ping", driver)
  5147  
  5148  	db, err := Open("ping", "ignored")
  5149  	if err != nil {
  5150  		t.Fatal(err)
  5151  	}
  5152  	defer db.Close()
  5153  
  5154  	if err := db.Ping(); err != nil {
  5155  		t.Errorf("err was %#v, expected nil", err)
  5156  		return
  5157  	}
  5158  
  5159  	driver.fails = true
  5160  	if err := db.Ping(); err != pingError {
  5161  		t.Errorf("err was %#v, expected pingError", err)
  5162  	}
  5163  }
  5164  
  5165  // Issue 18101.
  5166  func TestTypedString(t *testing.T) {
  5167  	testDatabase(t, testTypedString)
  5168  }
  5169  func testTypedString(t *testing.T, db *DB) {
  5170  	populate(t, db, "people")
  5171  
  5172  	type Str string
  5173  	var scanned Str
  5174  
  5175  	err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned)
  5176  	if err != nil {
  5177  		t.Fatal(err)
  5178  	}
  5179  	expected := Str("Alice")
  5180  	if scanned != expected {
  5181  		t.Errorf("expected %+v, got %+v", expected, scanned)
  5182  	}
  5183  }
  5184  
  5185  func BenchmarkConcurrentDBExec(b *testing.B) {
  5186  	b.ReportAllocs()
  5187  	ct := new(concurrentDBExecTest)
  5188  	for i := 0; i < b.N; i++ {
  5189  		doConcurrentTest(b, ct)
  5190  	}
  5191  }
  5192  
  5193  func BenchmarkConcurrentStmtQuery(b *testing.B) {
  5194  	b.ReportAllocs()
  5195  	ct := new(concurrentStmtQueryTest)
  5196  	for i := 0; i < b.N; i++ {
  5197  		doConcurrentTest(b, ct)
  5198  	}
  5199  }
  5200  
  5201  func BenchmarkConcurrentStmtExec(b *testing.B) {
  5202  	b.ReportAllocs()
  5203  	ct := new(concurrentStmtExecTest)
  5204  	for i := 0; i < b.N; i++ {
  5205  		doConcurrentTest(b, ct)
  5206  	}
  5207  }
  5208  
  5209  func BenchmarkConcurrentTxQuery(b *testing.B) {
  5210  	b.ReportAllocs()
  5211  	ct := new(concurrentTxQueryTest)
  5212  	for i := 0; i < b.N; i++ {
  5213  		doConcurrentTest(b, ct)
  5214  	}
  5215  }
  5216  
  5217  func BenchmarkConcurrentTxExec(b *testing.B) {
  5218  	b.ReportAllocs()
  5219  	ct := new(concurrentTxExecTest)
  5220  	for i := 0; i < b.N; i++ {
  5221  		doConcurrentTest(b, ct)
  5222  	}
  5223  }
  5224  
  5225  func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
  5226  	b.ReportAllocs()
  5227  	ct := new(concurrentTxStmtQueryTest)
  5228  	for i := 0; i < b.N; i++ {
  5229  		doConcurrentTest(b, ct)
  5230  	}
  5231  }
  5232  
  5233  func BenchmarkConcurrentTxStmtExec(b *testing.B) {
  5234  	b.ReportAllocs()
  5235  	ct := new(concurrentTxStmtExecTest)
  5236  	for i := 0; i < b.N; i++ {
  5237  		doConcurrentTest(b, ct)
  5238  	}
  5239  }
  5240  
  5241  func BenchmarkConcurrentRandom(b *testing.B) {
  5242  	b.ReportAllocs()
  5243  	ct := new(concurrentRandomTest)
  5244  	for i := 0; i < b.N; i++ {
  5245  		doConcurrentTest(b, ct)
  5246  	}
  5247  }
  5248  
  5249  func BenchmarkManyConcurrentQueries(b *testing.B) {
  5250  	b.ReportAllocs()
  5251  	// To see lock contention in Go 1.4, 16~ cores and 128~ goroutines are required.
  5252  	const parallelism = 16
  5253  
  5254  	db := newTestDB(b, "magicquery")
  5255  	defer closeDB(b, db)
  5256  	db.SetMaxIdleConns(runtime.GOMAXPROCS(0) * parallelism)
  5257  
  5258  	stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
  5259  	if err != nil {
  5260  		b.Fatal(err)
  5261  	}
  5262  	defer stmt.Close()
  5263  
  5264  	b.SetParallelism(parallelism)
  5265  	b.RunParallel(func(pb *testing.PB) {
  5266  		for pb.Next() {
  5267  			rows, err := stmt.Query("sleep", 1)
  5268  			if err != nil {
  5269  				b.Error(err)
  5270  				return
  5271  			}
  5272  			rows.Close()
  5273  		}
  5274  	})
  5275  }
  5276  
  5277  func TestGrabConnAllocs(t *testing.T) {
  5278  	synctest.Test(t, testGrabConnAllocs)
  5279  }
  5280  func testGrabConnAllocs(t *testing.T) {
  5281  	testenv.SkipIfOptimizationOff(t)
  5282  	if race.Enabled {
  5283  		t.Skip("skipping allocation test when using race detector")
  5284  	}
  5285  	c := new(Conn)
  5286  	ctx := context.Background()
  5287  	n := int(testing.AllocsPerRun(1000, func() {
  5288  		_, release, err := c.grabConn(ctx)
  5289  		if err != nil {
  5290  			t.Fatal(err)
  5291  		}
  5292  		release(nil)
  5293  	}))
  5294  	if n > 0 {
  5295  		t.Fatalf("Conn.grabConn allocated %v objects; want 0", n)
  5296  	}
  5297  }
  5298  
  5299  func BenchmarkGrabConn(b *testing.B) {
  5300  	b.ReportAllocs()
  5301  	c := new(Conn)
  5302  	ctx := context.Background()
  5303  	for i := 0; i < b.N; i++ {
  5304  		_, release, err := c.grabConn(ctx)
  5305  		if err != nil {
  5306  			b.Fatal(err)
  5307  		}
  5308  		release(nil)
  5309  	}
  5310  }
  5311  
  5312  func TestConnRequestSet(t *testing.T) {
  5313  	var s connRequestSet
  5314  	wantLen := func(want int) {
  5315  		t.Helper()
  5316  		if got := s.Len(); got != want {
  5317  			t.Errorf("Len = %d; want %d", got, want)
  5318  		}
  5319  		if want == 0 && !t.Failed() {
  5320  			if _, ok := s.TakeRandom(); ok {
  5321  				t.Fatalf("TakeRandom returned result when empty")
  5322  			}
  5323  		}
  5324  	}
  5325  	reset := func() { s = connRequestSet{} }
  5326  
  5327  	t.Run("add-delete", func(t *testing.T) {
  5328  		reset()
  5329  		wantLen(0)
  5330  		dh := s.Add(nil)
  5331  		wantLen(1)
  5332  		if !s.Delete(dh) {
  5333  			t.Fatal("failed to delete")
  5334  		}
  5335  		wantLen(0)
  5336  		if s.Delete(dh) {
  5337  			t.Error("delete worked twice")
  5338  		}
  5339  		wantLen(0)
  5340  	})
  5341  	t.Run("take-before-delete", func(t *testing.T) {
  5342  		reset()
  5343  		ch1 := make(chan connRequest)
  5344  		dh := s.Add(ch1)
  5345  		wantLen(1)
  5346  		if got, ok := s.TakeRandom(); !ok || got != ch1 {
  5347  			t.Fatalf("wrong take; ok=%v", ok)
  5348  		}
  5349  		wantLen(0)
  5350  		if s.Delete(dh) {
  5351  			t.Error("unexpected delete after take")
  5352  		}
  5353  	})
  5354  	t.Run("get-take-many", func(t *testing.T) {
  5355  		reset()
  5356  		m := map[chan connRequest]bool{}
  5357  		const N = 100
  5358  		var inOrder, backOut []chan connRequest
  5359  		for range N {
  5360  			c := make(chan connRequest)
  5361  			m[c] = true
  5362  			s.Add(c)
  5363  			inOrder = append(inOrder, c)
  5364  		}
  5365  		if s.Len() != N {
  5366  			t.Fatalf("Len = %v; want %v", s.Len(), N)
  5367  		}
  5368  		for s.Len() > 0 {
  5369  			c, ok := s.TakeRandom()
  5370  			if !ok {
  5371  				t.Fatal("failed to take when non-empty")
  5372  			}
  5373  			if !m[c] {
  5374  				t.Fatal("returned item not in remaining set")
  5375  			}
  5376  			delete(m, c)
  5377  			backOut = append(backOut, c)
  5378  		}
  5379  		if len(m) > 0 {
  5380  			t.Error("items remain in expected map")
  5381  		}
  5382  		if slices.Equal(inOrder, backOut) { // N! chance of flaking; N=100 is fine
  5383  			t.Error("wasn't random")
  5384  		}
  5385  	})
  5386  	t.Run("close-delete", func(t *testing.T) {
  5387  		reset()
  5388  		ch := make(chan connRequest)
  5389  		dh := s.Add(ch)
  5390  		wantLen(1)
  5391  		s.CloseAndRemoveAll()
  5392  		wantLen(0)
  5393  		if s.Delete(dh) {
  5394  			t.Error("unexpected delete after CloseAndRemoveAll")
  5395  		}
  5396  	})
  5397  }
  5398  
  5399  func BenchmarkConnRequestSet(b *testing.B) {
  5400  	var s connRequestSet
  5401  	for range b.N {
  5402  		for range 16 {
  5403  			s.Add(nil)
  5404  		}
  5405  		for range 8 {
  5406  			if _, ok := s.TakeRandom(); !ok {
  5407  				b.Fatal("want ok")
  5408  			}
  5409  		}
  5410  		for range 8 {
  5411  			s.Add(nil)
  5412  		}
  5413  		for range 16 {
  5414  			if _, ok := s.TakeRandom(); !ok {
  5415  				b.Fatal("want ok")
  5416  			}
  5417  		}
  5418  		if _, ok := s.TakeRandom(); ok {
  5419  			b.Fatal("unexpected ok")
  5420  		}
  5421  	}
  5422  }
  5423  
  5424  func TestIssue69837(t *testing.T) {
  5425  	u := Null[uint]{V: 1, Valid: true}
  5426  	val, err := driver.DefaultParameterConverter.ConvertValue(u)
  5427  	if err != nil {
  5428  		t.Errorf("ConvertValue() error = %v, want nil", err)
  5429  	}
  5430  
  5431  	if v, ok := val.(int64); !ok {
  5432  		t.Errorf("val.(type): got %T, expected int64", val)
  5433  	} else if v != 1 {
  5434  		t.Errorf("val: got %d, expected 1", v)
  5435  	}
  5436  }
  5437  
  5438  type issue69728Type struct {
  5439  	ID   int
  5440  	Name string
  5441  }
  5442  
  5443  func (t issue69728Type) Value() (driver.Value, error) {
  5444  	return []byte(fmt.Sprintf("%d, %s", t.ID, t.Name)), nil
  5445  }
  5446  
  5447  func TestIssue69728(t *testing.T) {
  5448  	forValue := Null[issue69728Type]{
  5449  		Valid: true,
  5450  		V: issue69728Type{
  5451  			ID:   42,
  5452  			Name: "foobar",
  5453  		},
  5454  	}
  5455  
  5456  	v1, err := forValue.Value()
  5457  	if err != nil {
  5458  		t.Errorf("forValue.Value() error = %v, want nil", err)
  5459  	}
  5460  
  5461  	v2, err := forValue.V.Value()
  5462  	if err != nil {
  5463  		t.Errorf("forValue.V.Value() error = %v, want nil", err)
  5464  	}
  5465  
  5466  	if !reflect.DeepEqual(v1, v2) {
  5467  		t.Errorf("not equal; v1 = %v, v2 = %v", v1, v2)
  5468  	}
  5469  }
  5470  
  5471  func TestColumnConverterWithUnknownInputCount(t *testing.T) {
  5472  	db := OpenDB(&unknownInputsConnector{})
  5473  	stmt, err := db.Prepare("SELECT ?")
  5474  	if err != nil {
  5475  		t.Fatal(err)
  5476  	}
  5477  	_, err = stmt.Exec(1)
  5478  	if err != nil {
  5479  		t.Fatal(err)
  5480  	}
  5481  }
  5482  
  5483  type unknownInputsConnector struct{}
  5484  
  5485  func (unknownInputsConnector) Connect(context.Context) (driver.Conn, error) {
  5486  	return unknownInputsConn{}, nil
  5487  }
  5488  
  5489  func (unknownInputsConnector) Driver() driver.Driver { return nil }
  5490  
  5491  type unknownInputsConn struct{}
  5492  
  5493  func (unknownInputsConn) Prepare(string) (driver.Stmt, error) { return unknownInputsStmt{}, nil }
  5494  func (unknownInputsConn) Close() error                        { return nil }
  5495  func (unknownInputsConn) Begin() (driver.Tx, error)           { return nil, nil }
  5496  
  5497  type unknownInputsStmt struct{}
  5498  
  5499  func (unknownInputsStmt) Close() error  { return nil }
  5500  func (unknownInputsStmt) NumInput() int { return -1 }
  5501  func (unknownInputsStmt) Exec(args []driver.Value) (driver.Result, error) {
  5502  	if _, ok := args[0].(string); !ok {
  5503  		return nil, fmt.Errorf("Expected string, got %T", args[0])
  5504  	}
  5505  	return nil, nil
  5506  }
  5507  func (unknownInputsStmt) Query([]driver.Value) (driver.Rows, error) { return nil, nil }
  5508  func (unknownInputsStmt) ColumnConverter(idx int) driver.ValueConverter {
  5509  	return unknownInputsValueConverter{}
  5510  }
  5511  
  5512  type unknownInputsValueConverter struct{}
  5513  
  5514  func (unknownInputsValueConverter) ConvertValue(v any) (driver.Value, error) {
  5515  	return "string", nil
  5516  }
  5517  
  5518  func TestNullTypeScanErrorConsistency(t *testing.T) {
  5519  	// Issue #45662: Null* types should have Valid=false when Scan returns an error.
  5520  	// Previously, Valid was set to true before convertAssign was called,
  5521  	// so if conversion failed, Valid would still be true despite the error.
  5522  
  5523  	tests := []struct {
  5524  		name    string
  5525  		scanner Scanner
  5526  		input   any
  5527  		wantErr bool
  5528  	}{
  5529  		{
  5530  			name:    "NullInt32 with invalid input",
  5531  			scanner: &NullInt32{},
  5532  			input:   []byte("not_a_number"),
  5533  			wantErr: true,
  5534  		},
  5535  		{
  5536  			name:    "NullInt64 with invalid input",
  5537  			scanner: &NullInt64{},
  5538  			input:   []byte("not_a_number"),
  5539  			wantErr: true,
  5540  		},
  5541  		{
  5542  			name:    "NullFloat64 with invalid input",
  5543  			scanner: &NullFloat64{},
  5544  			input:   []byte("not_a_float"),
  5545  			wantErr: true,
  5546  		},
  5547  		{
  5548  			name:    "NullBool with invalid input",
  5549  			scanner: &NullBool{},
  5550  			input:   []byte("not_a_bool"),
  5551  			wantErr: true,
  5552  		},
  5553  		// Valid cases should still work
  5554  		{
  5555  			name:    "NullInt32 with valid input",
  5556  			scanner: &NullInt32{},
  5557  			input:   int64(42),
  5558  			wantErr: false,
  5559  		},
  5560  		{
  5561  			name:    "NullInt64 with valid input",
  5562  			scanner: &NullInt64{},
  5563  			input:   int64(42),
  5564  			wantErr: false,
  5565  		},
  5566  		{
  5567  			name:    "NullFloat64 with valid input",
  5568  			scanner: &NullFloat64{},
  5569  			input:   float64(3.14),
  5570  			wantErr: false,
  5571  		},
  5572  		{
  5573  			name:    "NullBool with valid input",
  5574  			scanner: &NullBool{},
  5575  			input:   true,
  5576  			wantErr: false,
  5577  		},
  5578  		{
  5579  			name:    "NullString with valid input",
  5580  			scanner: &NullString{},
  5581  			input:   "hello",
  5582  			wantErr: false,
  5583  		},
  5584  	}
  5585  
  5586  	for _, tt := range tests {
  5587  		t.Run(tt.name, func(t *testing.T) {
  5588  			err := tt.scanner.Scan(tt.input)
  5589  
  5590  			// Check that error matches expectation
  5591  			if (err != nil) != tt.wantErr {
  5592  				t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
  5593  			}
  5594  
  5595  			// The key invariant: Valid should be the opposite of whether we got an error
  5596  			// (assuming non-nil input)
  5597  			var valid bool
  5598  			switch s := tt.scanner.(type) {
  5599  			case *NullInt32:
  5600  				valid = s.Valid
  5601  			case *NullInt64:
  5602  				valid = s.Valid
  5603  			case *NullFloat64:
  5604  				valid = s.Valid
  5605  			case *NullBool:
  5606  				valid = s.Valid
  5607  			case *NullString:
  5608  				valid = s.Valid
  5609  			case *NullTime:
  5610  				valid = s.Valid
  5611  			}
  5612  
  5613  			if err != nil && valid {
  5614  				t.Errorf("Scan() returned error but Valid=true; want Valid=false when err!=nil")
  5615  			}
  5616  			if err == nil && !valid {
  5617  				t.Errorf("Scan() returned nil error but Valid=false; want Valid=true when err==nil")
  5618  			}
  5619  		})
  5620  	}
  5621  }
  5622  
  5623  // TestNullTypeScanNil verifies that scanning nil sets Valid=false without error.
  5624  func TestNullTypeScanNil(t *testing.T) {
  5625  	tests := []struct {
  5626  		name    string
  5627  		scanner Scanner
  5628  	}{
  5629  		{"NullString", &NullString{String: "preset", Valid: true}},
  5630  		{"NullInt64", &NullInt64{Int64: 42, Valid: true}},
  5631  		{"NullInt32", &NullInt32{Int32: 42, Valid: true}},
  5632  		{"NullFloat64", &NullFloat64{Float64: 3.14, Valid: true}},
  5633  		{"NullBool", &NullBool{Bool: true, Valid: true}},
  5634  		{"NullTime", &NullTime{Time: time.Now(), Valid: true}},
  5635  	}
  5636  
  5637  	for _, tt := range tests {
  5638  		t.Run(tt.name, func(t *testing.T) {
  5639  			err := tt.scanner.Scan(nil)
  5640  			if err != nil {
  5641  				t.Errorf("Scan(nil) error = %v; want nil", err)
  5642  			}
  5643  
  5644  			var valid bool
  5645  			switch s := tt.scanner.(type) {
  5646  			case *NullString:
  5647  				valid = s.Valid
  5648  			case *NullInt64:
  5649  				valid = s.Valid
  5650  			case *NullInt32:
  5651  				valid = s.Valid
  5652  			case *NullFloat64:
  5653  				valid = s.Valid
  5654  			case *NullBool:
  5655  				valid = s.Valid
  5656  			case *NullTime:
  5657  				valid = s.Valid
  5658  			}
  5659  
  5660  			if valid {
  5661  				t.Errorf("Scan(nil) left Valid=true; want Valid=false")
  5662  			}
  5663  		})
  5664  	}
  5665  }
  5666  
  5667  type testStringType struct {
  5668  	s string
  5669  }
  5670  
  5671  func TestQueryRowsScanner(t *testing.T) {
  5672  	testDatabase(t, testQueryRowsScanner, requireFeature("ScanColumn"))
  5673  }
  5674  func testQueryRowsScanner(t *testing.T, db *DB) {
  5675  	populate(t, db, "people")
  5676  	rows, err := db.Query("SELECT|people|age,name|")
  5677  	if err != nil {
  5678  		t.Fatalf("Query: %v", err)
  5679  	}
  5680  	defer rows.Close()
  5681  	type row struct {
  5682  		age  int
  5683  		name testStringType
  5684  	}
  5685  	got := []row{}
  5686  	for rows.Next() {
  5687  		var r row
  5688  		err = rows.Scan(&r.age, &r.name)
  5689  		if err != nil {
  5690  			t.Fatalf("Scan: %v", err)
  5691  		}
  5692  		got = append(got, r)
  5693  	}
  5694  	err = rows.Err()
  5695  	if err != nil {
  5696  		t.Fatalf("Err: %v", err)
  5697  	}
  5698  	want := []row{
  5699  		{age: 1, name: testStringType{"Alice"}},
  5700  		{age: 2, name: testStringType{"Bob"}},
  5701  		{age: 3, name: testStringType{"Chris"}},
  5702  	}
  5703  	if !slices.Equal(got, want) {
  5704  		t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
  5705  	}
  5706  }
  5707  
  5708  type rowsColumnScannerConnector struct {
  5709  	fakeConnector
  5710  }
  5711  
  5712  func (c *rowsColumnScannerConnector) Connect(ctx context.Context) (driver.Conn, error) {
  5713  	conn, err := c.fakeConnector.Connect(ctx)
  5714  	fc := getFakeConn(conn)
  5715  	return &rowsColumnScannerConn{fc}, err
  5716  }
  5717  
  5718  // rowsColumnScannerConn is a Conn with rows that implement RowsColumnScanner.
  5719  type rowsColumnScannerConn struct {
  5720  	*fakeConn
  5721  }
  5722  
  5723  func (s *rowsColumnScannerConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
  5724  	stmt, err := s.fakeConn.PrepareContext(ctx, query)
  5725  	if err != nil {
  5726  		return nil, err
  5727  	}
  5728  	return &rowsColumnScannerStmt{stmt.(*fakeStmt)}, nil
  5729  }
  5730  
  5731  type rowsColumnScannerStmt struct {
  5732  	*fakeStmt
  5733  }
  5734  
  5735  func (s *rowsColumnScannerStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  5736  	rows, err := s.fakeStmt.QueryContext(ctx, args)
  5737  	if err != nil {
  5738  		return nil, err
  5739  	}
  5740  	return &rowsColumnScannerRows{rowsCursor: rows.(*rowsCursor)}, nil
  5741  }
  5742  
  5743  type rowsColumnScannerRows struct {
  5744  	*rowsCursor
  5745  	row []driver.Value
  5746  }
  5747  
  5748  func (c *rowsColumnScannerRows) NextRow() error {
  5749  	if c.row == nil {
  5750  		c.row = make([]driver.Value, len(c.rowsCursor.Columns()))
  5751  	}
  5752  	return c.rowsCursor.Next(c.row)
  5753  }
  5754  
  5755  func (c *rowsColumnScannerRows) NextResultSet() error {
  5756  	c.row = nil
  5757  	return c.rowsCursor.NextResultSet()
  5758  }
  5759  
  5760  func (c *rowsColumnScannerRows) ScanColumn(ctx driver.ScanContext, index int, dest any) error {
  5761  	if index < 0 || index >= len(c.row) {
  5762  		return fmt.Errorf("index %v out of range", index)
  5763  	}
  5764  	switch d := dest.(type) {
  5765  	case *testStringType:
  5766  		switch s := c.row[index].(type) {
  5767  		case string:
  5768  			d.s = s
  5769  			return nil
  5770  		case []byte:
  5771  			d.s = string(s)
  5772  			return nil
  5773  		}
  5774  	}
  5775  	return ConvertAssign(ctx, dest, c.row[index])
  5776  }
  5777  

View as plain text