1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package sql
17
18 import (
19 "context"
20 "database/sql/driver"
21 "errors"
22 "fmt"
23 "io"
24 "maps"
25 "math/rand/v2"
26 "reflect"
27 "runtime"
28 "slices"
29 "strconv"
30 "sync"
31 "sync/atomic"
32 "time"
33 _ "unsafe"
34 )
35
36 var driversMu sync.RWMutex
37
38
39
40
41
42
43
44
45
46
47
48 var drivers = make(map[string]driver.Driver)
49
50
51
52
53 func Register(name string, driver driver.Driver) {
54 driversMu.Lock()
55 defer driversMu.Unlock()
56 if driver == nil {
57 panic("sql: Register driver is nil")
58 }
59 if _, dup := drivers[name]; dup {
60 panic("sql: Register called twice for driver " + name)
61 }
62 drivers[name] = driver
63 }
64
65 func unregisterAllDrivers() {
66 driversMu.Lock()
67 defer driversMu.Unlock()
68
69 drivers = make(map[string]driver.Driver)
70 }
71
72
73 func Drivers() []string {
74 driversMu.RLock()
75 defer driversMu.RUnlock()
76 return slices.Sorted(maps.Keys(drivers))
77 }
78
79
80
81
82
83
84
85 type NamedArg struct {
86 _NamedFieldsRequired struct{}
87
88
89
90
91
92
93
94 Name string
95
96
97
98
99 Value any
100 }
101
102
103
104
105
106
107
108
109
110
111
112
113
114 func Named(name string, value any) NamedArg {
115
116
117
118
119 return NamedArg{Name: name, Value: value}
120 }
121
122
123 type IsolationLevel int
124
125
126
127
128
129 const (
130 LevelDefault IsolationLevel = iota
131 LevelReadUncommitted
132 LevelReadCommitted
133 LevelWriteCommitted
134 LevelRepeatableRead
135 LevelSnapshot
136 LevelSerializable
137 LevelLinearizable
138 )
139
140
141 func (i IsolationLevel) String() string {
142 switch i {
143 case LevelDefault:
144 return "Default"
145 case LevelReadUncommitted:
146 return "Read Uncommitted"
147 case LevelReadCommitted:
148 return "Read Committed"
149 case LevelWriteCommitted:
150 return "Write Committed"
151 case LevelRepeatableRead:
152 return "Repeatable Read"
153 case LevelSnapshot:
154 return "Snapshot"
155 case LevelSerializable:
156 return "Serializable"
157 case LevelLinearizable:
158 return "Linearizable"
159 default:
160 return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
161 }
162 }
163
164 var _ fmt.Stringer = LevelDefault
165
166
167 type TxOptions struct {
168
169
170 Isolation IsolationLevel
171 ReadOnly bool
172 }
173
174
175
176
177 type RawBytes []byte
178
179
180
181
182
183
184
185
186
187
188
189
190
191 type NullString struct {
192 String string
193 Valid bool
194 }
195
196
197 func (ns *NullString) Scan(value any) error {
198 if value == nil {
199 ns.String, ns.Valid = "", false
200 return nil
201 }
202 err := convertAssign(&ns.String, value)
203 ns.Valid = err == nil
204 return err
205 }
206
207
208 func (ns NullString) Value() (driver.Value, error) {
209 if !ns.Valid {
210 return nil, nil
211 }
212 return ns.String, nil
213 }
214
215
216
217
218 type NullInt64 struct {
219 Int64 int64
220 Valid bool
221 }
222
223
224 func (n *NullInt64) Scan(value any) error {
225 if value == nil {
226 n.Int64, n.Valid = 0, false
227 return nil
228 }
229 err := convertAssign(&n.Int64, value)
230 n.Valid = err == nil
231 return err
232 }
233
234
235 func (n NullInt64) Value() (driver.Value, error) {
236 if !n.Valid {
237 return nil, nil
238 }
239 return n.Int64, nil
240 }
241
242
243
244
245 type NullInt32 struct {
246 Int32 int32
247 Valid bool
248 }
249
250
251 func (n *NullInt32) Scan(value any) error {
252 if value == nil {
253 n.Int32, n.Valid = 0, false
254 return nil
255 }
256 err := convertAssign(&n.Int32, value)
257 n.Valid = err == nil
258 return err
259 }
260
261
262 func (n NullInt32) Value() (driver.Value, error) {
263 if !n.Valid {
264 return nil, nil
265 }
266 return int64(n.Int32), nil
267 }
268
269
270
271
272 type NullInt16 struct {
273 Int16 int16
274 Valid bool
275 }
276
277
278 func (n *NullInt16) Scan(value any) error {
279 if value == nil {
280 n.Int16, n.Valid = 0, false
281 return nil
282 }
283 err := convertAssign(&n.Int16, value)
284 n.Valid = err == nil
285 return err
286 }
287
288
289 func (n NullInt16) Value() (driver.Value, error) {
290 if !n.Valid {
291 return nil, nil
292 }
293 return int64(n.Int16), nil
294 }
295
296
297
298
299 type NullByte struct {
300 Byte byte
301 Valid bool
302 }
303
304
305 func (n *NullByte) Scan(value any) error {
306 if value == nil {
307 n.Byte, n.Valid = 0, false
308 return nil
309 }
310 err := convertAssign(&n.Byte, value)
311 n.Valid = err == nil
312 return err
313 }
314
315
316 func (n NullByte) Value() (driver.Value, error) {
317 if !n.Valid {
318 return nil, nil
319 }
320 return int64(n.Byte), nil
321 }
322
323
324
325
326 type NullFloat64 struct {
327 Float64 float64
328 Valid bool
329 }
330
331
332 func (n *NullFloat64) Scan(value any) error {
333 if value == nil {
334 n.Float64, n.Valid = 0, false
335 return nil
336 }
337 err := convertAssign(&n.Float64, value)
338 n.Valid = err == nil
339 return err
340 }
341
342
343 func (n NullFloat64) Value() (driver.Value, error) {
344 if !n.Valid {
345 return nil, nil
346 }
347 return n.Float64, nil
348 }
349
350
351
352
353 type NullBool struct {
354 Bool bool
355 Valid bool
356 }
357
358
359 func (n *NullBool) Scan(value any) error {
360 if value == nil {
361 n.Bool, n.Valid = false, false
362 return nil
363 }
364 err := convertAssign(&n.Bool, value)
365 n.Valid = err == nil
366 return err
367 }
368
369
370 func (n NullBool) Value() (driver.Value, error) {
371 if !n.Valid {
372 return nil, nil
373 }
374 return n.Bool, nil
375 }
376
377
378
379
380 type NullTime struct {
381 Time time.Time
382 Valid bool
383 }
384
385
386 func (n *NullTime) Scan(value any) error {
387 if value == nil {
388 n.Time, n.Valid = time.Time{}, false
389 return nil
390 }
391 err := convertAssign(&n.Time, value)
392 n.Valid = err == nil
393 return err
394 }
395
396
397 func (n NullTime) Value() (driver.Value, error) {
398 if !n.Valid {
399 return nil, nil
400 }
401 return n.Time, nil
402 }
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418 type Null[T any] struct {
419 V T
420 Valid bool
421 }
422
423 func (n *Null[T]) Scan(value any) error {
424 if value == nil {
425 n.V, n.Valid = *new(T), false
426 return nil
427 }
428 err := convertAssign(&n.V, value)
429 n.Valid = err == nil
430 return err
431 }
432
433 func (n Null[T]) Value() (driver.Value, error) {
434 if !n.Valid {
435 return nil, nil
436 }
437 v := any(n.V)
438
439 if valuer, ok := v.(driver.Valuer); ok {
440 val, err := callValuerValue(valuer)
441 if err != nil {
442 return val, err
443 }
444 v = val
445 }
446
447 return driver.DefaultParameterConverter.ConvertValue(v)
448 }
449
450
451 type Scanner interface {
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470 Scan(src any) error
471 }
472
473
474
475
476
477
478
479
480
481 type Out struct {
482 _NamedFieldsRequired struct{}
483
484
485
486 Dest any
487
488
489
490
491 In bool
492 }
493
494
495
496
497 var ErrNoRows = errors.New("sql: no rows in result set")
498
499
500
501
502
503
504
505
506
507
508
509
510
511 type DB struct {
512
513 waitDuration atomic.Int64
514
515 connector driver.Connector
516
517
518
519 numClosed atomic.Uint64
520
521 mu sync.Mutex
522 freeConn []*driverConn
523 connRequests connRequestSet
524 numOpen int
525
526
527
528
529
530 openerCh chan struct{}
531 closed bool
532 dep map[finalCloser]depSet
533 lastPut map[*driverConn]string
534 maxIdleCount int
535 maxOpen int
536 maxLifetime time.Duration
537 maxIdleTime time.Duration
538 cleanerCh chan struct{}
539 waitCount int64
540 maxIdleClosed int64
541 maxIdleTimeClosed int64
542 maxLifetimeClosed int64
543
544 stop func()
545 }
546
547
548 type connReuseStrategy uint8
549
550 const (
551
552 alwaysNewConn connReuseStrategy = iota
553
554
555
556 cachedOrNewConn
557 )
558
559
560
561
562
563 type driverConn struct {
564 db *DB
565 createdAt time.Time
566
567 sync.Mutex
568 ci driver.Conn
569 needReset bool
570 closed bool
571 finalClosed bool
572 openStmt map[*driverStmt]bool
573
574
575 inUse bool
576 dbmuClosed bool
577 returnedAt time.Time
578 onPut []func()
579 }
580
581 func (dc *driverConn) releaseConn(err error) {
582 dc.db.putConn(dc, err, true)
583 }
584
585 func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
586 dc.Lock()
587 defer dc.Unlock()
588 delete(dc.openStmt, ds)
589 }
590
591 func (dc *driverConn) expired(timeout time.Duration) bool {
592 if timeout <= 0 {
593 return false
594 }
595 return dc.createdAt.Add(timeout).Before(time.Now())
596 }
597
598
599
600 func (dc *driverConn) resetSession(ctx context.Context) error {
601 dc.Lock()
602 defer dc.Unlock()
603
604 if !dc.needReset {
605 return nil
606 }
607 if cr, ok := dc.ci.(driver.SessionResetter); ok {
608 return cr.ResetSession(ctx)
609 }
610 return nil
611 }
612
613
614
615 func (dc *driverConn) validateConnection(needsReset bool) bool {
616 dc.Lock()
617 defer dc.Unlock()
618
619 if needsReset {
620 dc.needReset = true
621 }
622 if cv, ok := dc.ci.(driver.Validator); ok {
623 return cv.IsValid()
624 }
625 return true
626 }
627
628
629
630 func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
631 si, err := ctxDriverPrepare(ctx, dc.ci, query)
632 if err != nil {
633 return nil, err
634 }
635 ds := &driverStmt{Locker: dc, si: si}
636
637
638 if cg != nil {
639 return ds, nil
640 }
641
642
643
644
645
646 if dc.openStmt == nil {
647 dc.openStmt = make(map[*driverStmt]bool)
648 }
649 dc.openStmt[ds] = true
650 return ds, nil
651 }
652
653
654 func (dc *driverConn) closeDBLocked() func() error {
655 dc.Lock()
656 defer dc.Unlock()
657 if dc.closed {
658 return func() error { return errors.New("sql: duplicate driverConn close") }
659 }
660 dc.closed = true
661 return dc.db.removeDepLocked(dc, dc)
662 }
663
664 func (dc *driverConn) Close() error {
665 dc.Lock()
666 if dc.closed {
667 dc.Unlock()
668 return errors.New("sql: duplicate driverConn close")
669 }
670 dc.closed = true
671 dc.Unlock()
672
673
674 dc.db.mu.Lock()
675 dc.dbmuClosed = true
676 fn := dc.db.removeDepLocked(dc, dc)
677 dc.db.mu.Unlock()
678 return fn()
679 }
680
681 func (dc *driverConn) finalClose() error {
682 var err error
683
684
685
686 var openStmt []*driverStmt
687 withLock(dc, func() {
688 openStmt = make([]*driverStmt, 0, len(dc.openStmt))
689 for ds := range dc.openStmt {
690 openStmt = append(openStmt, ds)
691 }
692 dc.openStmt = nil
693 })
694 for _, ds := range openStmt {
695 ds.Close()
696 }
697 withLock(dc, func() {
698 dc.finalClosed = true
699 err = dc.ci.Close()
700 dc.ci = nil
701 })
702
703 dc.db.mu.Lock()
704 dc.db.numOpen--
705 dc.db.maybeOpenNewConnections()
706 dc.db.mu.Unlock()
707
708 dc.db.numClosed.Add(1)
709 return err
710 }
711
712
713
714
715 type driverStmt struct {
716 sync.Locker
717 si driver.Stmt
718 closed bool
719 closeErr error
720 }
721
722
723
724 func (ds *driverStmt) Close() error {
725 ds.Lock()
726 defer ds.Unlock()
727 if ds.closed {
728 return ds.closeErr
729 }
730 ds.closed = true
731 ds.closeErr = ds.si.Close()
732 return ds.closeErr
733 }
734
735
736 type depSet map[any]bool
737
738
739
740 type finalCloser interface {
741
742
743 finalClose() error
744 }
745
746
747
748 func (db *DB) addDep(x finalCloser, dep any) {
749 db.mu.Lock()
750 defer db.mu.Unlock()
751 db.addDepLocked(x, dep)
752 }
753
754 func (db *DB) addDepLocked(x finalCloser, dep any) {
755 if db.dep == nil {
756 db.dep = make(map[finalCloser]depSet)
757 }
758 xdep := db.dep[x]
759 if xdep == nil {
760 xdep = make(depSet)
761 db.dep[x] = xdep
762 }
763 xdep[dep] = true
764 }
765
766
767
768
769
770 func (db *DB) removeDep(x finalCloser, dep any) error {
771 db.mu.Lock()
772 fn := db.removeDepLocked(x, dep)
773 db.mu.Unlock()
774 return fn()
775 }
776
777 func (db *DB) removeDepLocked(x finalCloser, dep any) func() error {
778 xdep, ok := db.dep[x]
779 if !ok {
780 panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
781 }
782
783 l0 := len(xdep)
784 delete(xdep, dep)
785
786 switch len(xdep) {
787 case l0:
788
789 panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
790 case 0:
791
792 delete(db.dep, x)
793 return x.finalClose
794 default:
795
796 return func() error { return nil }
797 }
798 }
799
800
801
802
803
804
805 var connectionRequestQueueSize = 1000000
806
807 type dsnConnector struct {
808 dsn string
809 driver driver.Driver
810 }
811
812 func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
813 return t.driver.Open(t.dsn)
814 }
815
816 func (t dsnConnector) Driver() driver.Driver {
817 return t.driver
818 }
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836 func OpenDB(c driver.Connector) *DB {
837 ctx, cancel := context.WithCancel(context.Background())
838 db := &DB{
839 connector: c,
840 openerCh: make(chan struct{}, connectionRequestQueueSize),
841 lastPut: make(map[*driverConn]string),
842 stop: cancel,
843 }
844
845 go db.connectionOpener(ctx)
846
847 return db
848 }
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867 func Open(driverName, dataSourceName string) (*DB, error) {
868 driversMu.RLock()
869 driveri, ok := drivers[driverName]
870 driversMu.RUnlock()
871 if !ok {
872 return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
873 }
874
875 if driverCtx, ok := driveri.(driver.DriverContext); ok {
876 connector, err := driverCtx.OpenConnector(dataSourceName)
877 if err != nil {
878 return nil, err
879 }
880 return OpenDB(connector), nil
881 }
882
883 return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
884 }
885
886 func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
887 var err error
888 if pinger, ok := dc.ci.(driver.Pinger); ok {
889 withLock(dc, func() {
890 err = pinger.Ping(ctx)
891 })
892 }
893 release(err)
894 return err
895 }
896
897
898
899 func (db *DB) PingContext(ctx context.Context) error {
900 var dc *driverConn
901 var err error
902
903 err = db.retry(func(strategy connReuseStrategy) error {
904 dc, err = db.conn(ctx, strategy)
905 return err
906 })
907
908 if err != nil {
909 return err
910 }
911
912 return db.pingDC(ctx, dc, dc.releaseConn)
913 }
914
915
916
917
918
919
920 func (db *DB) Ping() error {
921 return db.PingContext(context.Background())
922 }
923
924
925
926
927
928
929
930 func (db *DB) Close() error {
931 db.mu.Lock()
932 if db.closed {
933 db.mu.Unlock()
934 return nil
935 }
936 if db.cleanerCh != nil {
937 close(db.cleanerCh)
938 }
939 var err error
940 fns := make([]func() error, 0, len(db.freeConn))
941 for _, dc := range db.freeConn {
942 fns = append(fns, dc.closeDBLocked())
943 }
944 db.freeConn = nil
945 db.closed = true
946 db.connRequests.CloseAndRemoveAll()
947 db.mu.Unlock()
948 for _, fn := range fns {
949 err1 := fn()
950 if err1 != nil {
951 err = err1
952 }
953 }
954 db.stop()
955 if c, ok := db.connector.(io.Closer); ok {
956 err1 := c.Close()
957 if err1 != nil {
958 err = err1
959 }
960 }
961 return err
962 }
963
964 const defaultMaxIdleConns = 2
965
966 func (db *DB) maxIdleConnsLocked() int {
967 n := db.maxIdleCount
968 switch {
969 case n == 0:
970
971 return defaultMaxIdleConns
972 case n < 0:
973 return 0
974 default:
975 return n
976 }
977 }
978
979 func (db *DB) shortestIdleTimeLocked() time.Duration {
980 if db.maxIdleTime <= 0 {
981 return db.maxLifetime
982 }
983 if db.maxLifetime <= 0 {
984 return db.maxIdleTime
985 }
986 return min(db.maxIdleTime, db.maxLifetime)
987 }
988
989
990
991
992
993
994
995
996
997
998
999 func (db *DB) SetMaxIdleConns(n int) {
1000 db.mu.Lock()
1001 if n > 0 {
1002 db.maxIdleCount = n
1003 } else {
1004
1005 db.maxIdleCount = -1
1006 }
1007
1008 if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
1009 db.maxIdleCount = db.maxOpen
1010 }
1011 var closing []*driverConn
1012 idleCount := len(db.freeConn)
1013 maxIdle := db.maxIdleConnsLocked()
1014 if idleCount > maxIdle {
1015 closing = db.freeConn[maxIdle:]
1016 db.freeConn = db.freeConn[:maxIdle]
1017 }
1018 db.maxIdleClosed += int64(len(closing))
1019 db.mu.Unlock()
1020 for _, c := range closing {
1021 c.Close()
1022 }
1023 }
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033 func (db *DB) SetMaxOpenConns(n int) {
1034 db.mu.Lock()
1035 db.maxOpen = n
1036 if n < 0 {
1037 db.maxOpen = 0
1038 }
1039 syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
1040 db.mu.Unlock()
1041 if syncMaxIdle {
1042 db.SetMaxIdleConns(n)
1043 }
1044 }
1045
1046
1047
1048
1049
1050
1051 func (db *DB) SetConnMaxLifetime(d time.Duration) {
1052 if d < 0 {
1053 d = 0
1054 }
1055 db.mu.Lock()
1056
1057 if d > 0 && d < db.shortestIdleTimeLocked() && db.cleanerCh != nil {
1058 select {
1059 case db.cleanerCh <- struct{}{}:
1060 default:
1061 }
1062 }
1063 db.maxLifetime = d
1064 db.startCleanerLocked()
1065 db.mu.Unlock()
1066 }
1067
1068
1069
1070
1071
1072
1073 func (db *DB) SetConnMaxIdleTime(d time.Duration) {
1074 if d < 0 {
1075 d = 0
1076 }
1077 db.mu.Lock()
1078 defer db.mu.Unlock()
1079
1080
1081 if d > 0 && d < db.shortestIdleTimeLocked() && db.cleanerCh != nil {
1082 select {
1083 case db.cleanerCh <- struct{}{}:
1084 default:
1085 }
1086 }
1087 db.maxIdleTime = d
1088 db.startCleanerLocked()
1089 }
1090
1091
1092 func (db *DB) startCleanerLocked() {
1093 if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
1094 db.cleanerCh = make(chan struct{}, 1)
1095 go db.connectionCleaner(db.shortestIdleTimeLocked())
1096 }
1097 }
1098
1099 func (db *DB) connectionCleaner(d time.Duration) {
1100 const minInterval = time.Second
1101
1102 if d < minInterval {
1103 d = minInterval
1104 }
1105 t := time.NewTimer(d)
1106
1107 for {
1108 select {
1109 case <-t.C:
1110 case <-db.cleanerCh:
1111 }
1112
1113 db.mu.Lock()
1114
1115 d = db.shortestIdleTimeLocked()
1116 if db.closed || db.numOpen == 0 || d <= 0 {
1117 db.cleanerCh = nil
1118 db.mu.Unlock()
1119 return
1120 }
1121
1122 d, closing := db.connectionCleanerRunLocked(d)
1123 db.mu.Unlock()
1124 for _, c := range closing {
1125 c.Close()
1126 }
1127
1128 if d < minInterval {
1129 d = minInterval
1130 }
1131
1132 if !t.Stop() {
1133 select {
1134 case <-t.C:
1135 default:
1136 }
1137 }
1138 t.Reset(d)
1139 }
1140 }
1141
1142
1143
1144
1145 func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
1146 var idleClosing int64
1147 var closing []*driverConn
1148 if db.maxIdleTime > 0 {
1149
1150
1151 idleSince := time.Now().Add(-db.maxIdleTime)
1152 last := len(db.freeConn) - 1
1153 for i := last; i >= 0; i-- {
1154 c := db.freeConn[i]
1155 if c.returnedAt.Before(idleSince) {
1156 i++
1157 closing = db.freeConn[:i:i]
1158 db.freeConn = db.freeConn[i:]
1159 idleClosing = int64(len(closing))
1160 db.maxIdleTimeClosed += idleClosing
1161 break
1162 }
1163 }
1164
1165 if len(db.freeConn) > 0 {
1166 c := db.freeConn[0]
1167 if d2 := c.returnedAt.Sub(idleSince); d2 < d {
1168
1169
1170 d = d2
1171 }
1172 }
1173 }
1174
1175 if db.maxLifetime > 0 {
1176 expiredSince := time.Now().Add(-db.maxLifetime)
1177 for i := 0; i < len(db.freeConn); i++ {
1178 c := db.freeConn[i]
1179 if c.createdAt.Before(expiredSince) {
1180 closing = append(closing, c)
1181
1182 last := len(db.freeConn) - 1
1183
1184
1185 copy(db.freeConn[i:], db.freeConn[i+1:])
1186 db.freeConn[last] = nil
1187 db.freeConn = db.freeConn[:last]
1188 i--
1189 } else if d2 := c.createdAt.Sub(expiredSince); d2 < d {
1190
1191
1192 d = d2
1193 }
1194 }
1195 db.maxLifetimeClosed += int64(len(closing)) - idleClosing
1196 }
1197
1198 return d, closing
1199 }
1200
1201
1202 type DBStats struct {
1203 MaxOpenConnections int
1204
1205
1206 OpenConnections int
1207 InUse int
1208 Idle int
1209
1210
1211 WaitCount int64
1212 WaitDuration time.Duration
1213 MaxIdleClosed int64
1214 MaxIdleTimeClosed int64
1215 MaxLifetimeClosed int64
1216 }
1217
1218
1219 func (db *DB) Stats() DBStats {
1220 wait := db.waitDuration.Load()
1221
1222 db.mu.Lock()
1223 defer db.mu.Unlock()
1224
1225 stats := DBStats{
1226 MaxOpenConnections: db.maxOpen,
1227
1228 Idle: len(db.freeConn),
1229 OpenConnections: db.numOpen,
1230 InUse: db.numOpen - len(db.freeConn),
1231
1232 WaitCount: db.waitCount,
1233 WaitDuration: time.Duration(wait),
1234 MaxIdleClosed: db.maxIdleClosed,
1235 MaxIdleTimeClosed: db.maxIdleTimeClosed,
1236 MaxLifetimeClosed: db.maxLifetimeClosed,
1237 }
1238 return stats
1239 }
1240
1241
1242
1243
1244 func (db *DB) maybeOpenNewConnections() {
1245 numRequests := db.connRequests.Len()
1246 if db.maxOpen > 0 {
1247 numCanOpen := db.maxOpen - db.numOpen
1248 if numRequests > numCanOpen {
1249 numRequests = numCanOpen
1250 }
1251 }
1252 for numRequests > 0 {
1253 db.numOpen++
1254 numRequests--
1255 if db.closed {
1256 return
1257 }
1258 db.openerCh <- struct{}{}
1259 }
1260 }
1261
1262
1263 func (db *DB) connectionOpener(ctx context.Context) {
1264 for {
1265 select {
1266 case <-ctx.Done():
1267 return
1268 case <-db.openerCh:
1269 db.openNewConnection(ctx)
1270 }
1271 }
1272 }
1273
1274
1275 func (db *DB) openNewConnection(ctx context.Context) {
1276
1277
1278
1279 ci, err := db.connector.Connect(ctx)
1280 db.mu.Lock()
1281 defer db.mu.Unlock()
1282 if db.closed {
1283 if err == nil {
1284 ci.Close()
1285 }
1286 db.numOpen--
1287 return
1288 }
1289 if err != nil {
1290 db.numOpen--
1291 db.putConnDBLocked(nil, err)
1292 db.maybeOpenNewConnections()
1293 return
1294 }
1295 dc := &driverConn{
1296 db: db,
1297 createdAt: time.Now(),
1298 returnedAt: time.Now(),
1299 ci: ci,
1300 }
1301 if db.putConnDBLocked(dc, err) {
1302 db.addDepLocked(dc, dc)
1303 } else {
1304 db.numOpen--
1305 ci.Close()
1306 }
1307 }
1308
1309
1310
1311
1312 type connRequest struct {
1313 conn *driverConn
1314 err error
1315 }
1316
1317 var errDBClosed = errors.New("sql: database is closed")
1318
1319
1320 func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
1321 db.mu.Lock()
1322 if db.closed {
1323 db.mu.Unlock()
1324 return nil, errDBClosed
1325 }
1326
1327 select {
1328 default:
1329 case <-ctx.Done():
1330 db.mu.Unlock()
1331 return nil, ctx.Err()
1332 }
1333 lifetime := db.maxLifetime
1334
1335
1336 last := len(db.freeConn) - 1
1337 if strategy == cachedOrNewConn && last >= 0 {
1338
1339
1340 conn := db.freeConn[last]
1341 db.freeConn = db.freeConn[:last]
1342 conn.inUse = true
1343 if conn.expired(lifetime) {
1344 db.maxLifetimeClosed++
1345 db.mu.Unlock()
1346 conn.Close()
1347 return nil, driver.ErrBadConn
1348 }
1349 db.mu.Unlock()
1350
1351
1352 if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
1353 conn.Close()
1354 return nil, err
1355 }
1356
1357 return conn, nil
1358 }
1359
1360
1361
1362 if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
1363
1364
1365 req := make(chan connRequest, 1)
1366 delHandle := db.connRequests.Add(req)
1367 db.waitCount++
1368 db.mu.Unlock()
1369
1370 waitStart := time.Now()
1371
1372
1373 select {
1374 case <-ctx.Done():
1375
1376
1377 db.mu.Lock()
1378 deleted := db.connRequests.Delete(delHandle)
1379 db.mu.Unlock()
1380
1381 db.waitDuration.Add(int64(time.Since(waitStart)))
1382
1383
1384
1385 if !deleted {
1386
1387
1388
1389
1390
1391
1392 select {
1393 default:
1394 case ret, ok := <-req:
1395 if ok && ret.conn != nil {
1396 db.putConn(ret.conn, ret.err, false)
1397 }
1398 }
1399 }
1400 return nil, ctx.Err()
1401 case ret, ok := <-req:
1402 db.waitDuration.Add(int64(time.Since(waitStart)))
1403
1404 if !ok {
1405 return nil, errDBClosed
1406 }
1407
1408
1409
1410
1411
1412
1413 if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
1414 db.mu.Lock()
1415 db.maxLifetimeClosed++
1416 db.mu.Unlock()
1417 ret.conn.Close()
1418 return nil, driver.ErrBadConn
1419 }
1420 if ret.conn == nil {
1421 return nil, ret.err
1422 }
1423
1424
1425 if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
1426 ret.conn.Close()
1427 return nil, err
1428 }
1429 return ret.conn, ret.err
1430 }
1431 }
1432
1433 db.numOpen++
1434 db.mu.Unlock()
1435 ci, err := db.connector.Connect(ctx)
1436 if err != nil {
1437 db.mu.Lock()
1438 db.numOpen--
1439 db.maybeOpenNewConnections()
1440 db.mu.Unlock()
1441 return nil, err
1442 }
1443 db.mu.Lock()
1444 dc := &driverConn{
1445 db: db,
1446 createdAt: time.Now(),
1447 returnedAt: time.Now(),
1448 ci: ci,
1449 inUse: true,
1450 }
1451 db.addDepLocked(dc, dc)
1452 db.mu.Unlock()
1453 return dc, nil
1454 }
1455
1456
1457 var putConnHook func(*DB, *driverConn)
1458
1459
1460
1461
1462 func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
1463 db.mu.Lock()
1464 defer db.mu.Unlock()
1465 if c.inUse {
1466 c.onPut = append(c.onPut, func() {
1467 ds.Close()
1468 })
1469 } else {
1470 c.Lock()
1471 fc := c.finalClosed
1472 c.Unlock()
1473 if !fc {
1474 ds.Close()
1475 }
1476 }
1477 }
1478
1479
1480
1481 const debugGetPut = false
1482
1483
1484
1485 func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
1486 if !errors.Is(err, driver.ErrBadConn) {
1487 if !dc.validateConnection(resetSession) {
1488 err = driver.ErrBadConn
1489 }
1490 }
1491 db.mu.Lock()
1492 if !dc.inUse {
1493 db.mu.Unlock()
1494 if debugGetPut {
1495 fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
1496 }
1497 panic("sql: connection returned that was never out")
1498 }
1499
1500 if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
1501 db.maxLifetimeClosed++
1502 err = driver.ErrBadConn
1503 }
1504 if debugGetPut {
1505 db.lastPut[dc] = stack()
1506 }
1507 dc.inUse = false
1508 dc.returnedAt = time.Now()
1509
1510 for _, fn := range dc.onPut {
1511 fn()
1512 }
1513 dc.onPut = nil
1514
1515 if errors.Is(err, driver.ErrBadConn) {
1516
1517
1518
1519
1520 db.maybeOpenNewConnections()
1521 db.mu.Unlock()
1522 dc.Close()
1523 return
1524 }
1525 if putConnHook != nil {
1526 putConnHook(db, dc)
1527 }
1528 added := db.putConnDBLocked(dc, nil)
1529 db.mu.Unlock()
1530
1531 if !added {
1532 dc.Close()
1533 return
1534 }
1535 }
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546 func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
1547 if db.closed {
1548 return false
1549 }
1550 if db.maxOpen > 0 && db.numOpen > db.maxOpen {
1551 return false
1552 }
1553 if req, ok := db.connRequests.TakeRandom(); ok {
1554 if err == nil {
1555 dc.inUse = true
1556 }
1557 req <- connRequest{
1558 conn: dc,
1559 err: err,
1560 }
1561 return true
1562 } else if err == nil && !db.closed {
1563 if db.maxIdleConnsLocked() > len(db.freeConn) {
1564 db.freeConn = append(db.freeConn, dc)
1565 db.startCleanerLocked()
1566 return true
1567 }
1568 db.maxIdleClosed++
1569 }
1570 return false
1571 }
1572
1573
1574
1575
1576 const maxBadConnRetries = 2
1577
1578 func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
1579 for i := int64(0); i < maxBadConnRetries; i++ {
1580 err := fn(cachedOrNewConn)
1581
1582 if err == nil || !errors.Is(err, driver.ErrBadConn) {
1583 return err
1584 }
1585 }
1586
1587 return fn(alwaysNewConn)
1588 }
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1599 var stmt *Stmt
1600 var err error
1601
1602 err = db.retry(func(strategy connReuseStrategy) error {
1603 stmt, err = db.prepare(ctx, query, strategy)
1604 return err
1605 })
1606
1607 return stmt, err
1608 }
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618 func (db *DB) Prepare(query string) (*Stmt, error) {
1619 return db.PrepareContext(context.Background(), query)
1620 }
1621
1622 func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
1623
1624
1625
1626
1627
1628
1629 dc, err := db.conn(ctx, strategy)
1630 if err != nil {
1631 return nil, err
1632 }
1633 return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
1634 }
1635
1636
1637
1638
1639 func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
1640 var ds *driverStmt
1641 var err error
1642 defer func() {
1643 release(err)
1644 }()
1645 withLock(dc, func() {
1646 ds, err = dc.prepareLocked(ctx, cg, query)
1647 })
1648 if err != nil {
1649 return nil, err
1650 }
1651 stmt := &Stmt{
1652 db: db,
1653 query: query,
1654 cg: cg,
1655 cgds: ds,
1656 }
1657
1658
1659
1660
1661 if cg == nil {
1662 stmt.css = []connStmt{{dc, ds}}
1663 stmt.lastNumClosed = db.numClosed.Load()
1664 db.addDep(stmt, stmt)
1665 }
1666 return stmt, nil
1667 }
1668
1669
1670
1671 func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
1672 var res Result
1673 var err error
1674
1675 err = db.retry(func(strategy connReuseStrategy) error {
1676 res, err = db.exec(ctx, query, args, strategy)
1677 return err
1678 })
1679
1680 return res, err
1681 }
1682
1683
1684
1685
1686
1687
1688 func (db *DB) Exec(query string, args ...any) (Result, error) {
1689 return db.ExecContext(context.Background(), query, args...)
1690 }
1691
1692 func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
1693 dc, err := db.conn(ctx, strategy)
1694 if err != nil {
1695 return nil, err
1696 }
1697 return db.execDC(ctx, dc, dc.releaseConn, query, args)
1698 }
1699
1700 func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
1701 defer func() {
1702 release(err)
1703 }()
1704 execerCtx, ok := dc.ci.(driver.ExecerContext)
1705 var execer driver.Execer
1706 if !ok {
1707 execer, ok = dc.ci.(driver.Execer)
1708 }
1709 if ok {
1710 var nvdargs []driver.NamedValue
1711 var resi driver.Result
1712 withLock(dc, func() {
1713 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1714 if err != nil {
1715 return
1716 }
1717 resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
1718 })
1719 if err != driver.ErrSkip {
1720 if err != nil {
1721 return nil, err
1722 }
1723 return driverResult{dc, resi}, nil
1724 }
1725 }
1726
1727 var si driver.Stmt
1728 withLock(dc, func() {
1729 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1730 })
1731 if err != nil {
1732 return nil, err
1733 }
1734 ds := &driverStmt{Locker: dc, si: si}
1735 defer ds.Close()
1736 return resultFromStatement(ctx, dc.ci, ds, args...)
1737 }
1738
1739
1740
1741 func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
1742 var rows *Rows
1743 var err error
1744
1745 err = db.retry(func(strategy connReuseStrategy) error {
1746 rows, err = db.query(ctx, query, args, strategy)
1747 return err
1748 })
1749
1750 return rows, err
1751 }
1752
1753
1754
1755
1756
1757
1758 func (db *DB) Query(query string, args ...any) (*Rows, error) {
1759 return db.QueryContext(context.Background(), query, args...)
1760 }
1761
1762 func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
1763 dc, err := db.conn(ctx, strategy)
1764 if err != nil {
1765 return nil, err
1766 }
1767
1768 return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
1769 }
1770
1771
1772
1773
1774
1775 func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
1776 queryerCtx, ok := dc.ci.(driver.QueryerContext)
1777 var queryer driver.Queryer
1778 if !ok {
1779 queryer, ok = dc.ci.(driver.Queryer)
1780 }
1781 if ok {
1782 var nvdargs []driver.NamedValue
1783 var rowsi driver.Rows
1784 var err error
1785 withLock(dc, func() {
1786 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1787 if err != nil {
1788 return
1789 }
1790 rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
1791 })
1792 if err != driver.ErrSkip {
1793 if err != nil {
1794 releaseConn(err)
1795 return nil, err
1796 }
1797
1798
1799 rows := &Rows{
1800 dc: dc,
1801 releaseConn: releaseConn,
1802 rowsi: rowsi,
1803 }
1804 rows.initContextClose(ctx, txctx)
1805 return rows, nil
1806 }
1807 }
1808
1809 var si driver.Stmt
1810 var err error
1811 withLock(dc, func() {
1812 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1813 })
1814 if err != nil {
1815 releaseConn(err)
1816 return nil, err
1817 }
1818
1819 ds := &driverStmt{Locker: dc, si: si}
1820 rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
1821 if err != nil {
1822 ds.Close()
1823 releaseConn(err)
1824 return nil, err
1825 }
1826
1827
1828
1829 rows := &Rows{
1830 dc: dc,
1831 releaseConn: releaseConn,
1832 rowsi: rowsi,
1833 closeStmt: ds,
1834 }
1835 rows.initContextClose(ctx, txctx)
1836 return rows, nil
1837 }
1838
1839
1840
1841
1842
1843
1844
1845 func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
1846 rows, err := db.QueryContext(ctx, query, args...)
1847 return &Row{rows: rows, err: err}
1848 }
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859 func (db *DB) QueryRow(query string, args ...any) *Row {
1860 return db.QueryRowContext(context.Background(), query, args...)
1861 }
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873 func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
1874 var tx *Tx
1875 var err error
1876
1877 err = db.retry(func(strategy connReuseStrategy) error {
1878 tx, err = db.begin(ctx, opts, strategy)
1879 return err
1880 })
1881
1882 return tx, err
1883 }
1884
1885
1886
1887
1888
1889
1890 func (db *DB) Begin() (*Tx, error) {
1891 return db.BeginTx(context.Background(), nil)
1892 }
1893
1894 func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
1895 dc, err := db.conn(ctx, strategy)
1896 if err != nil {
1897 return nil, err
1898 }
1899 return db.beginDC(ctx, dc, dc.releaseConn, opts)
1900 }
1901
1902
1903 func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
1904 var txi driver.Tx
1905 keepConnOnRollback := false
1906 withLock(dc, func() {
1907 _, hasSessionResetter := dc.ci.(driver.SessionResetter)
1908 _, hasConnectionValidator := dc.ci.(driver.Validator)
1909 keepConnOnRollback = hasSessionResetter && hasConnectionValidator
1910 txi, err = ctxDriverBegin(ctx, opts, dc.ci)
1911 })
1912 if err != nil {
1913 release(err)
1914 return nil, err
1915 }
1916
1917
1918
1919 ctx, cancel := context.WithCancel(ctx)
1920 tx = &Tx{
1921 db: db,
1922 dc: dc,
1923 releaseConn: release,
1924 txi: txi,
1925 cancel: cancel,
1926 keepConnOnRollback: keepConnOnRollback,
1927 ctx: ctx,
1928 }
1929 go tx.awaitDone()
1930 return tx, nil
1931 }
1932
1933
1934 func (db *DB) Driver() driver.Driver {
1935 return db.connector.Driver()
1936 }
1937
1938
1939
1940 var ErrConnDone = errors.New("sql: connection is already closed")
1941
1942
1943
1944
1945
1946
1947
1948
1949 func (db *DB) Conn(ctx context.Context) (*Conn, error) {
1950 var dc *driverConn
1951 var err error
1952
1953 err = db.retry(func(strategy connReuseStrategy) error {
1954 dc, err = db.conn(ctx, strategy)
1955 return err
1956 })
1957
1958 if err != nil {
1959 return nil, err
1960 }
1961
1962 conn := &Conn{
1963 db: db,
1964 dc: dc,
1965 }
1966 return conn, nil
1967 }
1968
1969 type releaseConn func(error)
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980 type Conn struct {
1981 db *DB
1982
1983
1984
1985
1986 closemu closingMutex
1987
1988
1989
1990 dc *driverConn
1991
1992
1993
1994 done atomic.Bool
1995
1996 releaseConnOnce sync.Once
1997
1998
1999 releaseConnCache releaseConn
2000 }
2001
2002
2003
2004 func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
2005 if c.done.Load() {
2006 return nil, nil, ErrConnDone
2007 }
2008 c.releaseConnOnce.Do(func() {
2009 c.releaseConnCache = c.closemuRUnlockCondReleaseConn
2010 })
2011 c.closemu.RLock()
2012 return c.dc, c.releaseConnCache, nil
2013 }
2014
2015
2016 func (c *Conn) PingContext(ctx context.Context) error {
2017 dc, release, err := c.grabConn(ctx)
2018 if err != nil {
2019 return err
2020 }
2021 return c.db.pingDC(ctx, dc, release)
2022 }
2023
2024
2025
2026 func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
2027 dc, release, err := c.grabConn(ctx)
2028 if err != nil {
2029 return nil, err
2030 }
2031 return c.db.execDC(ctx, dc, release, query, args)
2032 }
2033
2034
2035
2036 func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
2037 dc, release, err := c.grabConn(ctx)
2038 if err != nil {
2039 return nil, err
2040 }
2041 return c.db.queryDC(ctx, nil, dc, release, query, args)
2042 }
2043
2044
2045
2046
2047
2048
2049
2050 func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
2051 rows, err := c.QueryContext(ctx, query, args...)
2052 return &Row{rows: rows, err: err}
2053 }
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063 func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2064 dc, release, err := c.grabConn(ctx)
2065 if err != nil {
2066 return nil, err
2067 }
2068 return c.db.prepareDC(ctx, dc, release, c, query)
2069 }
2070
2071
2072
2073
2074
2075
2076 func (c *Conn) Raw(f func(driverConn any) error) (err error) {
2077 var dc *driverConn
2078 var release releaseConn
2079
2080
2081 dc, release, err = c.grabConn(nil)
2082 if err != nil {
2083 return
2084 }
2085 fPanic := true
2086 dc.Mutex.Lock()
2087 defer func() {
2088 dc.Mutex.Unlock()
2089
2090
2091
2092
2093 if fPanic {
2094 err = driver.ErrBadConn
2095 }
2096 release(err)
2097 }()
2098 err = f(dc.ci)
2099 fPanic = false
2100
2101 return
2102 }
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114 func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
2115 dc, release, err := c.grabConn(ctx)
2116 if err != nil {
2117 return nil, err
2118 }
2119 return c.db.beginDC(ctx, dc, release, opts)
2120 }
2121
2122
2123
2124 func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
2125 c.closemu.RUnlock()
2126 if errors.Is(err, driver.ErrBadConn) {
2127 c.close(err)
2128 }
2129 }
2130
2131 func (c *Conn) txCtx() context.Context {
2132 return nil
2133 }
2134
2135 func (c *Conn) close(err error) error {
2136 if !c.done.CompareAndSwap(false, true) {
2137 return ErrConnDone
2138 }
2139
2140
2141
2142 c.closemu.Lock()
2143 defer c.closemu.Unlock()
2144
2145 c.dc.releaseConn(err)
2146 c.dc = nil
2147 c.db = nil
2148 return err
2149 }
2150
2151
2152
2153
2154
2155
2156 func (c *Conn) Close() error {
2157 return c.close(nil)
2158 }
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170 type Tx struct {
2171 db *DB
2172
2173
2174
2175
2176 closemu closingMutex
2177
2178
2179
2180 dc *driverConn
2181 txi driver.Tx
2182
2183
2184
2185 releaseConn func(error)
2186
2187
2188
2189
2190 done atomic.Bool
2191
2192
2193
2194
2195 keepConnOnRollback bool
2196
2197
2198
2199 stmts struct {
2200 sync.Mutex
2201 v []*Stmt
2202 }
2203
2204
2205 cancel func()
2206
2207
2208 ctx context.Context
2209 }
2210
2211
2212
2213 func (tx *Tx) awaitDone() {
2214
2215
2216 <-tx.ctx.Done()
2217
2218
2219
2220
2221
2222
2223
2224 discardConnection := !tx.keepConnOnRollback
2225 tx.rollback(discardConnection)
2226 }
2227
2228 func (tx *Tx) isDone() bool {
2229 return tx.done.Load()
2230 }
2231
2232
2233
2234 var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
2235
2236
2237
2238
2239 func (tx *Tx) close(err error) {
2240 tx.releaseConn(err)
2241 tx.dc = nil
2242 tx.txi = nil
2243 }
2244
2245
2246
2247 var hookTxGrabConn func()
2248
2249 func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
2250 select {
2251 default:
2252 case <-ctx.Done():
2253 return nil, nil, ctx.Err()
2254 }
2255
2256
2257
2258 tx.closemu.RLock()
2259 if tx.isDone() {
2260 tx.closemu.RUnlock()
2261 return nil, nil, ErrTxDone
2262 }
2263 if hookTxGrabConn != nil {
2264 hookTxGrabConn()
2265 }
2266 return tx.dc, tx.closemuRUnlockRelease, nil
2267 }
2268
2269 func (tx *Tx) txCtx() context.Context {
2270 return tx.ctx
2271 }
2272
2273
2274
2275
2276
2277 func (tx *Tx) closemuRUnlockRelease(error) {
2278 tx.closemu.RUnlock()
2279 }
2280
2281
2282 func (tx *Tx) closePrepared() {
2283 tx.stmts.Lock()
2284 defer tx.stmts.Unlock()
2285 for _, stmt := range tx.stmts.v {
2286 stmt.Close()
2287 }
2288 }
2289
2290
2291 func (tx *Tx) Commit() error {
2292
2293
2294
2295 select {
2296 default:
2297 case <-tx.ctx.Done():
2298 if tx.done.Load() {
2299 return ErrTxDone
2300 }
2301 return tx.ctx.Err()
2302 }
2303 if !tx.done.CompareAndSwap(false, true) {
2304 return ErrTxDone
2305 }
2306
2307
2308
2309
2310
2311 tx.cancel()
2312 tx.closemu.Lock()
2313 tx.closemu.Unlock()
2314
2315 var err error
2316 withLock(tx.dc, func() {
2317 err = tx.txi.Commit()
2318 })
2319 if !errors.Is(err, driver.ErrBadConn) {
2320 tx.closePrepared()
2321 }
2322 tx.close(err)
2323 return err
2324 }
2325
2326 var rollbackHook func()
2327
2328
2329
2330 func (tx *Tx) rollback(discardConn bool) error {
2331 if !tx.done.CompareAndSwap(false, true) {
2332 return ErrTxDone
2333 }
2334
2335 if rollbackHook != nil {
2336 rollbackHook()
2337 }
2338
2339
2340
2341
2342
2343 tx.cancel()
2344 tx.closemu.Lock()
2345 tx.closemu.Unlock()
2346
2347 var err error
2348 withLock(tx.dc, func() {
2349 err = tx.txi.Rollback()
2350 })
2351 if !errors.Is(err, driver.ErrBadConn) {
2352 tx.closePrepared()
2353 }
2354 if discardConn {
2355 err = driver.ErrBadConn
2356 }
2357 tx.close(err)
2358 return err
2359 }
2360
2361
2362 func (tx *Tx) Rollback() error {
2363 return tx.rollback(false)
2364 }
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2377 dc, release, err := tx.grabConn(ctx)
2378 if err != nil {
2379 return nil, err
2380 }
2381
2382 stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
2383 if err != nil {
2384 return nil, err
2385 }
2386 tx.stmts.Lock()
2387 tx.stmts.v = append(tx.stmts.v, stmt)
2388 tx.stmts.Unlock()
2389 return stmt, nil
2390 }
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401 func (tx *Tx) Prepare(query string) (*Stmt, error) {
2402 return tx.PrepareContext(context.Background(), query)
2403 }
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421 func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
2422 dc, release, err := tx.grabConn(ctx)
2423 if err != nil {
2424 return &Stmt{stickyErr: err}
2425 }
2426 defer release(nil)
2427
2428 if tx.db != stmt.db {
2429 return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
2430 }
2431 var si driver.Stmt
2432 var parentStmt *Stmt
2433 stmt.mu.Lock()
2434 if stmt.closed || stmt.cg != nil {
2435
2436
2437
2438
2439
2440
2441 stmt.mu.Unlock()
2442 withLock(dc, func() {
2443 si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
2444 })
2445 if err != nil {
2446 return &Stmt{stickyErr: err}
2447 }
2448 } else {
2449 stmt.removeClosedStmtLocked()
2450
2451
2452 for _, v := range stmt.css {
2453 if v.dc == dc {
2454 si = v.ds.si
2455 break
2456 }
2457 }
2458
2459 stmt.mu.Unlock()
2460
2461 if si == nil {
2462 var ds *driverStmt
2463 withLock(dc, func() {
2464 ds, err = stmt.prepareOnConnLocked(ctx, dc)
2465 })
2466 if err != nil {
2467 return &Stmt{stickyErr: err}
2468 }
2469 si = ds.si
2470 }
2471 parentStmt = stmt
2472 }
2473
2474 txs := &Stmt{
2475 db: tx.db,
2476 cg: tx,
2477 cgds: &driverStmt{
2478 Locker: dc,
2479 si: si,
2480 },
2481 parentStmt: parentStmt,
2482 query: stmt.query,
2483 }
2484 if parentStmt != nil {
2485 tx.db.addDep(parentStmt, txs)
2486 }
2487 tx.stmts.Lock()
2488 tx.stmts.v = append(tx.stmts.v, txs)
2489 tx.stmts.Unlock()
2490 return txs
2491 }
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509 func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
2510 return tx.StmtContext(context.Background(), stmt)
2511 }
2512
2513
2514
2515 func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
2516 dc, release, err := tx.grabConn(ctx)
2517 if err != nil {
2518 return nil, err
2519 }
2520 return tx.db.execDC(ctx, dc, release, query, args)
2521 }
2522
2523
2524
2525
2526
2527
2528 func (tx *Tx) Exec(query string, args ...any) (Result, error) {
2529 return tx.ExecContext(context.Background(), query, args...)
2530 }
2531
2532
2533 func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
2534 dc, release, err := tx.grabConn(ctx)
2535 if err != nil {
2536 return nil, err
2537 }
2538
2539 return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
2540 }
2541
2542
2543
2544
2545
2546 func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
2547 return tx.QueryContext(context.Background(), query, args...)
2548 }
2549
2550
2551
2552
2553
2554
2555
2556 func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
2557 rows, err := tx.QueryContext(ctx, query, args...)
2558 return &Row{rows: rows, err: err}
2559 }
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570 func (tx *Tx) QueryRow(query string, args ...any) *Row {
2571 return tx.QueryRowContext(context.Background(), query, args...)
2572 }
2573
2574
2575 type connStmt struct {
2576 dc *driverConn
2577 ds *driverStmt
2578 }
2579
2580
2581
2582 type stmtConnGrabber interface {
2583
2584
2585 grabConn(context.Context) (*driverConn, releaseConn, error)
2586
2587
2588
2589
2590 txCtx() context.Context
2591 }
2592
2593 var (
2594 _ stmtConnGrabber = &Tx{}
2595 _ stmtConnGrabber = &Conn{}
2596 )
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607 type Stmt struct {
2608
2609 db *DB
2610 query string
2611 stickyErr error
2612
2613 closemu closingMutex
2614
2615
2616
2617
2618
2619
2620 cg stmtConnGrabber
2621 cgds *driverStmt
2622
2623
2624
2625
2626
2627
2628
2629 parentStmt *Stmt
2630
2631 mu sync.Mutex
2632 closed bool
2633
2634
2635
2636
2637
2638 css []connStmt
2639
2640
2641
2642 lastNumClosed uint64
2643 }
2644
2645
2646
2647 func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
2648 s.closemu.RLock()
2649 defer s.closemu.RUnlock()
2650
2651 var res Result
2652 err := s.db.retry(func(strategy connReuseStrategy) error {
2653 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2654 if err != nil {
2655 return err
2656 }
2657
2658 res, err = resultFromStatement(ctx, dc.ci, ds, args...)
2659 releaseConn(err)
2660 return err
2661 })
2662
2663 return res, err
2664 }
2665
2666
2667
2668
2669
2670
2671 func (s *Stmt) Exec(args ...any) (Result, error) {
2672 return s.ExecContext(context.Background(), args...)
2673 }
2674
2675 func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
2676 ds.Lock()
2677 defer ds.Unlock()
2678
2679 dargs, err := driverArgsConnLocked(ci, ds, args)
2680 if err != nil {
2681 return nil, err
2682 }
2683
2684 resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
2685 if err != nil {
2686 return nil, err
2687 }
2688 return driverResult{ds.Locker, resi}, nil
2689 }
2690
2691
2692
2693
2694
2695 func (s *Stmt) removeClosedStmtLocked() {
2696 t := len(s.css)/2 + 1
2697 if t > 10 {
2698 t = 10
2699 }
2700 dbClosed := s.db.numClosed.Load()
2701 if dbClosed-s.lastNumClosed < uint64(t) {
2702 return
2703 }
2704
2705 s.db.mu.Lock()
2706 for i := 0; i < len(s.css); i++ {
2707 if s.css[i].dc.dbmuClosed {
2708 s.css[i] = s.css[len(s.css)-1]
2709
2710 s.css[len(s.css)-1] = connStmt{}
2711 s.css = s.css[:len(s.css)-1]
2712 i--
2713 }
2714 }
2715 s.db.mu.Unlock()
2716 s.lastNumClosed = dbClosed
2717 }
2718
2719
2720
2721
2722 func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
2723 if err = s.stickyErr; err != nil {
2724 return
2725 }
2726 s.mu.Lock()
2727 if s.closed {
2728 s.mu.Unlock()
2729 err = errors.New("sql: statement is closed")
2730 return
2731 }
2732
2733
2734
2735 if s.cg != nil {
2736 s.mu.Unlock()
2737 dc, releaseConn, err = s.cg.grabConn(ctx)
2738 if err != nil {
2739 return
2740 }
2741 return dc, releaseConn, s.cgds, nil
2742 }
2743
2744 s.removeClosedStmtLocked()
2745 s.mu.Unlock()
2746
2747 dc, err = s.db.conn(ctx, strategy)
2748 if err != nil {
2749 return nil, nil, nil, err
2750 }
2751
2752 s.mu.Lock()
2753 for _, v := range s.css {
2754 if v.dc == dc {
2755 s.mu.Unlock()
2756 return dc, dc.releaseConn, v.ds, nil
2757 }
2758 }
2759 s.mu.Unlock()
2760
2761
2762 withLock(dc, func() {
2763 ds, err = s.prepareOnConnLocked(ctx, dc)
2764 })
2765 if err != nil {
2766 dc.releaseConn(err)
2767 return nil, nil, nil, err
2768 }
2769
2770 return dc, dc.releaseConn, ds, nil
2771 }
2772
2773
2774
2775 func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
2776 si, err := dc.prepareLocked(ctx, s.cg, s.query)
2777 if err != nil {
2778 return nil, err
2779 }
2780 cs := connStmt{dc, si}
2781 s.mu.Lock()
2782 s.css = append(s.css, cs)
2783 s.mu.Unlock()
2784 return cs.ds, nil
2785 }
2786
2787
2788
2789 func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
2790 s.closemu.RLock()
2791 defer s.closemu.RUnlock()
2792
2793 var rowsi driver.Rows
2794 var rows *Rows
2795
2796 err := s.db.retry(func(strategy connReuseStrategy) error {
2797 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2798 if err != nil {
2799 return err
2800 }
2801
2802 rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
2803 if err == nil {
2804
2805
2806 rows = &Rows{
2807 dc: dc,
2808 rowsi: rowsi,
2809
2810 }
2811
2812
2813 s.db.addDep(s, rows)
2814
2815
2816
2817 rows.releaseConn = func(err error) {
2818 releaseConn(err)
2819 s.db.removeDep(s, rows)
2820 }
2821 var txctx context.Context
2822 if s.cg != nil {
2823 txctx = s.cg.txCtx()
2824 }
2825 rows.initContextClose(ctx, txctx)
2826 return nil
2827 }
2828
2829 releaseConn(err)
2830 return err
2831 })
2832
2833 return rows, err
2834 }
2835
2836
2837
2838
2839
2840
2841 func (s *Stmt) Query(args ...any) (*Rows, error) {
2842 return s.QueryContext(context.Background(), args...)
2843 }
2844
2845 func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
2846 ds.Lock()
2847 defer ds.Unlock()
2848 dargs, err := driverArgsConnLocked(ci, ds, args)
2849 if err != nil {
2850 return nil, err
2851 }
2852 return ctxDriverStmtQuery(ctx, ds.si, dargs)
2853 }
2854
2855
2856
2857
2858
2859
2860
2861 func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row {
2862 rows, err := s.QueryContext(ctx, args...)
2863 if err != nil {
2864 return &Row{err: err}
2865 }
2866 return &Row{rows: rows}
2867 }
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883 func (s *Stmt) QueryRow(args ...any) *Row {
2884 return s.QueryRowContext(context.Background(), args...)
2885 }
2886
2887
2888 func (s *Stmt) Close() error {
2889 s.closemu.Lock()
2890 defer s.closemu.Unlock()
2891
2892 if s.stickyErr != nil {
2893 return s.stickyErr
2894 }
2895 s.mu.Lock()
2896 if s.closed {
2897 s.mu.Unlock()
2898 return nil
2899 }
2900 s.closed = true
2901 txds := s.cgds
2902 s.cgds = nil
2903
2904 s.mu.Unlock()
2905
2906 if s.cg == nil {
2907 return s.db.removeDep(s, s)
2908 }
2909
2910 if s.parentStmt != nil {
2911
2912
2913 return s.db.removeDep(s.parentStmt, s)
2914 }
2915 return txds.Close()
2916 }
2917
2918 func (s *Stmt) finalClose() error {
2919 s.mu.Lock()
2920 defer s.mu.Unlock()
2921 if s.css != nil {
2922 for _, v := range s.css {
2923 s.db.noteUnusedDriverStatement(v.dc, v.ds)
2924 v.dc.removeOpenStmt(v.ds)
2925 }
2926 s.css = nil
2927 }
2928 return nil
2929 }
2930
2931
2932
2933 type Rows struct {
2934 dc *driverConn
2935 releaseConn func(error)
2936 rowsi driver.Rows
2937 cancel func()
2938 closeStmt *driverStmt
2939
2940 contextDone atomic.Pointer[error]
2941
2942
2943
2944
2945
2946
2947 closemu closingMutex
2948 lasterr error
2949 closed bool
2950
2951
2952
2953
2954
2955
2956
2957
2958 closemuScanHold bool
2959
2960
2961
2962
2963
2964 hitEOF bool
2965
2966
2967
2968 lastcols []driver.Value
2969
2970
2971
2972
2973
2974
2975 raw []byte
2976 }
2977
2978
2979
2980 func (rs *Rows) lasterrOrErrLocked(err error) error {
2981 if rs.lasterr != nil && rs.lasterr != io.EOF {
2982 return rs.lasterr
2983 }
2984 return err
2985 }
2986
2987
2988
2989 var bypassRowsAwaitDone = false
2990
2991 func (rs *Rows) initContextClose(ctx, txctx context.Context) {
2992 if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
2993 return
2994 }
2995 if bypassRowsAwaitDone {
2996 return
2997 }
2998 closectx, cancel := context.WithCancel(ctx)
2999 rs.cancel = cancel
3000 go rs.awaitDone(ctx, txctx, closectx)
3001 }
3002
3003
3004
3005
3006
3007
3008 func (rs *Rows) awaitDone(ctx, txctx, closectx context.Context) {
3009 var txctxDone <-chan struct{}
3010 if txctx != nil {
3011 txctxDone = txctx.Done()
3012 }
3013 select {
3014 case <-ctx.Done():
3015 err := ctx.Err()
3016 rs.contextDone.Store(&err)
3017 case <-txctxDone:
3018 err := txctx.Err()
3019 rs.contextDone.Store(&err)
3020 case <-closectx.Done():
3021
3022
3023 }
3024 rs.close(ctx.Err())
3025 }
3026
3027
3028
3029
3030
3031
3032
3033 func (rs *Rows) Next() bool {
3034
3035
3036
3037 rs.closemuRUnlockIfHeldByScan()
3038
3039 if rs.contextDone.Load() != nil {
3040 return false
3041 }
3042
3043 var doClose, ok bool
3044 func() {
3045 rs.closemu.RLock()
3046 defer rs.closemu.RUnlock()
3047 doClose, ok = rs.nextLocked()
3048 }()
3049 if doClose {
3050 rs.Close()
3051 }
3052 if doClose && !ok {
3053 rs.hitEOF = true
3054 }
3055 return ok
3056 }
3057
3058 func (rs *Rows) nextLocked() (doClose, ok bool) {
3059 if rs.closed {
3060 return false, false
3061 }
3062
3063
3064
3065 rs.dc.Lock()
3066 defer rs.dc.Unlock()
3067
3068 if rs.lastcols == nil {
3069 rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
3070 }
3071
3072 rs.lasterr = rs.rowsi.Next(rs.lastcols)
3073 if rs.lasterr != nil {
3074
3075 if rs.lasterr != io.EOF {
3076 return true, false
3077 }
3078 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
3079 if !ok {
3080 return true, false
3081 }
3082
3083
3084
3085 if !nextResultSet.HasNextResultSet() {
3086 doClose = true
3087 }
3088 return doClose, false
3089 }
3090 return false, true
3091 }
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101 func (rs *Rows) NextResultSet() bool {
3102
3103
3104
3105 rs.closemuRUnlockIfHeldByScan()
3106
3107 var doClose bool
3108 defer func() {
3109 if doClose {
3110 rs.Close()
3111 }
3112 }()
3113 rs.closemu.RLock()
3114 defer rs.closemu.RUnlock()
3115
3116 if rs.closed {
3117 return false
3118 }
3119
3120 rs.lastcols = nil
3121 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
3122 if !ok {
3123 doClose = true
3124 return false
3125 }
3126
3127
3128
3129 rs.dc.Lock()
3130 defer rs.dc.Unlock()
3131
3132 rs.lasterr = nextResultSet.NextResultSet()
3133 if rs.lasterr != nil {
3134 doClose = true
3135 return false
3136 }
3137 return true
3138 }
3139
3140
3141
3142 func (rs *Rows) Err() error {
3143
3144
3145
3146
3147 if !rs.hitEOF {
3148 if errp := rs.contextDone.Load(); errp != nil {
3149 return *errp
3150 }
3151 }
3152
3153 rs.closemu.RLock()
3154 defer rs.closemu.RUnlock()
3155 return rs.lasterrOrErrLocked(nil)
3156 }
3157
3158
3159
3160
3161
3162
3163
3164 func (rs *Rows) rawbuf() []byte {
3165 if rs == nil {
3166
3167 return nil
3168 }
3169 return rs.raw
3170 }
3171
3172
3173
3174 func (rs *Rows) setrawbuf(b []byte) RawBytes {
3175 if rs == nil {
3176
3177 return RawBytes(b)
3178 }
3179 off := len(rs.raw)
3180 rs.raw = b
3181 return RawBytes(rs.raw[off:])
3182 }
3183
3184 var errRowsClosed = errors.New("sql: Rows are closed")
3185 var errNoRows = errors.New("sql: no Rows available")
3186
3187
3188
3189 func (rs *Rows) Columns() ([]string, error) {
3190 rs.closemu.RLock()
3191 defer rs.closemu.RUnlock()
3192 if rs.closed {
3193 return nil, rs.lasterrOrErrLocked(errRowsClosed)
3194 }
3195 if rs.rowsi == nil {
3196 return nil, rs.lasterrOrErrLocked(errNoRows)
3197 }
3198 rs.dc.Lock()
3199 defer rs.dc.Unlock()
3200
3201 return rs.rowsi.Columns(), nil
3202 }
3203
3204
3205
3206 func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
3207 rs.closemu.RLock()
3208 defer rs.closemu.RUnlock()
3209 if rs.closed {
3210 return nil, rs.lasterrOrErrLocked(errRowsClosed)
3211 }
3212 if rs.rowsi == nil {
3213 return nil, rs.lasterrOrErrLocked(errNoRows)
3214 }
3215 rs.dc.Lock()
3216 defer rs.dc.Unlock()
3217
3218 return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
3219 }
3220
3221
3222 type ColumnType struct {
3223 name string
3224
3225 hasNullable bool
3226 hasLength bool
3227 hasPrecisionScale bool
3228
3229 nullable bool
3230 length int64
3231 databaseType string
3232 precision int64
3233 scale int64
3234 scanType reflect.Type
3235 }
3236
3237
3238 func (ci *ColumnType) Name() string {
3239 return ci.name
3240 }
3241
3242
3243
3244
3245
3246
3247 func (ci *ColumnType) Length() (length int64, ok bool) {
3248 return ci.length, ci.hasLength
3249 }
3250
3251
3252
3253 func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
3254 return ci.precision, ci.scale, ci.hasPrecisionScale
3255 }
3256
3257
3258
3259
3260 func (ci *ColumnType) ScanType() reflect.Type {
3261 return ci.scanType
3262 }
3263
3264
3265
3266 func (ci *ColumnType) Nullable() (nullable, ok bool) {
3267 return ci.nullable, ci.hasNullable
3268 }
3269
3270
3271
3272
3273
3274
3275
3276 func (ci *ColumnType) DatabaseTypeName() string {
3277 return ci.databaseType
3278 }
3279
3280 func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
3281 names := rowsi.Columns()
3282
3283 list := make([]*ColumnType, len(names))
3284 for i := range list {
3285 ci := &ColumnType{
3286 name: names[i],
3287 }
3288 list[i] = ci
3289
3290 if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
3291 ci.scanType = prop.ColumnTypeScanType(i)
3292 } else {
3293 ci.scanType = reflect.TypeFor[any]()
3294 }
3295 if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
3296 ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
3297 }
3298 if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
3299 ci.length, ci.hasLength = prop.ColumnTypeLength(i)
3300 }
3301 if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
3302 ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
3303 }
3304 if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
3305 ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
3306 }
3307 }
3308 return list
3309 }
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371 func (rs *Rows) Scan(dest ...any) error {
3372 if rs.closemuScanHold {
3373
3374
3375 return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
3376 }
3377
3378 rs.closemu.RLock()
3379 rs.raw = rs.raw[:0]
3380 err := rs.scanLocked(dest...)
3381 if err == nil && scanArgsContainRawBytes(dest) {
3382 rs.closemuScanHold = true
3383 } else {
3384 rs.closemu.RUnlock()
3385 }
3386 return err
3387 }
3388
3389 func (rs *Rows) scanLocked(dest ...any) error {
3390 if rs.lasterr != nil && rs.lasterr != io.EOF {
3391 return rs.lasterr
3392 }
3393 if rs.closed {
3394 return rs.lasterrOrErrLocked(errRowsClosed)
3395 }
3396
3397 if rs.lastcols == nil {
3398 return errors.New("sql: Scan called without calling Next")
3399 }
3400 if len(dest) != len(rs.lastcols) {
3401 return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
3402 }
3403
3404 for i, sv := range rs.lastcols {
3405 err := convertAssignRows(dest[i], sv, rs)
3406 if err != nil {
3407 return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
3408 }
3409 }
3410 return nil
3411 }
3412
3413
3414
3415 func (rs *Rows) closemuRUnlockIfHeldByScan() {
3416 if rs.closemuScanHold {
3417 rs.closemuScanHold = false
3418 rs.closemu.RUnlock()
3419 }
3420 }
3421
3422 func scanArgsContainRawBytes(args []any) bool {
3423 for _, a := range args {
3424 if _, ok := a.(*RawBytes); ok {
3425 return true
3426 }
3427 }
3428 return false
3429 }
3430
3431
3432
3433 var rowsCloseHook = func() func(*Rows, *error) { return nil }
3434
3435
3436
3437
3438
3439 func (rs *Rows) Close() error {
3440
3441
3442
3443 rs.closemuRUnlockIfHeldByScan()
3444
3445 return rs.close(nil)
3446 }
3447
3448 func (rs *Rows) close(err error) error {
3449 rs.closemu.Lock()
3450 defer rs.closemu.Unlock()
3451
3452 if rs.closed {
3453 return nil
3454 }
3455 rs.closed = true
3456
3457 if rs.lasterr == nil {
3458 rs.lasterr = err
3459 }
3460
3461 withLock(rs.dc, func() {
3462 err = rs.rowsi.Close()
3463 })
3464 if fn := rowsCloseHook(); fn != nil {
3465 fn(rs, &err)
3466 }
3467 if rs.cancel != nil {
3468 rs.cancel()
3469 }
3470
3471 if rs.closeStmt != nil {
3472 rs.closeStmt.Close()
3473 }
3474 rs.releaseConn(err)
3475
3476 rs.lasterr = rs.lasterrOrErrLocked(err)
3477 return err
3478 }
3479
3480
3481 type Row struct {
3482
3483 err error
3484 rows *Rows
3485 }
3486
3487
3488
3489
3490
3491
3492 func (r *Row) Scan(dest ...any) error {
3493 if r.err != nil {
3494 return r.err
3495 }
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510 defer r.rows.Close()
3511 if scanArgsContainRawBytes(dest) {
3512 return errors.New("sql: RawBytes isn't allowed on Row.Scan")
3513 }
3514
3515 if !r.rows.Next() {
3516 if err := r.rows.Err(); err != nil {
3517 return err
3518 }
3519 return ErrNoRows
3520 }
3521 err := r.rows.Scan(dest...)
3522 if err != nil {
3523 return err
3524 }
3525
3526 return r.rows.Close()
3527 }
3528
3529
3530
3531
3532
3533 func (r *Row) Err() error {
3534 return r.err
3535 }
3536
3537
3538 type Result interface {
3539
3540
3541
3542
3543
3544 LastInsertId() (int64, error)
3545
3546
3547
3548
3549 RowsAffected() (int64, error)
3550 }
3551
3552 type driverResult struct {
3553 sync.Locker
3554 resi driver.Result
3555 }
3556
3557 func (dr driverResult) LastInsertId() (int64, error) {
3558 dr.Lock()
3559 defer dr.Unlock()
3560 return dr.resi.LastInsertId()
3561 }
3562
3563 func (dr driverResult) RowsAffected() (int64, error) {
3564 dr.Lock()
3565 defer dr.Unlock()
3566 return dr.resi.RowsAffected()
3567 }
3568
3569 func stack() string {
3570 var buf [2 << 10]byte
3571 return string(buf[:runtime.Stack(buf[:], false)])
3572 }
3573
3574
3575 func withLock(lk sync.Locker, fn func()) {
3576 lk.Lock()
3577 defer lk.Unlock()
3578 fn()
3579 }
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591 type connRequestSet struct {
3592
3593 s []connRequestAndIndex
3594 }
3595
3596 type connRequestAndIndex struct {
3597
3598 req chan connRequest
3599
3600
3601
3602 curIdx *int
3603 }
3604
3605
3606
3607 func (s *connRequestSet) CloseAndRemoveAll() {
3608 for _, v := range s.s {
3609 *v.curIdx = -1
3610 close(v.req)
3611 }
3612 s.s = nil
3613 }
3614
3615
3616 func (s *connRequestSet) Len() int { return len(s.s) }
3617
3618
3619
3620 type connRequestDelHandle struct {
3621 idx *int
3622 }
3623
3624
3625
3626
3627 func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
3628 idx := len(s.s)
3629
3630
3631
3632
3633
3634
3635
3636
3637 idxPtr := &idx
3638 s.s = append(s.s, connRequestAndIndex{v, idxPtr})
3639 return connRequestDelHandle{idxPtr}
3640 }
3641
3642
3643
3644
3645
3646 func (s *connRequestSet) Delete(h connRequestDelHandle) bool {
3647 idx := *h.idx
3648 if idx < 0 {
3649 return false
3650 }
3651 s.deleteIndex(idx)
3652 return true
3653 }
3654
3655 func (s *connRequestSet) deleteIndex(idx int) {
3656
3657 *(s.s[idx].curIdx) = -1
3658
3659
3660 if idx < len(s.s)-1 {
3661 last := s.s[len(s.s)-1]
3662 *last.curIdx = idx
3663 s.s[idx] = last
3664 }
3665
3666 s.s[len(s.s)-1] = connRequestAndIndex{}
3667 s.s = s.s[:len(s.s)-1]
3668 }
3669
3670
3671
3672
3673 func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) {
3674 if len(s.s) == 0 {
3675 return nil, false
3676 }
3677 pick := rand.IntN(len(s.s))
3678 e := s.s[pick]
3679 s.deleteIndex(pick)
3680 return e.req, true
3681 }
3682
View as plain text