Source file
src/net/dial.go
1
2
3
4
5 package net
6
7 import (
8 "context"
9 "internal/bytealg"
10 "internal/godebug"
11 "internal/nettrace"
12 "net/netip"
13 "syscall"
14 "time"
15 )
16
17 const (
18
19
20 defaultTCPKeepAliveIdle = 15 * time.Second
21
22
23
24 defaultTCPKeepAliveInterval = 15 * time.Second
25
26
27 defaultTCPKeepAliveCount = 9
28
29
30
31
32 defaultMPTCPEnabledListen = true
33 defaultMPTCPEnabledDial = false
34 )
35
36
37
38
39
40
41
42 var multipathtcp = godebug.New("multipathtcp")
43
44
45
46 type mptcpStatusDial uint8
47
48 const (
49
50 mptcpUseDefaultDial mptcpStatusDial = iota
51 mptcpEnabledDial
52 mptcpDisabledDial
53 )
54
55 func (m *mptcpStatusDial) get() bool {
56 switch *m {
57 case mptcpEnabledDial:
58 return true
59 case mptcpDisabledDial:
60 return false
61 }
62
63
64 if multipathtcp.Value() == "1" || multipathtcp.Value() == "3" {
65 multipathtcp.IncNonDefault()
66
67 return true
68 }
69
70 return defaultMPTCPEnabledDial
71 }
72
73 func (m *mptcpStatusDial) set(use bool) {
74 if use {
75 *m = mptcpEnabledDial
76 } else {
77 *m = mptcpDisabledDial
78 }
79 }
80
81
82
83 type mptcpStatusListen uint8
84
85 const (
86
87 mptcpUseDefaultListen mptcpStatusListen = iota
88 mptcpEnabledListen
89 mptcpDisabledListen
90 )
91
92 func (m *mptcpStatusListen) get() bool {
93 switch *m {
94 case mptcpEnabledListen:
95 return true
96 case mptcpDisabledListen:
97 return false
98 }
99
100
101
102 if multipathtcp.Value() == "0" || multipathtcp.Value() == "3" {
103 multipathtcp.IncNonDefault()
104
105 return false
106 }
107
108 return defaultMPTCPEnabledListen
109 }
110
111 func (m *mptcpStatusListen) set(use bool) {
112 if use {
113 *m = mptcpEnabledListen
114 } else {
115 *m = mptcpDisabledListen
116 }
117 }
118
119
120
121
122
123
124
125
126 type Dialer struct {
127
128
129
130
131
132
133
134
135
136
137
138
139 Timeout time.Duration
140
141
142
143
144
145 Deadline time.Time
146
147
148
149
150
151 LocalAddr Addr
152
153
154
155
156
157
158
159
160 DualStack bool
161
162
163
164
165
166
167
168
169
170 FallbackDelay time.Duration
171
172
173
174
175
176
177
178
179
180
181
182 KeepAlive time.Duration
183
184
185
186
187
188
189
190
191 KeepAliveConfig KeepAliveConfig
192
193
194 Resolver *Resolver
195
196
197
198
199
200
201 Cancel <-chan struct{}
202
203
204
205
206
207
208
209
210
211
212
213 Control func(network, address string, c syscall.RawConn) error
214
215
216
217
218
219
220
221
222
223
224
225 ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
226
227
228
229
230 mptcpStatus mptcpStatusDial
231 }
232
233 func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
234
235 func minNonzeroTime(a, b time.Time) time.Time {
236 if a.IsZero() {
237 return b
238 }
239 if b.IsZero() || a.Before(b) {
240 return a
241 }
242 return b
243 }
244
245
246
247
248
249
250
251 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
252 if d.Timeout != 0 {
253 earliest = now.Add(d.Timeout)
254 }
255 if d, ok := ctx.Deadline(); ok {
256 earliest = minNonzeroTime(earliest, d)
257 }
258 return minNonzeroTime(earliest, d.Deadline)
259 }
260
261 func (d *Dialer) resolver() *Resolver {
262 if d.Resolver != nil {
263 return d.Resolver
264 }
265 return DefaultResolver
266 }
267
268
269
270 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
271 if deadline.IsZero() {
272 return deadline, nil
273 }
274 timeRemaining := deadline.Sub(now)
275 if timeRemaining <= 0 {
276 return time.Time{}, errTimeout
277 }
278
279 timeout := timeRemaining / time.Duration(addrsRemaining)
280
281 const saneMinimum = 2 * time.Second
282 if timeout < saneMinimum {
283 if timeRemaining < saneMinimum {
284 timeout = timeRemaining
285 } else {
286 timeout = saneMinimum
287 }
288 }
289 return now.Add(timeout), nil
290 }
291
292 func (d *Dialer) fallbackDelay() time.Duration {
293 if d.FallbackDelay > 0 {
294 return d.FallbackDelay
295 } else {
296 return 300 * time.Millisecond
297 }
298 }
299
300 func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
301 i := bytealg.LastIndexByteString(network, ':')
302 if i < 0 {
303 switch network {
304 case "tcp", "tcp4", "tcp6":
305 case "udp", "udp4", "udp6":
306 case "ip", "ip4", "ip6":
307 if needsProto {
308 return "", 0, UnknownNetworkError(network)
309 }
310 case "unix", "unixgram", "unixpacket":
311 default:
312 return "", 0, UnknownNetworkError(network)
313 }
314 return network, 0, nil
315 }
316 afnet = network[:i]
317 switch afnet {
318 case "ip", "ip4", "ip6":
319 protostr := network[i+1:]
320 proto, i, ok := dtoi(protostr)
321 if !ok || i != len(protostr) {
322 proto, err = lookupProtocol(ctx, protostr)
323 if err != nil {
324 return "", 0, err
325 }
326 }
327 return afnet, proto, nil
328 }
329 return "", 0, UnknownNetworkError(network)
330 }
331
332
333
334
335 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
336 afnet, _, err := parseNetwork(ctx, network, true)
337 if err != nil {
338 return nil, err
339 }
340 if op == "dial" && addr == "" {
341 return nil, errMissingAddress
342 }
343 switch afnet {
344 case "unix", "unixgram", "unixpacket":
345 addr, err := ResolveUnixAddr(afnet, addr)
346 if err != nil {
347 return nil, err
348 }
349 if op == "dial" && hint != nil && addr.Network() != hint.Network() {
350 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
351 }
352 return addrList{addr}, nil
353 }
354 addrs, err := r.internetAddrList(ctx, afnet, addr)
355 if err != nil || op != "dial" || hint == nil {
356 return addrs, err
357 }
358 var (
359 tcp *TCPAddr
360 udp *UDPAddr
361 ip *IPAddr
362 wildcard bool
363 )
364 switch hint := hint.(type) {
365 case *TCPAddr:
366 tcp = hint
367 wildcard = tcp.isWildcard()
368 case *UDPAddr:
369 udp = hint
370 wildcard = udp.isWildcard()
371 case *IPAddr:
372 ip = hint
373 wildcard = ip.isWildcard()
374 }
375 naddrs := addrs[:0]
376 for _, addr := range addrs {
377 if addr.Network() != hint.Network() {
378 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
379 }
380 switch addr := addr.(type) {
381 case *TCPAddr:
382 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
383 continue
384 }
385 naddrs = append(naddrs, addr)
386 case *UDPAddr:
387 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
388 continue
389 }
390 naddrs = append(naddrs, addr)
391 case *IPAddr:
392 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
393 continue
394 }
395 naddrs = append(naddrs, addr)
396 }
397 }
398 if len(naddrs) == 0 {
399 return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
400 }
401 return naddrs, nil
402 }
403
404
405
406
407
408 func (d *Dialer) MultipathTCP() bool {
409 return d.mptcpStatus.get()
410 }
411
412
413
414
415
416
417
418 func (d *Dialer) SetMultipathTCP(use bool) {
419 d.mptcpStatus.set(use)
420 }
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470 func Dial(network, address string) (Conn, error) {
471 var d Dialer
472 return d.Dial(network, address)
473 }
474
475
476
477
478
479
480
481
482
483
484
485 func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
486 d := Dialer{Timeout: timeout}
487 return d.Dial(network, address)
488 }
489
490
491 type sysDialer struct {
492 Dialer
493 network, address string
494 testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
495 }
496
497
498
499
500
501
502
503
504 func (d *Dialer) Dial(network, address string) (Conn, error) {
505 return d.DialContext(context.Background(), network, address)
506 }
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
527 ctx, cancel := d.dialCtx(ctx)
528 defer cancel()
529
530
531 resolveCtx := ctx
532 if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
533 shadow := *trace
534 shadow.ConnectStart = nil
535 shadow.ConnectDone = nil
536 resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
537 }
538
539 addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
540 if err != nil {
541 return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
542 }
543
544 sd := &sysDialer{
545 Dialer: *d,
546 network: network,
547 address: address,
548 }
549
550 var primaries, fallbacks addrList
551 if d.dualStack() && network == "tcp" {
552 primaries, fallbacks = addrs.partition(isIPv4)
553 } else {
554 primaries = addrs
555 }
556
557 return sd.dialParallel(ctx, primaries, fallbacks)
558 }
559
560 func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) {
561 if ctx == nil {
562 panic("nil context")
563 }
564 deadline := d.deadline(ctx, time.Now())
565 var cancel1, cancel2 context.CancelFunc
566 if !deadline.IsZero() {
567 testHookStepTime()
568 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
569 var subCtx context.Context
570 subCtx, cancel1 = context.WithDeadline(ctx, deadline)
571 ctx = subCtx
572 }
573 }
574 if oldCancel := d.Cancel; oldCancel != nil {
575 subCtx, cancel2 := context.WithCancel(ctx)
576 go func() {
577 select {
578 case <-oldCancel:
579 cancel2()
580 case <-subCtx.Done():
581 }
582 }()
583 ctx = subCtx
584 }
585 return ctx, func() {
586 if cancel1 != nil {
587 cancel1()
588 }
589 if cancel2 != nil {
590 cancel2()
591 }
592 }
593 }
594
595
596
597
598
599
600
601
602
603 func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) {
604 ctx, cancel := d.dialCtx(ctx)
605 defer cancel()
606 return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr))
607 }
608
609
610
611
612
613
614
615
616
617 func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) {
618 ctx, cancel := d.dialCtx(ctx)
619 defer cancel()
620 return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr))
621 }
622
623
624
625
626
627
628
629
630
631 func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) {
632 ctx, cancel := d.dialCtx(ctx)
633 defer cancel()
634 return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr))
635 }
636
637
638
639
640
641
642
643
644
645 func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) {
646 ctx, cancel := d.dialCtx(ctx)
647 defer cancel()
648 return dialUnix(ctx, d, network, laddr, raddr)
649 }
650
651
652
653
654
655 func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
656 if len(fallbacks) == 0 {
657 return sd.dialSerial(ctx, primaries)
658 }
659
660 returned := make(chan struct{})
661 defer close(returned)
662
663 type dialResult struct {
664 Conn
665 error
666 primary bool
667 done bool
668 }
669 results := make(chan dialResult)
670
671 startRacer := func(ctx context.Context, primary bool) {
672 ras := primaries
673 if !primary {
674 ras = fallbacks
675 }
676 c, err := sd.dialSerial(ctx, ras)
677 select {
678 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
679 case <-returned:
680 if c != nil {
681 c.Close()
682 }
683 }
684 }
685
686 var primary, fallback dialResult
687
688
689 primaryCtx, primaryCancel := context.WithCancel(ctx)
690 defer primaryCancel()
691 go startRacer(primaryCtx, true)
692
693
694 fallbackTimer := time.NewTimer(sd.fallbackDelay())
695 defer fallbackTimer.Stop()
696
697 for {
698 select {
699 case <-fallbackTimer.C:
700 fallbackCtx, fallbackCancel := context.WithCancel(ctx)
701 defer fallbackCancel()
702 go startRacer(fallbackCtx, false)
703
704 case res := <-results:
705 if res.error == nil {
706 return res.Conn, nil
707 }
708 if res.primary {
709 primary = res
710 } else {
711 fallback = res
712 }
713 if primary.done && fallback.done {
714 return nil, primary.error
715 }
716 if res.primary && fallbackTimer.Stop() {
717
718
719
720
721 fallbackTimer.Reset(0)
722 }
723 }
724 }
725 }
726
727
728
729 func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
730 var firstErr error
731
732 for i, ra := range ras {
733 select {
734 case <-ctx.Done():
735 return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
736 default:
737 }
738
739 dialCtx := ctx
740 if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
741 partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
742 if err != nil {
743
744 if firstErr == nil {
745 firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
746 }
747 break
748 }
749 if partialDeadline.Before(deadline) {
750 var cancel context.CancelFunc
751 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
752 defer cancel()
753 }
754 }
755
756 c, err := sd.dialSingle(dialCtx, ra)
757 if err == nil {
758 return c, nil
759 }
760 if firstErr == nil {
761 firstErr = err
762 }
763 }
764
765 if firstErr == nil {
766 firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
767 }
768 return nil, firstErr
769 }
770
771
772
773 func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
774 trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
775 if trace != nil {
776 raStr := ra.String()
777 if trace.ConnectStart != nil {
778 trace.ConnectStart(sd.network, raStr)
779 }
780 if trace.ConnectDone != nil {
781 defer func() { trace.ConnectDone(sd.network, raStr, err) }()
782 }
783 }
784 la := sd.LocalAddr
785 switch ra := ra.(type) {
786 case *TCPAddr:
787 la, _ := la.(*TCPAddr)
788 if sd.MultipathTCP() {
789 c, err = sd.dialMPTCP(ctx, la, ra)
790 } else {
791 c, err = sd.dialTCP(ctx, la, ra)
792 }
793 case *UDPAddr:
794 la, _ := la.(*UDPAddr)
795 c, err = sd.dialUDP(ctx, la, ra)
796 case *IPAddr:
797 la, _ := la.(*IPAddr)
798 c, err = sd.dialIP(ctx, la, ra)
799 case *UnixAddr:
800 la, _ := la.(*UnixAddr)
801 c, err = sd.dialUnix(ctx, la, ra)
802 default:
803 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
804 }
805 if err != nil {
806 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err}
807 }
808 return c, nil
809 }
810
811
812 type ListenConfig struct {
813
814
815
816
817
818
819
820
821 Control func(network, address string, c syscall.RawConn) error
822
823
824
825
826
827
828
829
830
831
832 KeepAlive time.Duration
833
834
835
836
837
838
839
840
841 KeepAliveConfig KeepAliveConfig
842
843
844
845
846 mptcpStatus mptcpStatusListen
847 }
848
849
850
851
852
853 func (lc *ListenConfig) MultipathTCP() bool {
854 return lc.mptcpStatus.get()
855 }
856
857
858
859
860
861
862
863 func (lc *ListenConfig) SetMultipathTCP(use bool) {
864 lc.mptcpStatus.set(use)
865 }
866
867
868
869
870
871
872
873
874 func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
875 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
876 if err != nil {
877 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
878 }
879 sl := &sysListener{
880 ListenConfig: *lc,
881 network: network,
882 address: address,
883 }
884 var l Listener
885 la := addrs.first(isIPv4)
886 switch la := la.(type) {
887 case *TCPAddr:
888 if sl.MultipathTCP() {
889 l, err = sl.listenMPTCP(ctx, la)
890 } else {
891 l, err = sl.listenTCP(ctx, la)
892 }
893 case *UnixAddr:
894 l, err = sl.listenUnix(ctx, la)
895 default:
896 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
897 }
898 if err != nil {
899 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
900 }
901 return l, nil
902 }
903
904
905
906
907
908
909
910
911 func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
912 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
913 if err != nil {
914 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
915 }
916 sl := &sysListener{
917 ListenConfig: *lc,
918 network: network,
919 address: address,
920 }
921 var c PacketConn
922 la := addrs.first(isIPv4)
923 switch la := la.(type) {
924 case *UDPAddr:
925 c, err = sl.listenUDP(ctx, la)
926 case *IPAddr:
927 c, err = sl.listenIP(ctx, la)
928 case *UnixAddr:
929 c, err = sl.listenUnixgram(ctx, la)
930 default:
931 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
932 }
933 if err != nil {
934 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
935 }
936 return c, nil
937 }
938
939
940 type sysListener struct {
941 ListenConfig
942 network, address string
943 }
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966 func Listen(network, address string) (Listener, error) {
967 var lc ListenConfig
968 return lc.Listen(context.Background(), network, address)
969 }
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996 func ListenPacket(network, address string) (PacketConn, error) {
997 var lc ListenConfig
998 return lc.ListenPacket(context.Background(), network, address)
999 }
1000
View as plain text