@@ -26,6 +26,54 @@ class engine;
26
26
using primitive_id = std::string;
27
27
using memory_ptr = std::shared_ptr<memory>;
28
28
29
+ template <typename Key, typename Hash = std::hash<Key>, typename KeyEqual = std::equal_to<Key>>
30
+ class memory_restricter {
31
+ private:
32
+ const std::unordered_set<Key, Hash, KeyEqual>& set1; // Const reference to external set
33
+ std::unordered_set<Key, Hash, KeyEqual> set2; // Internal modifiable set
34
+ static std::unordered_set<Key, Hash, KeyEqual> empty_set; // Static empty set for default
35
+
36
+ public:
37
+ // Default constructor initializes set1 with an empty set
38
+ memory_restricter ()
39
+ : set1(empty_set) {}
40
+
41
+ // Constructor to initialize with a const reference for set1
42
+ explicit memory_restricter (const std::unordered_set<Key, Hash, KeyEqual>& externalSet)
43
+ : set1(externalSet) {}
44
+
45
+ // Insert into set2 (set1 is read-only)
46
+ void insert (const Key& key) {
47
+ if (set1.find (key) == set1.end ())
48
+ set2.insert (key);
49
+ }
50
+
51
+ // Check existence in either set
52
+ bool contains (const Key& key) const {
53
+ return set1.find (key) != set1.end () || set2.find (key) != set2.end ();
54
+ }
55
+
56
+ // Total size of both sets
57
+ size_t size () const {
58
+ return set1.size () + set2.size ();
59
+ }
60
+
61
+ // Check if both sets are empty
62
+ bool empty () const {
63
+ return set1.empty () && set2.empty ();
64
+ }
65
+
66
+ // Iterate over both sets
67
+ void for_each (void (*func)(const Key&)) const {
68
+ for (const auto & key : set1) func (key);
69
+ for (const auto & key : set2) func (key);
70
+ }
71
+ }; // end of memory_restricter
72
+
73
+ // Define the static empty_set
74
+ template <typename Key, typename Hash, typename KeyEqual>
75
+ std::unordered_set<Key, Hash, KeyEqual> memory_restricter<Key, Hash, KeyEqual>::empty_set(0 ); // minimize its memory usage
76
+
29
77
struct memory_user {
30
78
size_t _unique_id;
31
79
uint32_t _network_id;
@@ -112,7 +160,7 @@ struct padded_pool_comparer {
112
160
113
161
class memory_pool {
114
162
memory_ptr alloc_memory (const layout& layout, allocation_type type, bool reset = true );
115
- static bool has_conflict (const memory_set&, const std::unordered_set <uint32_t >&, uint32_t network_id);
163
+ static bool has_conflict (const memory_set&, const memory_restricter <uint32_t >&, uint32_t network_id);
116
164
117
165
std::multimap<uint64_t , memory_record> _non_padded_pool;
118
166
std::map<layout, std::list<memory_record>, padded_pool_comparer> _padded_pool;
@@ -127,7 +175,7 @@ class memory_pool {
127
175
const primitive_id& id,
128
176
size_t unique_id,
129
177
uint32_t network_id,
130
- const std::unordered_set <uint32_t >& restrictions,
178
+ const memory_restricter <uint32_t >& restrictions,
131
179
allocation_type type,
132
180
bool reusable = true ,
133
181
bool reset = true ,
@@ -137,21 +185,16 @@ class memory_pool {
137
185
const primitive_id& prim_id,
138
186
size_t unique_id,
139
187
uint32_t network_id,
140
- const std::unordered_set <uint32_t >&,
188
+ const memory_restricter <uint32_t >&,
141
189
allocation_type type,
142
190
bool reset = true ,
143
191
bool is_dynamic = false );
144
192
memory_ptr get_from_padded_pool (const layout& layout,
145
193
const primitive_id& prim_id,
146
194
size_t unique_id,
147
195
uint32_t network_id,
148
- const std::unordered_set <uint32_t >& restrictions,
196
+ const memory_restricter <uint32_t >& restrictions,
149
197
allocation_type type);
150
- memory_ptr get_from_across_networks_pool (const layout& layout,
151
- const primitive_id& id,
152
- size_t unique_id,
153
- uint32_t network_id,
154
- allocation_type type);
155
198
void clear_pool_for_network (uint32_t network_id);
156
199
void release_memory (memory* memory, const size_t & unique_id, primitive_id prim_id, uint32_t network_id);
157
200
0 commit comments