From ff979195285b7cce7904d44441e6a1cd3c6c7914 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 26 Feb 2025 13:53:33 +0200 Subject: [PATCH 1/2] Updated MapBackData to use generics, updated mapBackData tests --- module/mempool/backData.go | 6 +- module/mempool/stdmap/backdata/mapBackData.go | 170 +++++++++--------- .../stdmap/backdata/mapBackData_test.go | 89 +++++---- 3 files changed, 140 insertions(+), 125 deletions(-) diff --git a/module/mempool/backData.go b/module/mempool/backData.go index ff9a3244e96..923bceb3dce 100644 --- a/module/mempool/backData.go +++ b/module/mempool/backData.go @@ -18,21 +18,21 @@ type BackData[K comparable, V any] interface { // Adjust adjusts the value using the provided function if the key is found. // It returns the updated value along with a boolean indicating whether an update occurred. - Adjust(key K, f func(value V) V) (V, bool) + Adjust(key K, f func(value V) (K, V)) (V, bool) // AdjustWithInit adjusts the value using the provided function if the key is found. // If the key is not found, it initializes the value using the given init function and then applies the adjustment. // // Args: // - key: The key for which the value should be adjusted. - // - adjust: the function that adjusts the value. + // - adjust: the function that adjusts the value and returns the updated key and value. // - init: A function that initializes the value if the key is not present. // // Returns: // - the adjusted value. // // - a bool which indicates whether the value was adjusted. - AdjustWithInit(key K, adjust func(value V) V, init func() V) (V, bool) + AdjustWithInit(key K, adjust func(value V) (K, V), init func() V) (V, bool) // GetWithInit returns the value for the given key. // If the key does not exist, it creates a new value using the init function, stores it, and returns it. diff --git a/module/mempool/stdmap/backdata/mapBackData.go b/module/mempool/stdmap/backdata/mapBackData.go index 24c34d79eb0..fb953d25960 100644 --- a/module/mempool/stdmap/backdata/mapBackData.go +++ b/module/mempool/stdmap/backdata/mapBackData.go @@ -1,145 +1,147 @@ package backdata -import ( - "github.com/onflow/flow-go/model/flow" -) - // MapBackData implements a map-based generic memory BackData backed by a Go map. // Note that this implementation is NOT thread-safe, and the higher-level Backend is responsible for concurrency management. -type MapBackData struct { +type MapBackData[K comparable, V any] struct { // NOTE: as a BackData implementation, MapBackData must be non-blocking. // Concurrency management is done by overlay Backend. - entities map[flow.Identifier]flow.Entity + dataMap map[K]V } -func NewMapBackData() *MapBackData { - bd := &MapBackData{ - entities: make(map[flow.Identifier]flow.Entity), +func NewMapBackData[K comparable, V any]() *MapBackData[K, V] { + bd := &MapBackData[K, V]{ + dataMap: make(map[K]V), } return bd } -// Has checks if backdata already contains the entity with the given identifier. -func (b *MapBackData) Has(entityID flow.Identifier) bool { - _, exists := b.entities[entityID] +// Has checks if backdata already contains the value with the given key. +func (b *MapBackData[K, V]) Has(key K) bool { + _, exists := b.dataMap[key] return exists } -// Add adds the given entity to the backdata. -func (b *MapBackData) Add(entityID flow.Identifier, entity flow.Entity) bool { - _, exists := b.entities[entityID] +// Add adds the given value to the backdata. +func (b *MapBackData[K, V]) Add(key K, value V) bool { + _, exists := b.dataMap[key] if exists { return false } - b.entities[entityID] = entity + b.dataMap[key] = value return true } -// Remove removes the entity with the given identifier. -func (b *MapBackData) Remove(entityID flow.Identifier) (flow.Entity, bool) { - entity, exists := b.entities[entityID] +// Remove removes the value with the given key. +func (b *MapBackData[K, V]) Remove(key K) (V, bool) { + value, exists := b.dataMap[key] if !exists { - return nil, false + var zero V + return zero, false } - delete(b.entities, entityID) - return entity, true + delete(b.dataMap, key) + return value, true } -// Adjust adjusts the entity using the given function if the given identifier can be found. -// Returns a bool which indicates whether the entity was updated as well as the updated entity. -func (b *MapBackData) Adjust(entityID flow.Identifier, f func(flow.Entity) flow.Entity) (flow.Entity, bool) { - entity, ok := b.entities[entityID] +// Adjust adjusts the value using the given function if the given key can be found. +// It returns the updated value along with a boolean indicating whether an update occurred. +func (b *MapBackData[K, V]) Adjust(key K, f func(V) (K, V)) (V, bool) { + value, ok := b.dataMap[key] if !ok { - return nil, false + var zero V + return zero, false } - newentity := f(entity) - newentityID := newentity.ID() + newKey, newValue := f(value) - delete(b.entities, entityID) - b.entities[newentityID] = newentity - return newentity, true + delete(b.dataMap, key) + b.dataMap[newKey] = newValue + return newValue, true } -// AdjustWithInit adjusts the entity using the given function if the given identifier can be found. When the -// entity is not found, it initializes the entity using the given init function and then applies the adjust function. +// AdjustWithInit adjusts the value using the provided function if the key is found. +// If the key is not found, it initializes the value using the given init function and then applies the adjustment. +// // Args: -// - entityID: the identifier of the entity to adjust. -// - adjust: the function that adjusts the entity. -// - init: the function that initializes the entity when it is not found. +// - key: The key for which the value should be adjusted. +// - adjust: the function that adjusts the value. +// - init: A function that initializes the value if the key is not present. +// // Returns: -// - the adjusted entity. +// - the adjusted value. // -// - a bool which indicates whether the entity was adjusted. -func (b *MapBackData) AdjustWithInit(entityID flow.Identifier, adjust func(flow.Entity) flow.Entity, init func() flow.Entity) (flow.Entity, bool) { - if b.Has(entityID) { - return b.Adjust(entityID, adjust) +// - a bool which indicates whether the value was adjusted. +func (b *MapBackData[K, V]) AdjustWithInit(key K, adjust func(V) (K, V), init func() V) (V, bool) { + if b.Has(key) { + return b.Adjust(key, adjust) } - b.Add(entityID, init()) - return b.Adjust(entityID, adjust) + b.Add(key, init()) + return b.Adjust(key, adjust) } -// GetWithInit returns the given entity from the backdata. If the entity does not exist, it creates a new entity -// using the factory function and stores it in the backdata. +// GetWithInit returns the value for the given key. +// If the key does not exist, it creates a new value using the init function, stores it, and returns it. +// // Args: -// - entityID: the identifier of the entity to get. -// - init: the function that initializes the entity when it is not found. +// - key: The key for which the value should be retrieved. +// - init: A function that initializes the value if the key is not present. +// // Returns: -// - the entity. -// - a bool which indicates whether the entity was found (or created). -func (b *MapBackData) GetWithInit(entityID flow.Identifier, init func() flow.Entity) (flow.Entity, bool) { - if b.Has(entityID) { - return b.ByID(entityID) +// - the value. +// - a bool which indicates whether the value was found (or created). +func (b *MapBackData[K, V]) GetWithInit(key K, init func() V) (V, bool) { + if b.Has(key) { + return b.ByID(key) } - b.Add(entityID, init()) - return b.ByID(entityID) + b.Add(key, init()) + return b.ByID(key) } -// ByID returns the given entity from the backdata. -func (b *MapBackData) ByID(entityID flow.Identifier) (flow.Entity, bool) { - entity, exists := b.entities[entityID] +// ByID returns the value for the given key. +func (b *MapBackData[K, V]) ByID(key K) (V, bool) { + value, exists := b.dataMap[key] if !exists { - return nil, false + var zero V + return zero, false } - return entity, true + return value, true } -// Size returns the size of the backdata, i.e., total number of stored (entityId, entity) -func (b *MapBackData) Size() uint { - return uint(len(b.entities)) +// Size returns the number of stored key-value pairs. +func (b *MapBackData[K, V]) Size() uint { + return uint(len(b.dataMap)) } -// All returns all entities stored in the backdata. -func (b *MapBackData) All() map[flow.Identifier]flow.Entity { - entities := make(map[flow.Identifier]flow.Entity) - for entityID, entity := range b.entities { - entities[entityID] = entity +// All returns all stored key-value pairs as a map. +func (b *MapBackData[K, V]) All() map[K]V { + values := make(map[K]V) + for key, value := range b.dataMap { + values[key] = value } - return entities + return values } -// Identifiers returns the list of identifiers of entities stored in the backdata. -func (b *MapBackData) Identifiers() flow.IdentifierList { - ids := make(flow.IdentifierList, len(b.entities)) +// Identifiers returns the list of keys of values stored in the backdata. +func (b *MapBackData[K, V]) Identifiers() []K { + keys := make([]K, len(b.dataMap)) i := 0 - for entityID := range b.entities { - ids[i] = entityID + for key := range b.dataMap { + keys[i] = key i++ } - return ids + return keys } -// Entities returns the list of entities stored in the backdata. -func (b *MapBackData) Entities() []flow.Entity { - entities := make([]flow.Entity, len(b.entities)) +// Entities returns the list of values stored in the backdata. +func (b *MapBackData[K, V]) Entities() []V { + values := make([]V, len(b.dataMap)) i := 0 - for _, entity := range b.entities { - entities[i] = entity + for _, value := range b.dataMap { + values[i] = value i++ } - return entities + return values } -// Clear removes all entities from the backdata. -func (b *MapBackData) Clear() { - b.entities = make(map[flow.Identifier]flow.Entity) +// Clear removes all values from the backdata. +func (b *MapBackData[K, V]) Clear() { + b.dataMap = make(map[K]V) } diff --git a/module/mempool/stdmap/backdata/mapBackData_test.go b/module/mempool/stdmap/backdata/mapBackData_test.go index a5a13b70bb8..3081c514a65 100644 --- a/module/mempool/stdmap/backdata/mapBackData_test.go +++ b/module/mempool/stdmap/backdata/mapBackData_test.go @@ -10,13 +10,13 @@ import ( ) func TestMapBackData_StoreAnd(t *testing.T) { - backData := NewMapBackData() + backData := NewMapBackData[flow.Identifier, unittest.MockEntity]() entities := unittest.EntityListFixture(100) // Add for _, e := range entities { // all entities must be stored successfully - require.True(t, backData.Add(e.ID(), e)) + require.True(t, backData.Add(e.ID(), *e)) } // ByID @@ -24,7 +24,7 @@ func TestMapBackData_StoreAnd(t *testing.T) { // all entities must be retrievable successfully actual, ok := backData.ByID(expected.ID()) require.True(t, ok) - require.Equal(t, expected, actual) + require.Equal(t, expected, &actual) } // All @@ -33,7 +33,7 @@ func TestMapBackData_StoreAnd(t *testing.T) { for _, expected := range entities { actual, ok := backData.ByID(expected.ID()) require.True(t, ok) - require.Equal(t, expected, actual) + require.Equal(t, expected, &actual) } // Identifiers @@ -44,30 +44,32 @@ func TestMapBackData_StoreAnd(t *testing.T) { } // Entities - actualEntities := backData.Entities() - require.Equal(t, len(entities), len(actualEntities)) - require.ElementsMatch(t, entities, actualEntities) + requireEntitiesMatch(t, entities, backData.Entities()) } // TestMapBackData_AdjustWithInit tests the AdjustWithInit method of the MapBackData. // Note that as the backdata is not inherently thread-safe, this test is not concurrent. func TestMapBackData_AdjustWithInit(t *testing.T) { - backData := NewMapBackData() + backData := NewMapBackData[flow.Identifier, unittest.MockEntity]() entities := unittest.EntityListFixture(100) ids := flow.GetIDs(entities) // AdjustWithInit for _, e := range entities { // all entities must be adjusted successfully - actual, ok := backData.AdjustWithInit(e.ID(), func(entity flow.Entity) flow.Entity { + actual, ok := backData.AdjustWithInit(e.ID(), func(entity unittest.MockEntity) (flow.Identifier, unittest.MockEntity) { // increment nonce of the entity - entity.(*unittest.MockEntity).Nonce++ - return entity - }, func() flow.Entity { - return e + entity.Nonce++ + return entity.ID(), entity + }, func() unittest.MockEntity { + return *e }) require.True(t, ok) - require.Equal(t, e, actual) + + // Manually update e to reflect the expected change + e.Nonce++ + + require.Equal(t, e, &actual) } // All @@ -77,7 +79,7 @@ func TestMapBackData_AdjustWithInit(t *testing.T) { actual, ok := backData.ByID(expected.ID()) require.True(t, ok) require.Equal(t, expected.ID(), actual.ID()) - require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce) + require.Equal(t, uint64(1), actual.Nonce) } // Identifiers @@ -89,9 +91,7 @@ func TestMapBackData_AdjustWithInit(t *testing.T) { } // Entities - actualEntities := backData.Entities() - require.Equal(t, len(entities), len(actualEntities)) - require.ElementsMatch(t, entities, actualEntities) + requireEntitiesMatch(t, entities, backData.Entities()) // ByID for _, e := range entities { @@ -99,36 +99,36 @@ func TestMapBackData_AdjustWithInit(t *testing.T) { actual, ok := backData.ByID(e.ID()) require.True(t, ok) require.Equal(t, e.ID(), actual.ID()) - require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce) + require.Equal(t, uint64(1), actual.Nonce) } // GetWithInit for _, e := range entities { // all entities must be retrieved successfully - actual, ok := backData.GetWithInit(e.ID(), func() flow.Entity { + actual, ok := backData.GetWithInit(e.ID(), func() unittest.MockEntity { require.Fail(t, "should not be called") // entity has already been initialized - return e + return *e }) require.True(t, ok) require.Equal(t, e.ID(), actual.ID()) - require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce) + require.Equal(t, uint64(1), actual.Nonce) } } // TestMapBackData_GetWithInit tests the GetWithInit method of the MapBackData. // Note that as the backdata is not inherently thread-safe, this test is not concurrent. func TestMapBackData_GetWithInit(t *testing.T) { - backData := NewMapBackData() + backData := NewMapBackData[flow.Identifier, unittest.MockEntity]() entities := unittest.EntityListFixture(100) // GetWithInit for _, e := range entities { // all entities must be initialized retrieved successfully - actual, ok := backData.GetWithInit(e.ID(), func() flow.Entity { - return e // initialize with the entity + actual, ok := backData.GetWithInit(e.ID(), func() unittest.MockEntity { + return *e // initialize with the entity }) require.True(t, ok) - require.Equal(t, e, actual) + require.Equal(t, e, &actual) } // All @@ -137,7 +137,7 @@ func TestMapBackData_GetWithInit(t *testing.T) { for _, expected := range entities { actual, ok := backData.ByID(expected.ID()) require.True(t, ok) - require.Equal(t, expected, actual) + require.Equal(t, expected, &actual) } // Identifiers @@ -148,20 +148,22 @@ func TestMapBackData_GetWithInit(t *testing.T) { } // Entities - actualEntities := backData.Entities() - require.Equal(t, len(entities), len(actualEntities)) - require.ElementsMatch(t, entities, actualEntities) + requireEntitiesMatch(t, entities, backData.Entities()) // Adjust for _, e := range entities { // all entities must be adjusted successfully - actual, ok := backData.Adjust(e.ID(), func(entity flow.Entity) flow.Entity { + actual, ok := backData.Adjust(e.ID(), func(entity unittest.MockEntity) (flow.Identifier, unittest.MockEntity) { // increment nonce of the entity - entity.(*unittest.MockEntity).Nonce++ - return entity + entity.Nonce++ + return entity.ID(), entity }) require.True(t, ok) - require.Equal(t, e, actual) + + // Manually update e to reflect the expected change + e.Nonce++ + + require.Equal(t, e, &actual) } // ByID; should return the latest version of the entity @@ -170,18 +172,29 @@ func TestMapBackData_GetWithInit(t *testing.T) { actual, ok := backData.ByID(e.ID()) require.True(t, ok) require.Equal(t, e.ID(), actual.ID()) - require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce) + require.Equal(t, uint64(1), actual.Nonce) } // GetWithInit; should return the latest version of the entity, than increment the nonce for _, e := range entities { // all entities must be retrieved successfully - actual, ok := backData.GetWithInit(e.ID(), func() flow.Entity { + actual, ok := backData.GetWithInit(e.ID(), func() unittest.MockEntity { require.Fail(t, "should not be called") // entity has already been initialized - return e + return *e }) require.True(t, ok) require.Equal(t, e.ID(), actual.ID()) - require.Equal(t, uint64(1), actual.(*unittest.MockEntity).Nonce) + require.Equal(t, uint64(1), actual.Nonce) + } +} + +// requireEntitiesMatch is a helper function to validate entity lists +func requireEntitiesMatch(t *testing.T, expected []*unittest.MockEntity, actualEntities []unittest.MockEntity) { + require.Equal(t, len(expected), len(actualEntities)) + + expectedEntities := make([]unittest.MockEntity, len(expected)) + for i, e := range expected { + expectedEntities[i] = *e } + require.ElementsMatch(t, expectedEntities, actualEntities) } From e5aa42e1d8f7ac219f660b68943a1aaf3cc8ebdc Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 26 Feb 2025 20:14:14 +0200 Subject: [PATCH 2/2] Updated MapBackData usages in Backend --- module/mempool/mutable_back_data.go | 4 ++-- module/mempool/stdmap/backend.go | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/module/mempool/mutable_back_data.go b/module/mempool/mutable_back_data.go index 7de7087b953..e04560aa868 100644 --- a/module/mempool/mutable_back_data.go +++ b/module/mempool/mutable_back_data.go @@ -11,7 +11,7 @@ type MutableBackData[K comparable, V any] interface { // Adjust adjusts the value using the given function if the given key can be found. // Returns a bool which indicates whether the value was updated as well as the updated value. - Adjust(key K, f func(value V) V) (V, bool) + Adjust(key K, f func(value V) (K, V)) (V, bool) // AdjustWithInit adjusts the value using the given function if the given key can be found. When the // value is not found, it initializes the value using the given init function and then applies the adjust function. @@ -23,5 +23,5 @@ type MutableBackData[K comparable, V any] interface { // - the adjusted value. // // - a bool which indicates whether the value was adjusted. - AdjustWithInit(key K, adjust func(value V) V, init func() V) (V, bool) + AdjustWithInit(key K, adjust func(value V) (K, V), init func() V) (V, bool) } diff --git a/module/mempool/stdmap/backend.go b/module/mempool/stdmap/backend.go index 54d78be4e7f..3d7cecb7cc9 100644 --- a/module/mempool/stdmap/backend.go +++ b/module/mempool/stdmap/backend.go @@ -13,7 +13,7 @@ import ( // Backend is a wrapper around the mutable backdata that provides concurrency-safe operations. type Backend struct { sync.RWMutex - mutableBackData mempool.MutableBackData + mutableBackData mempool.MutableBackData[flow.Identifier, flow.Entity] guaranteedCapacity uint batchEject BatchEjectFunc eject EjectFunc @@ -24,7 +24,7 @@ type Backend struct { // This is using EjectRandomFast() func NewBackend(options ...OptionFunc) *Backend { b := Backend{ - mutableBackData: backdata.NewMapBackData(), + mutableBackData: backdata.NewMapBackData[flow.Identifier, flow.Entity](), guaranteedCapacity: uint(math.MaxUint32), batchEject: EjectRandomFast, eject: nil, @@ -82,7 +82,7 @@ func (b *Backend) Remove(entityID flow.Identifier) bool { // Adjust will adjust the value item using the given function if the given key can be found. // Returns a bool which indicates whether the value was updated. -func (b *Backend) Adjust(entityID flow.Identifier, f func(flow.Entity) flow.Entity) (flow.Entity, bool) { +func (b *Backend) Adjust(entityID flow.Identifier, f func(flow.Entity) (flow.Identifier, flow.Entity)) (flow.Entity, bool) { // bs1 := binstat.EnterTime(binstat.BinStdmap + ".w_lock.(Backend)Adjust") b.Lock() // binstat.Leave(bs1) @@ -120,7 +120,7 @@ func (b *Backend) GetWithInit(entityID flow.Identifier, init func() flow.Entity) // - the adjusted entity. // // - a bool which indicates whether the entity was adjusted. -func (b *Backend) AdjustWithInit(entityID flow.Identifier, adjust func(flow.Entity) flow.Entity, init func() flow.Entity) (flow.Entity, bool) { +func (b *Backend) AdjustWithInit(entityID flow.Identifier, adjust func(flow.Entity) (flow.Identifier, flow.Entity), init func() flow.Entity) (flow.Entity, bool) { b.Lock() defer b.Unlock() @@ -141,7 +141,7 @@ func (b *Backend) ByID(entityID flow.Identifier) (flow.Entity, bool) { } // Run executes a function giving it exclusive access to the backdata -func (b *Backend) Run(f func(backdata mempool.BackData) error) error { +func (b *Backend) Run(f func(backdata mempool.BackData[flow.Identifier, flow.Entity]) error) error { // bs1 := binstat.EnterTime(binstat.BinStdmap + ".w_lock.(Backend)Run") b.Lock() // binstat.Leave(bs1)