-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathratelimiter.go
153 lines (133 loc) · 3.77 KB
/
ratelimiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package ratelimit
import (
"context"
"fmt"
"sync"
"time"
"github.com/vivangkumar/ratelimit/internal/bucket"
)
// NowFunc helps with mocking time.
type NowFunc = func() time.Time
// RateLimiter represents a token bucket based rate limiter.
//
// It is built on top of a bucket that accepts a max size.
// The bucket is refilled at an interval determined by the limiter.
//
// Most callers should use either Wait or WaitN to wait for tokens
// to be available.
type RateLimiter struct {
// bucket is the underlying storage for the rate limiter.
bucket *bucket.Bucket
// refillDuration is the duration after which tokens are refilled.
//
// The duration is calculated based on the limit specified
// at creation time.
refillDuration time.Duration
now NowFunc
m sync.Mutex
// lastRefillUnixNs keeps track of the time when the last
// refresh of tokens took place.
//
// It is kept track in nanoseconds.
lastRefillUnixNs int64
}
// New constructs a rate limiter that accepts the max tokens (size) that
// the limiter holds, along with the limit per duration.
//
// For example: if the bucket is configured with a max of 100 tokens
// and limit is set to 10 over a duration of 1m, this implies that
// the bucket will be refilled with one token every ((1 * 60s) / limit)s.
// This refills the bucket with 1 token every 6s, while giving us a
// max "burst" of 100 tokens.
func New(
max uint64,
limit uint64,
per time.Duration,
opts ...Opt,
) (*RateLimiter, error) {
if limit == 0 {
return nil, fmt.Errorf("limit must be positive")
}
r := &RateLimiter{
bucket: bucket.New(max),
refillDuration: per / time.Duration(limit),
now: time.Now,
}
for _, opt := range opts {
opt(r)
}
r.lastRefillUnixNs = r.now().UnixNano()
return r, nil
}
// refill is responsible for refilling the bucket with
// one token every refill period.
//
// This method is called when attempting to Add as we
// might have to refresh our token count before allowing
// the token to be taken.
func (r *RateLimiter) refill() {
r.m.Lock()
lastRefill := r.lastRefillUnixNs
r.m.Unlock()
now := r.now()
tokens := (now.UnixNano() - lastRefill) / r.refillDuration.Nanoseconds()
if tokens > 0 {
r.m.Lock()
r.lastRefillUnixNs = now.UnixNano()
r.m.Unlock()
r.bucket.Refill(uint64(tokens))
}
}
// Add attempts to take a single token from the bucket.
//
// If there are tokens available, it returns true.
// Otherwise, the method returns false, indicating that we
// have reached the rate limit.
//
// Callers should retry the request to take a token from the
// rate limiter the next time.
func (r *RateLimiter) Add() bool {
return r.AddN(1)
}
// AddN attempts to acquire n tokens from the bucket.
// Its behaviour is details in bucket.takeN.
func (r *RateLimiter) AddN(n uint64) bool {
r.refill()
return r.bucket.TakeN(n)
}
// Wait blocks until a token is available.
//
// It returns an error if the context is cancelled,
// or if the wait time for the context is exceeded.
//
// This method consumes a token if successful.
func (r *RateLimiter) Wait(ctx context.Context) error {
return r.WaitN(ctx, 1)
}
// WaitN blocks until n tokens are available.
//
// It returns an error if the context is cancelled,
// if the wait time for the context is exceeded, or
// if the number of tokens request exceeds the maximum
// available tokens.
//
// This method also consumes n tokens, if successful.
func (r *RateLimiter) WaitN(ctx context.Context, n uint64) error {
if n > r.bucket.Size() {
return fmt.Errorf("tokens requested exceeds max tokens")
}
// Check refillEvery duration to see if a new token is available.
t := time.NewTicker(r.refillDuration)
defer t.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
if !r.AddN(n) {
continue
}
return nil
}
}
}