diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index 645f2e2..10fcaf7 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -38,6 +38,7 @@ Examples: func main() { timeout := flag.Duration("t", time.Second*100000, "") + packetTimeout := flag.Duration("p", 0, "") interval := flag.Duration("i", time.Second, "") count := flag.Int("c", -1, "") size := flag.Int("s", 24, "") @@ -89,6 +90,7 @@ func main() { pinger.Size = *size pinger.Interval = *interval pinger.Timeout = *timeout + pinger.PacketTimeout = *packetTimeout pinger.TTL = *ttl pinger.SetPrivileged(*privileged) diff --git a/ping.go b/ping.go index 4b5248c..cbcc98a 100644 --- a/ping.go +++ b/ping.go @@ -95,11 +95,12 @@ func New(addr string) *Pinger { var firstSequence = map[uuid.UUID]map[int]struct{}{} firstSequence[firstUUID] = make(map[int]struct{}) return &Pinger{ - Count: -1, - Interval: time.Second, - RecordRtts: true, - Size: timeSliceLength + trackerLength, - Timeout: time.Duration(math.MaxInt64), + Count: -1, + Interval: time.Second, + RecordRtts: true, + Size: timeSliceLength + trackerLength, + Timeout: time.Duration(math.MaxInt64), + PacketTimeout: 1 * time.Second, addr: addr, done: make(chan interface{}), @@ -130,6 +131,9 @@ type Pinger struct { // packets have been received. Timeout time.Duration + // PacketTimeout specifies a timeout per packet. + PacketTimeout time.Duration + // Count tells pinger to stop after sending (and receiving) Count echo // packets. If this option is not specified, pinger will operate until // interrupted. @@ -286,6 +290,10 @@ type Statistics struct { } func (p *Pinger) updateStatistics(pkt *Packet) { + if pkt.Rtt >= p.PacketTimeout { // ignore packets that timeout from stats + return + } + p.statsMu.Lock() defer p.statsMu.Unlock() @@ -512,8 +520,9 @@ func (p *Pinger) runLoop( logger = NoopLogger{} } - timeout := time.NewTicker(p.Timeout) + timeout := time.NewTimer(calculateTimeout(p.Count, p.Interval, p.PacketTimeout, p.Timeout)) interval := time.NewTicker(p.Interval) + defer func() { interval.Stop() timeout.Stop() @@ -555,6 +564,21 @@ func (p *Pinger) runLoop( } } +func calculateTimeout(count int, interval, packetTimeout, requestTimeout time.Duration) time.Duration { + // Handles the continuous ping case as requestTimeout is max int64 by default. + if count == -1 { + return requestTimeout + } + + // the last packet roundtrip time + its timeout and a buffer of 100ms is the maximum time needed to collect all packets + pTimeout := time.Duration(count-1)*interval + packetTimeout + 100*time.Millisecond + if pTimeout < requestTimeout { + return pTimeout + } + + return requestTimeout +} + func (p *Pinger) Stop() { p.lock.Lock() defer p.lock.Unlock() diff --git a/ping_test.go b/ping_test.go index 7f8c7e9..6dca792 100644 --- a/ping_test.go +++ b/ping_test.go @@ -817,3 +817,47 @@ func TestRunWithBackgroundContext(t *testing.T) { } AssertTrue(t, stats.PacketsRecv == 10) } + +func TestCalculateTimeout(t *testing.T) { + tests := []struct { + name string + count int + interval time.Duration + packetTimeout time.Duration + requestTimeout time.Duration + want time.Duration + }{ + { + name: "Test Continuous Ping", + count: -1, + interval: 1 * time.Second, + packetTimeout: 2 * time.Second, + requestTimeout: time.Duration(1<<63 - 1), + want: time.Duration(1<<63 - 1), + }, + { + name: "Test Packet Timeout Greater Than Request Timeout", + count: 5, + interval: 500 * time.Millisecond, + packetTimeout: 3 * time.Second, + requestTimeout: 2 * time.Second, + want: 2 * time.Second, + }, + { + name: "Test Packet Timeout Less Than Request Timeout", + count: 5, + interval: 500 * time.Millisecond, + packetTimeout: 2 * time.Second, + requestTimeout: 5 * time.Second, + want: 2*time.Second + 2*time.Second + 100*time.Millisecond, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := calculateTimeout(tt.count, tt.interval, tt.packetTimeout, tt.requestTimeout); got != tt.want { + t.Errorf("calculateTimeout() = %v, want %v", got, tt.want) + } + }) + } +}