diff --git a/go.mod b/go.mod index 2112e0c..338d6c0 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/abihf/cache-loader go 1.14 -require github.com/hashicorp/golang-lru v0.5.4 +require ( + github.com/hashicorp/golang-lru v0.5.4 + github.com/stretchr/testify v1.5.1 +) diff --git a/go.sum b/go.sum index 703ee30..beba631 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,12 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/inmemory_cache.go b/inmemory_cache.go new file mode 100644 index 0000000..0273e39 --- /dev/null +++ b/inmemory_cache.go @@ -0,0 +1,23 @@ +package loader + +import "sync" + +type inMemoryCache struct { + sync.Map +} + +func InMemoryCache() Cache { + return &inMemoryCache{sync.Map{}} +} + +func (c *inMemoryCache) Add(key, value interface{}) { + c.Map.Store(key, value) +} + +func (c *inMemoryCache) Get(key interface{}) (value interface{}, ok bool) { + return c.Map.Load(key) +} + +func (c *inMemoryCache) Remove(key interface{}) { + c.Map.Delete(key) +} diff --git a/loader.go b/loader.go index d64d241..8dcc147 100644 --- a/loader.go +++ b/loader.go @@ -2,6 +2,7 @@ package loader import ( "sync" + "sync/atomic" "time" ) @@ -49,19 +50,17 @@ func (l *Loader) Get(key interface{}) (interface{}, error) { item.mutex.Lock() defer item.mutex.Unlock() - if item.expire.Before(time.Now()) && !item.isFetching { - item.isFetching = true // so other thread don't fetch + // if the item is expired and it's not doing refetch + if item.expire.Before(time.Now()) && atomic.CompareAndSwapInt32(&item.isFetching, 0, 1) { go l.refetch(key, item) } return item.value, nil } - item := &cacheItem{isFetching: true, mutex: sync.Mutex{}} + item := &cacheItem{isFetching: 0, mutex: sync.Mutex{}} item.mutex.Lock() defer item.mutex.Unlock() - defer func() { - item.isFetching = false - }() + l.cache.Add(key, item) l.mutex.Unlock() @@ -76,16 +75,17 @@ func (l *Loader) Get(key interface{}) (interface{}, error) { } func (l *Loader) refetch(key interface{}, item *cacheItem) { - item.isFetching = true // to make sure, lol - defer func() { - item.isFetching = false - }() + defer atomic.StoreInt32(&item.isFetching, 0) value, err := l.fn(key) if err != nil { l.cache.Remove(key) return } + + item.mutex.Lock() + defer item.mutex.Unlock() + item.value = value item.updateExpire(l.ttl) } @@ -95,7 +95,7 @@ type cacheItem struct { expire time.Time mutex sync.Mutex - isFetching bool + isFetching int32 } func (i *cacheItem) updateExpire(ttl time.Duration) { diff --git a/loader_test.go b/loader_test.go new file mode 100644 index 0000000..211d6a6 --- /dev/null +++ b/loader_test.go @@ -0,0 +1,120 @@ +package loader + +import ( + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestConcurrencySingleKey(t *testing.T) { + var counter int32 + fetch := func(key interface{}) (interface{}, error) { + atomic.AddInt32(&counter, 1) + time.Sleep(100 * time.Millisecond) + return key, nil + } + l := New(fetch, 500*time.Millisecond, InMemoryCache()) + type result struct { + dur time.Duration + val interface{} + } + c := make(chan *result, 3) + + var start time.Time + var dur time.Duration + + start = time.Now() + for i := 0; i < 3; i++ { + go func() { + start := time.Now() + val, _ := l.Get("x") + c <- &result{val: val, dur: time.Now().Sub(start)} + }() + time.Sleep(10 * time.Millisecond) + } + for i := 0; i < 3; i++ { + res := <-c + assert.InDelta(t, 100, res.dur.Milliseconds(), 25, "each get should within 1s") + assert.Equal(t, "x", res.val, "Value must be x") + } + dur = time.Now().Sub(start) + assert.InDelta(t, 100, dur.Milliseconds(), 25, "all get should within 1s") + + start = time.Now() + val, _ := l.Get("x") + dur = time.Now().Sub(start) + assert.Less(t, dur.Milliseconds(), int64(50), "After cached get must be fast") + assert.Equal(t, "x", val, "Value must still be x") + + assert.Equal(t, int32(1), counter, "fetch must be called once") +} + +func TestConcurrencyMultiKey(t *testing.T) { + var counter int32 + fetch := func(key interface{}) (interface{}, error) { + atomic.AddInt32(&counter, 1) + time.Sleep(100 * time.Millisecond) + return key, nil + } + l := New(fetch, 500*time.Millisecond, InMemoryCache()) + type result struct { + dur time.Duration + val interface{} + } + c := make(chan *result, 3) + + var start time.Time + var dur time.Duration + + start = time.Now() + for i := 0; i < 3; i++ { + go func(i int) { + start := time.Now() + val, _ := l.Get(fmt.Sprint(i)) + c <- &result{val: val, dur: time.Now().Sub(start)} + }(i) + time.Sleep(10 * time.Millisecond) + } + for i := 0; i < 3; i++ { + res := <-c + assert.InDelta(t, 100, res.dur.Milliseconds(), 25, "each get should within 1s") + assert.Equal(t, fmt.Sprint(i), res.val, "Value must be valid") + } + dur = time.Now().Sub(start) + assert.InDelta(t, 100, dur.Milliseconds(), 25, "all get should within 1s") + + start = time.Now() + val, _ := l.Get("1") + dur = time.Now().Sub(start) + assert.Less(t, dur.Milliseconds(), int64(50), "After cached get must be fast") + assert.Equal(t, "1", val, "Value must still the same") + + assert.Equal(t, int32(3), counter, "fetch must be called once") +} + +func TestExpire(t *testing.T) { + var counter int32 + fetch := func(key interface{}) (interface{}, error) { + atomic.AddInt32(&counter, 1) + time.Sleep(10 * time.Millisecond) + return fmt.Sprintf("%d %s", counter, key), nil + } + l := New(fetch, 500*time.Millisecond, InMemoryCache()) + val, _ := l.Get("x") + assert.Equal(t, "1 x", val, "First call") + assert.Equal(t, int32(1), counter, "fetch called once") + + time.Sleep(550 * time.Millisecond) + val, _ = l.Get("x") + assert.Equal(t, "1 x", val, "Use stale value") + val, _ = l.Get("x") + assert.Equal(t, "1 x", val, "Still use stale value") + + time.Sleep(100 * time.Millisecond) + val, _ = l.Get("x") + assert.Equal(t, "2 x", val, "Use updated value") + assert.Equal(t, int32(2), counter, "fetch called twice") +} diff --git a/lru.go b/lru.go index 3c9dc9c..7412c7b 100644 --- a/lru.go +++ b/lru.go @@ -23,6 +23,9 @@ func (c *LRUCache) Remove(key interface{}) { // NewLRU creates Loader with lru based cache func NewLRU(fn LoadFunc, ttl time.Duration, size int) *Loader { - cache, _ := lru.New(size) + cache, err := lru.New(size) + if err != nil { + panic(err) + } return New(fn, ttl, &LRUCache{cache}) }