diff --git a/ping.go b/ping.go index 3a72159..c891cce 100644 --- a/ping.go +++ b/ping.go @@ -78,6 +78,10 @@ const ( trackerLength = len(uuid.UUID{}) protocolICMP = 1 protocolIPv6ICMP = 58 + + networkIP = "ip" + networkIPv4 = "ip4" + networkIPv6 = "ip6" ) var ( @@ -107,7 +111,7 @@ func New(addr string) *Pinger { trackerUUIDs: []uuid.UUID{firstUUID}, ipaddr: nil, ipv4: false, - network: "ip", + network: networkIP, protocol: "udp", awaitingSequences: firstSequence, TTL: 64, @@ -130,6 +134,9 @@ type Pinger struct { // packets have been received. Timeout time.Duration + // ResolveTimeout specifies a timeout to resolve an IP address or domain name + ResolveTimeout time.Duration + // Count tells pinger to stop after sending (and receiving) Count echo // packets. If this option is not specified, pinger will operate until // interrupted. @@ -338,9 +345,42 @@ func (p *Pinger) Resolve() error { if len(p.addr) == 0 { return errors.New("addr cannot be empty") } - ipaddr, err := net.ResolveIPAddr(p.network, p.addr) - if err != nil { - return err + var ( + ipaddr *net.IPAddr + err error + ) + if p.ResolveTimeout > time.Duration(0) { + var ( + ctx = context.Background() + ips []net.IP + ) + ctx, cancel := context.WithTimeout(ctx, p.ResolveTimeout) + defer cancel() + ips, err = net.DefaultResolver.LookupIP(ctx, p.network, p.addr) + if err != nil { + return err + } + if len(ips) == 0 { + return fmt.Errorf("lookup %s failed: no addresses found", p.addr) + } + ipaddr = &net.IPAddr{IP: ips[0]} + for _, ip := range ips { + if p.network == networkIPv6 { + if ip.To4() == nil && ip.To16() != nil { + ipaddr = &net.IPAddr{IP: ip} + break + } + continue + } + if ip.To4() != nil { + ipaddr = &net.IPAddr{IP: ip} + } + } + } else { + ipaddr, err = net.ResolveIPAddr(p.network, p.addr) + if err != nil { + return err + } } p.ipv4 = isIPv4(ipaddr.IP) @@ -374,12 +414,12 @@ func (p *Pinger) Addr() string { // * "ip6" will select IPv6. func (p *Pinger) SetNetwork(n string) { switch n { - case "ip4": - p.network = "ip4" - case "ip6": - p.network = "ip6" + case networkIPv4: + p.network = networkIPv4 + case networkIPv6: + p.network = networkIPv6 default: - p.network = "ip" + p.network = networkIP } } diff --git a/ping_test.go b/ping_test.go index 722bf95..bbbfe9b 100644 --- a/ping_test.go +++ b/ping_test.go @@ -816,3 +816,21 @@ func TestRunWithBackgroundContext(t *testing.T) { } AssertTrue(t, stats.PacketsRecv == 10) } + +func TestSetResolveTimeout(t *testing.T) { + p := New("www.google.com") + p.Count = 3 + p.Timeout = 5 * time.Second + p.ResolveTimeout = 2 * time.Second + err := p.Resolve() + AssertNoError(t, err) + + err = p.SetAddr("www.google.com ") + AssertError(t, err, "") + + err = p.SetAddr("127.0.0.1 ") + AssertError(t, err, "") + + err = p.SetAddr("127.0.0.1") + AssertNoError(t, err) +}