Skip to content

Commit 6dc12cf

Browse files
committed
#1934: Update unit tests to check all relevant fields in NodeLBData
1 parent 0eee330 commit 6dc12cf

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

src/vt/vrt/collection/balance/node_lb_data.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ NodeLBData::getUserData() const {
103103
return &lb_data_->user_defined_lb_info_;
104104
}
105105

106-
std::unordered_map<PhaseType, DataMapType> const*
106+
DataMapBufferType const*
107107
NodeLBData::getPhaseAttributes() const {
108108
return &lb_data_->node_user_attributes_;
109109
}
@@ -147,6 +147,7 @@ void NodeLBData::resizeLBDataHistory(uint32_t new_hist_len) {
147147
lb_data_->node_subphase_comm_.resize(new_hist_len);
148148
lb_data_->user_defined_lb_info_.resize(new_hist_len);
149149
lb_data_->user_defined_json_.resize(new_hist_len);
150+
lb_data_->node_user_attributes_.resize(new_hist_len);
150151
}
151152

152153
NodeLBData::node_migrate_.clear();

src/vt/vrt/collection/balance/node_lb_data.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ struct NodeLBData : runtime::component::Component<NodeLBData> {
194194
*
195195
* \return an observer pointer to the user-defined attributes
196196
*/
197-
std::unordered_map<PhaseType, DataMapType> const* getPhaseAttributes() const;
197+
DataMapBufferType const* getPhaseAttributes() const;
198198

199199
/**
200200
* \internal \brief Get stored object comm data for a specific phase

tests/unit/collection/test_lb_data_holder.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ TEST_F(TestLBDataHolder, test_lb_entity_attributes) {
271271
auto id = vt::vrt::collection::balance::ElementIDStruct{524291, 0};
272272

273273
LBDataHolder testObj(json);
274-
EXPECT_TRUE(testObj.node_user_attributes_.find(0) != testObj.node_user_attributes_.end());
274+
EXPECT_TRUE(testObj.node_user_attributes_.contains(0));
275275
EXPECT_TRUE(testObj.node_user_attributes_[0].find(id) != testObj.node_user_attributes_[0].end());
276276
auto attributes = testObj.node_user_attributes_[0][id];
277277
EXPECT_EQ(123, std::get<int>(attributes["intSample"]));

tests/unit/collection/test_lb_data_retention.cc

+17-8
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ void validatePersistedPhases(std::vector<PhaseType> expected_phases) {
6565
EXPECT_EQ(expected_phases.size(), theNodeLBData()->getLBData()->node_data_.size());
6666
EXPECT_EQ(expected_phases.size(), theNodeLBData()->getLBData()->node_subphase_comm_.size());
6767
EXPECT_EQ(expected_phases.size(), theNodeLBData()->getLBData()->user_defined_json_.size());
68-
// EXPECT_EQ(expected_phases.size(), theNodeLBData()->getLBData()->user_defined_lb_info_.size());
68+
EXPECT_EQ(expected_phases.size(), theNodeLBData()->getLBData()->user_defined_lb_info_.size());
69+
EXPECT_EQ(expected_phases.size(), theNodeLBData()->getLBData()->node_user_attributes_.size());
6970
// Check if each phase is present
7071
for(auto&& phase : expected_phases) {
7172
EXPECT_TRUE(theNodeLBData()->getLBData()->node_comm_.contains(phase));
7273
EXPECT_TRUE(theNodeLBData()->getLBData()->node_data_.contains(phase));
7374
EXPECT_TRUE(theNodeLBData()->getLBData()->node_subphase_comm_.contains(phase));
7475
EXPECT_TRUE(theNodeLBData()->getLBData()->user_defined_json_.contains(phase));
75-
// EXPECT_TRUE(theNodeLBData()->getLBData()->user_defined_lb_info_.contains(phase));
76+
EXPECT_TRUE(theNodeLBData()->getLBData()->user_defined_lb_info_.contains(phase));
77+
EXPECT_TRUE(theNodeLBData()->getLBData()->node_user_attributes_.contains(phase));
7678
}
7779
#else
7880
(void)expected_phases;
@@ -81,21 +83,20 @@ void validatePersistedPhases(std::vector<PhaseType> expected_phases) {
8183
EXPECT_EQ(0, theNodeLBData()->getLBData()->node_subphase_comm_.size());
8284
EXPECT_EQ(0, theNodeLBData()->getLBData()->user_defined_json_.size());
8385
EXPECT_EQ(0, theNodeLBData()->getLBData()->user_defined_lb_info_.size());
86+
EXPECT_EQ(0, theNodeLBData()->getLBData()->node_user_attributes_.size());
8487
#endif
8588
}
8689

8790
struct TestCol : vt::Collection<TestCol,vt::Index1D> {
8891
unsigned int prev_calls_ = thePhase()->getCurrentPhase();
8992

90-
TestCol() {
91-
// Insert dummy lb info data
92-
valInsert("foo", 10, true, true, true);
93-
}
94-
9593
unsigned int prevCalls() { return prev_calls_++; }
9694

97-
static void colHandler(TestCol* col) {
95+
static void insertValue(TestCol* col) {
96+
col->valInsert("foo", 10, true, true, true);
97+
}
9898

99+
static void colHandler(TestCol* col) {
99100
auto& lb_data = col->lb_data_;
100101
auto load_phase_count = lb_data.getLoadPhaseCount();
101102
auto comm_phase_count = lb_data.getCommPhaseCount();
@@ -175,6 +176,7 @@ TEST_F(TestLBDataRetention, test_lbdata_retention_last1) {
175176
for (int i=0; i<num_phases; ++i) {
176177
runInEpochCollective([&]{
177178
// Do some work.
179+
proxy.broadcastCollective<TestCol::insertValue>();
178180
proxy.broadcastCollective<TestCol::colHandler>();
179181
});
180182
// Go to the next phase.
@@ -214,6 +216,7 @@ TEST_F(TestLBDataRetention, test_lbdata_retention_last2) {
214216
for (int i=0; i<num_phases; ++i) {
215217
runInEpochCollective([&]{
216218
// Do some work.
219+
proxy.broadcastCollective<TestCol::insertValue>();
217220
proxy.broadcastCollective<TestCol::colHandler>();
218221
});
219222
// Go to the next phase.
@@ -253,6 +256,7 @@ TEST_F(TestLBDataRetention, test_lbdata_retention_last4) {
253256
for (int i=0; i<num_phases; ++i) {
254257
runInEpochCollective([&]{
255258
// Do some work.
259+
proxy.broadcastCollective<TestCol::insertValue>();
256260
proxy.broadcastCollective<TestCol::colHandler>();
257261
});
258262
// Go to the next phase.
@@ -295,6 +299,7 @@ TEST_F(TestLBDataRetention, test_lbdata_config_retention_higher) {
295299
for (uint32_t i=0; i<theConfig()->vt_lb_data_retention * 2; ++i) {
296300
runInEpochCollective([&]{
297301
// Do some work.
302+
proxy.broadcastCollective<TestCol::insertValue>();
298303
proxy.broadcastCollective<TestCol::colHandler>();
299304
});
300305
// Go to the next phase.
@@ -336,6 +341,7 @@ TEST_F(TestLBDataRetention, test_lbdata_retention_model_switch_1) {
336341
for (uint32_t i=0; i<first_stage_num_phases; ++i) {
337342
runInEpochCollective([&]{
338343
// Do some work.
344+
proxy.broadcastCollective<TestCol::insertValue>();
339345
proxy.broadcastCollective<TestCol::colHandler>();
340346
});
341347
// Go to the next phase.
@@ -354,6 +360,7 @@ TEST_F(TestLBDataRetention, test_lbdata_retention_model_switch_1) {
354360
for (uint32_t i=0; i<first_stage_num_phases; ++i) {
355361
runInEpochCollective([&]{
356362
// Do some work.
363+
proxy.broadcastCollective<TestCol::insertValue>();
357364
proxy.broadcastCollective<TestCol::colHandler>();
358365
});
359366
// Go to the next phase.
@@ -393,6 +400,7 @@ TEST_F(TestLBDataRetention, test_lbdata_retention_model_switch_2) {
393400
for (uint32_t i=0; i<first_stage_num_phases; ++i) {
394401
runInEpochCollective([&]{
395402
// Do some work.
403+
proxy.broadcastCollective<TestCol::insertValue>();
396404
proxy.broadcastCollective<TestCol::colHandler>();
397405
});
398406
// Go to the next phase.
@@ -416,6 +424,7 @@ TEST_F(TestLBDataRetention, test_lbdata_retention_model_switch_2) {
416424
for (uint32_t i=0; i<10; ++i) {
417425
runInEpochCollective([&]{
418426
// Do some work.
427+
proxy.broadcastCollective<TestCol::insertValue>();
419428
proxy.broadcastCollective<TestCol::colHandler>();
420429
});
421430
// Go to the next phase.

0 commit comments

Comments
 (0)