Source file
src/net/dial.go
1
2
3
4
5 package net
6
7 import (
8 "context"
9 "internal/bytealg"
10 "internal/godebug"
11 "internal/nettrace"
12 "net/netip"
13 "syscall"
14 "time"
15 )
16
17 const (
18
19
20 defaultTCPKeepAliveIdle = 15 * time.Second
21
22
23
24 defaultTCPKeepAliveInterval = 15 * time.Second
25
26
27 defaultTCPKeepAliveCount = 9
28
29
30
31
32 defaultMPTCPEnabledListen = true
33 defaultMPTCPEnabledDial = false
34 )
35
36
37
38
39
40
41
42 var multipathtcp = godebug.New("multipathtcp")
43
44
45
46 type mptcpStatusDial uint8
47
48 const (
49
50 mptcpUseDefaultDial mptcpStatusDial = iota
51 mptcpEnabledDial
52 mptcpDisabledDial
53 )
54
55 func (m *mptcpStatusDial) get() bool {
56 switch *m {
57 case mptcpEnabledDial:
58 return true
59 case mptcpDisabledDial:
60 return false
61 }
62
63
64 if multipathtcp.Value() == "1" || multipathtcp.Value() == "3" {
65 multipathtcp.IncNonDefault()
66
67 return true
68 }
69
70 return defaultMPTCPEnabledDial
71 }
72
73 func (m *mptcpStatusDial) set(use bool) {
74 if use {
75 *m = mptcpEnabledDial
76 } else {
77 *m = mptcpDisabledDial
78 }
79 }
80
81
82
83 type mptcpStatusListen uint8
84
85 const (
86
87 mptcpUseDefaultListen mptcpStatusListen = iota
88 mptcpEnabledListen
89 mptcpDisabledListen
90 )
91
92 func (m *mptcpStatusListen) get() bool {
93 switch *m {
94 case mptcpEnabledListen:
95 return true
96 case mptcpDisabledListen:
97 return false
98 }
99
100
101
102 if multipathtcp.Value() == "0" || multipathtcp.Value() == "3" {
103 multipathtcp.IncNonDefault()
104
105 return false
106 }
107
108 return defaultMPTCPEnabledListen
109 }
110
111 func (m *mptcpStatusListen) set(use bool) {
112 if use {
113 *m = mptcpEnabledListen
114 } else {
115 *m = mptcpDisabledListen
116 }
117 }
118
119
120
121
122
123
124
125
126 type Dialer struct {
127
128
129
130
131
132
133
134
135
136
137
138
139 Timeout time.Duration
140
141
142
143
144
145 Deadline time.Time
146
147
148
149
150
151 LocalAddr Addr
152
153
154
155
156
157
158
159
160 DualStack bool
161
162
163
164
165
166
167
168
169
170
171
172
173 FallbackDelay time.Duration
174
175
176
177
178
179
180
181
182
183
184
185 KeepAlive time.Duration
186
187
188
189
190
191
192
193
194 KeepAliveConfig KeepAliveConfig
195
196
197 Resolver *Resolver
198
199
200
201
202
203
204 Cancel <-chan struct{}
205
206
207
208
209
210
211
212
213
214
215
216 Control func(network, address string, c syscall.RawConn) error
217
218
219
220
221
222
223
224
225
226
227
228 ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
229
230
231
232
233 mptcpStatus mptcpStatusDial
234 }
235
236 func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
237
238 func minNonzeroTime(a, b time.Time) time.Time {
239 if a.IsZero() {
240 return b
241 }
242 if b.IsZero() || a.Before(b) {
243 return a
244 }
245 return b
246 }
247
248
249
250
251
252
253
254 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
255 if d.Timeout != 0 {
256 earliest = now.Add(d.Timeout)
257 }
258 if d, ok := ctx.Deadline(); ok {
259 earliest = minNonzeroTime(earliest, d)
260 }
261 return minNonzeroTime(earliest, d.Deadline)
262 }
263
264 func (d *Dialer) resolver() *Resolver {
265 if d.Resolver != nil {
266 return d.Resolver
267 }
268 return DefaultResolver
269 }
270
271
272
273 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
274 if deadline.IsZero() {
275 return deadline, nil
276 }
277 timeRemaining := deadline.Sub(now)
278 if timeRemaining <= 0 {
279 return time.Time{}, errTimeout
280 }
281
282 timeout := timeRemaining / time.Duration(addrsRemaining)
283
284 const saneMinimum = 2 * time.Second
285 if timeout < saneMinimum {
286 if timeRemaining < saneMinimum {
287 timeout = timeRemaining
288 } else {
289 timeout = saneMinimum
290 }
291 }
292 return now.Add(timeout), nil
293 }
294
295 func (d *Dialer) fallbackDelay() time.Duration {
296 if d.FallbackDelay > 0 {
297 return d.FallbackDelay
298 } else {
299 return 300 * time.Millisecond
300 }
301 }
302
303 func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
304 i := bytealg.LastIndexByteString(network, ':')
305 if i < 0 {
306 switch network {
307 case "tcp", "tcp4", "tcp6":
308 case "udp", "udp4", "udp6":
309 case "ip", "ip4", "ip6":
310 if needsProto {
311 return "", 0, UnknownNetworkError(network)
312 }
313 case "unix", "unixgram", "unixpacket":
314 default:
315 return "", 0, UnknownNetworkError(network)
316 }
317 return network, 0, nil
318 }
319 afnet = network[:i]
320 switch afnet {
321 case "ip", "ip4", "ip6":
322 protostr := network[i+1:]
323 proto, i, ok := dtoi(protostr)
324 if !ok || i != len(protostr) {
325 proto, err = lookupProtocol(ctx, protostr)
326 if err != nil {
327 return "", 0, err
328 }
329 }
330 return afnet, proto, nil
331 }
332 return "", 0, UnknownNetworkError(network)
333 }
334
335
336
337
338 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
339 afnet, _, err := parseNetwork(ctx, network, true)
340 if err != nil {
341 return nil, err
342 }
343 if op == "dial" && addr == "" {
344 return nil, errMissingAddress
345 }
346 switch afnet {
347 case "unix", "unixgram", "unixpacket":
348 addr, err := ResolveUnixAddr(afnet, addr)
349 if err != nil {
350 return nil, err
351 }
352 if op == "dial" && hint != nil && addr.Network() != hint.Network() {
353 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
354 }
355 return addrList{addr}, nil
356 }
357 addrs, err := r.internetAddrList(ctx, afnet, addr)
358 if err != nil || op != "dial" || hint == nil {
359 return addrs, err
360 }
361 var (
362 tcp *TCPAddr
363 udp *UDPAddr
364 ip *IPAddr
365 wildcard bool
366 )
367 switch hint := hint.(type) {
368 case *TCPAddr:
369 tcp = hint
370 wildcard = tcp.isWildcard()
371 case *UDPAddr:
372 udp = hint
373 wildcard = udp.isWildcard()
374 case *IPAddr:
375 ip = hint
376 wildcard = ip.isWildcard()
377 }
378 naddrs := addrs[:0]
379 for _, addr := range addrs {
380 if addr.Network() != hint.Network() {
381 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
382 }
383 switch addr := addr.(type) {
384 case *TCPAddr:
385 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
386 continue
387 }
388 naddrs = append(naddrs, addr)
389 case *UDPAddr:
390 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
391 continue
392 }
393 naddrs = append(naddrs, addr)
394 case *IPAddr:
395 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
396 continue
397 }
398 naddrs = append(naddrs, addr)
399 }
400 }
401 if len(naddrs) == 0 {
402 return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
403 }
404 return naddrs, nil
405 }
406
407
408
409
410
411 func (d *Dialer) MultipathTCP() bool {
412 return d.mptcpStatus.get()
413 }
414
415
416
417
418
419
420
421 func (d *Dialer) SetMultipathTCP(use bool) {
422 d.mptcpStatus.set(use)
423 }
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473 func Dial(network, address string) (Conn, error) {
474 var d Dialer
475 return d.Dial(network, address)
476 }
477
478
479
480
481
482
483
484
485
486
487
488 func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
489 d := Dialer{Timeout: timeout}
490 return d.Dial(network, address)
491 }
492
493
494 type sysDialer struct {
495 Dialer
496 network, address string
497 testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
498 }
499
500
501
502
503
504
505
506
507 func (d *Dialer) Dial(network, address string) (Conn, error) {
508 return d.DialContext(context.Background(), network, address)
509 }
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
530 ctx, cancel := d.dialCtx(ctx)
531 defer cancel()
532
533
534 resolveCtx := ctx
535 if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
536 shadow := *trace
537 shadow.ConnectStart = nil
538 shadow.ConnectDone = nil
539 resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
540 }
541
542 addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
543 if err != nil {
544 return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
545 }
546
547 sd := &sysDialer{
548 Dialer: *d,
549 network: network,
550 address: address,
551 }
552
553 var primaries, fallbacks addrList
554 if d.dualStack() && network == "tcp" {
555 primaries, fallbacks = addrs.partition(isIPv4)
556 } else {
557 primaries = addrs
558 }
559
560 return sd.dialParallel(ctx, primaries, fallbacks)
561 }
562
563 func (d *Dialer) dialCtx(ctx context.Context) (context.Context, context.CancelFunc) {
564 if ctx == nil {
565 panic("nil context")
566 }
567 deadline := d.deadline(ctx, time.Now())
568 var cancel1, cancel2 context.CancelFunc
569 if !deadline.IsZero() {
570 testHookStepTime()
571 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
572 var subCtx context.Context
573 subCtx, cancel1 = context.WithDeadline(ctx, deadline)
574 ctx = subCtx
575 }
576 }
577 if oldCancel := d.Cancel; oldCancel != nil {
578 var subCtx context.Context
579 subCtx, cancel2 = context.WithCancel(ctx)
580 go func() {
581 select {
582 case <-oldCancel:
583 cancel2()
584 case <-subCtx.Done():
585 }
586 }()
587 ctx = subCtx
588 }
589 return ctx, func() {
590 if cancel1 != nil {
591 cancel1()
592 }
593 if cancel2 != nil {
594 cancel2()
595 }
596 }
597 }
598
599
600
601
602
603
604
605
606
607 func (d *Dialer) DialTCP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*TCPConn, error) {
608 ctx, cancel := d.dialCtx(ctx)
609 defer cancel()
610 return dialTCP(ctx, d, network, TCPAddrFromAddrPort(laddr), TCPAddrFromAddrPort(raddr))
611 }
612
613
614
615
616
617
618
619
620
621 func (d *Dialer) DialUDP(ctx context.Context, network string, laddr netip.AddrPort, raddr netip.AddrPort) (*UDPConn, error) {
622 ctx, cancel := d.dialCtx(ctx)
623 defer cancel()
624 return dialUDP(ctx, d, network, UDPAddrFromAddrPort(laddr), UDPAddrFromAddrPort(raddr))
625 }
626
627
628
629
630
631
632
633
634
635 func (d *Dialer) DialIP(ctx context.Context, network string, laddr netip.Addr, raddr netip.Addr) (*IPConn, error) {
636 ctx, cancel := d.dialCtx(ctx)
637 defer cancel()
638 return dialIP(ctx, d, network, ipAddrFromAddr(laddr), ipAddrFromAddr(raddr))
639 }
640
641
642
643
644
645
646
647
648
649 func (d *Dialer) DialUnix(ctx context.Context, network string, laddr *UnixAddr, raddr *UnixAddr) (*UnixConn, error) {
650 ctx, cancel := d.dialCtx(ctx)
651 defer cancel()
652 return dialUnix(ctx, d, network, laddr, raddr)
653 }
654
655
656
657
658
659 func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
660 if len(fallbacks) == 0 {
661 return sd.dialSerial(ctx, primaries)
662 }
663
664 returned := make(chan struct{})
665 defer close(returned)
666
667 type dialResult struct {
668 Conn
669 error
670 primary bool
671 done bool
672 }
673 results := make(chan dialResult)
674
675 startRacer := func(ctx context.Context, primary bool) {
676 ras := primaries
677 if !primary {
678 ras = fallbacks
679 }
680 c, err := sd.dialSerial(ctx, ras)
681 select {
682 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
683 case <-returned:
684 if c != nil {
685 c.Close()
686 }
687 }
688 }
689
690 var primary, fallback dialResult
691
692
693 primaryCtx, primaryCancel := context.WithCancel(ctx)
694 defer primaryCancel()
695 go startRacer(primaryCtx, true)
696
697
698 fallbackTimer := time.NewTimer(sd.fallbackDelay())
699 defer fallbackTimer.Stop()
700
701 for {
702 select {
703 case <-fallbackTimer.C:
704 fallbackCtx, fallbackCancel := context.WithCancel(ctx)
705 defer fallbackCancel()
706 go startRacer(fallbackCtx, false)
707
708 case res := <-results:
709 if res.error == nil {
710 return res.Conn, nil
711 }
712 if res.primary {
713 primary = res
714 } else {
715 fallback = res
716 }
717 if primary.done && fallback.done {
718 return nil, primary.error
719 }
720 if res.primary && fallbackTimer.Stop() {
721
722
723
724
725 fallbackTimer.Reset(0)
726 }
727 }
728 }
729 }
730
731
732
733 func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
734 var firstErr error
735
736 for i, ra := range ras {
737 select {
738 case <-ctx.Done():
739 return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
740 default:
741 }
742
743 dialCtx := ctx
744 if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
745 partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
746 if err != nil {
747
748 if firstErr == nil {
749 firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
750 }
751 break
752 }
753 if partialDeadline.Before(deadline) {
754 var cancel context.CancelFunc
755 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
756 defer cancel()
757 }
758 }
759
760 c, err := sd.dialSingle(dialCtx, ra)
761 if err == nil {
762 return c, nil
763 }
764 if firstErr == nil {
765 firstErr = err
766 }
767 }
768
769 if firstErr == nil {
770 firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
771 }
772 return nil, firstErr
773 }
774
775
776
777 func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
778 trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
779 if trace != nil {
780 raStr := ra.String()
781 if trace.ConnectStart != nil {
782 trace.ConnectStart(sd.network, raStr)
783 }
784 if trace.ConnectDone != nil {
785 defer func() { trace.ConnectDone(sd.network, raStr, err) }()
786 }
787 }
788 la := sd.LocalAddr
789 switch ra := ra.(type) {
790 case *TCPAddr:
791 la, _ := la.(*TCPAddr)
792 if sd.MultipathTCP() {
793 c, err = sd.dialMPTCP(ctx, la, ra)
794 } else {
795 c, err = sd.dialTCP(ctx, la, ra)
796 }
797 case *UDPAddr:
798 la, _ := la.(*UDPAddr)
799 c, err = sd.dialUDP(ctx, la, ra)
800 case *IPAddr:
801 la, _ := la.(*IPAddr)
802 c, err = sd.dialIP(ctx, la, ra)
803 case *UnixAddr:
804 la, _ := la.(*UnixAddr)
805 c, err = sd.dialUnix(ctx, la, ra)
806 default:
807 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
808 }
809 if err != nil {
810 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err}
811 }
812 return c, nil
813 }
814
815
816 type ListenConfig struct {
817
818
819
820
821
822
823
824
825 Control func(network, address string, c syscall.RawConn) error
826
827
828
829
830
831
832
833
834
835
836 KeepAlive time.Duration
837
838
839
840
841
842
843
844
845 KeepAliveConfig KeepAliveConfig
846
847
848
849
850 mptcpStatus mptcpStatusListen
851 }
852
853
854
855
856
857 func (lc *ListenConfig) MultipathTCP() bool {
858 return lc.mptcpStatus.get()
859 }
860
861
862
863
864
865
866
867 func (lc *ListenConfig) SetMultipathTCP(use bool) {
868 lc.mptcpStatus.set(use)
869 }
870
871
872
873
874
875
876
877
878 func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
879 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
880 if err != nil {
881 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
882 }
883 sl := &sysListener{
884 ListenConfig: *lc,
885 network: network,
886 address: address,
887 }
888 var l Listener
889 la := addrs.first(isIPv4)
890 switch la := la.(type) {
891 case *TCPAddr:
892 if sl.MultipathTCP() {
893 l, err = sl.listenMPTCP(ctx, la)
894 } else {
895 l, err = sl.listenTCP(ctx, la)
896 }
897 case *UnixAddr:
898 l, err = sl.listenUnix(ctx, la)
899 default:
900 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
901 }
902 if err != nil {
903 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
904 }
905 return l, nil
906 }
907
908
909
910
911
912
913
914
915 func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
916 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
917 if err != nil {
918 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
919 }
920 sl := &sysListener{
921 ListenConfig: *lc,
922 network: network,
923 address: address,
924 }
925 var c PacketConn
926 la := addrs.first(isIPv4)
927 switch la := la.(type) {
928 case *UDPAddr:
929 c, err = sl.listenUDP(ctx, la)
930 case *IPAddr:
931 c, err = sl.listenIP(ctx, la)
932 case *UnixAddr:
933 c, err = sl.listenUnixgram(ctx, la)
934 default:
935 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
936 }
937 if err != nil {
938 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
939 }
940 return c, nil
941 }
942
943
944 type sysListener struct {
945 ListenConfig
946 network, address string
947 }
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970 func Listen(network, address string) (Listener, error) {
971 var lc ListenConfig
972 return lc.Listen(context.Background(), network, address)
973 }
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000 func ListenPacket(network, address string) (PacketConn, error) {
1001 var lc ListenConfig
1002 return lc.ListenPacket(context.Background(), network, address)
1003 }
1004
View as plain text