Skip to content

Commit a77ec7e

Browse files
committed
#2201: temperedlb: implement basic memory information consumption, threshold variable for user
1 parent 1765bd5 commit a77ec7e

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

src/vt/vrt/collection/balance/temperedlb/temperedlb.cc

+112
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,15 @@ Default: false
274274
instead of the processor-average load.
275275
)"
276276
},
277+
{
278+
"memory_threshold",
279+
R"(
280+
Values: <double>
281+
Defaut: 0
282+
Description: The memory threshold TemperedLB should strictly stay under which is
283+
respected if memory information is present in the user-defined data.
284+
)"
285+
}
277286
};
278287
return keys_help;
279288
}
@@ -378,6 +387,7 @@ void TemperedLB::inputParams(balance::ConfigEntry* config) {
378387
deterministic_ = config->getOrDefault<bool>("deterministic", deterministic_);
379388
rollback_ = config->getOrDefault<bool>("rollback", rollback_);
380389
target_pole_ = config->getOrDefault<bool>("targetpole", target_pole_);
390+
mem_thresh_ = config->getOrDefault<double>("memory_threshold", mem_thresh_);
381391

382392
balance::LBArgsEnumConverter<CriterionEnum> criterion_converter_(
383393
"criterion", "CriterionEnum", {
@@ -509,6 +519,98 @@ void TemperedLB::runLB(LoadType total_load) {
509519
}
510520
}
511521

522+
void TemperedLB::readClustersMemoryData() {
523+
if (user_data_) {
524+
for (auto const& [obj, data_map] : *user_data_) {
525+
SharedIDType shared_id = -1;
526+
BytesType shared_bytes = 0;
527+
BytesType working_bytes = 0;
528+
for (auto const& [key, variant] : data_map) {
529+
if (key == "shared_id") {
530+
// Because of how JSON is stored this is always a double, even though
531+
// it should be an integer
532+
if (double const* val = std::get_if<double>(&variant)) {
533+
shared_id = static_cast<int>(*val);
534+
} else {
535+
vtAbort("\"shared_id\" in variant does not match integer");
536+
}
537+
}
538+
if (key == "shared_bytes") {
539+
if (BytesType const* val = std::get_if<BytesType>(&variant)) {
540+
shared_bytes = *val;
541+
} else {
542+
vtAbort("\"shared_bytes\" in variant does not match double");
543+
}
544+
}
545+
if (key == "task_working_bytes") {
546+
if (BytesType const* val = std::get_if<BytesType>(&variant)) {
547+
working_bytes = *val;
548+
} else {
549+
vtAbort("\"working_bytes\" in variant does not match double");
550+
}
551+
}
552+
if (key == "rank_working_bytes") {
553+
if (BytesType const* val = std::get_if<BytesType>(&variant)) {
554+
rank_bytes_ = *val;
555+
} else {
556+
vtAbort("\"rank_bytes\" in variant does not match double");
557+
}
558+
}
559+
// @todo: for now, skip "task_serialized_bytes" and
560+
// "task_footprint_bytes"
561+
}
562+
563+
// @todo: switch to debug print at some point
564+
vt_print(
565+
temperedlb, "obj={} shared_block={} bytes={}\n",
566+
obj, shared_id, shared_bytes
567+
);
568+
569+
obj_shared_block_[obj] = shared_id;
570+
obj_working_bytes_[obj] = working_bytes;
571+
shared_block_size_[shared_id] = shared_bytes;
572+
has_memory_data_ = true;
573+
}
574+
}
575+
}
576+
577+
TemperedLB::BytesType TemperedLB::computeMemoryUsage() const {
578+
// Compute bytes used by shared blocks mapped here based on object mapping
579+
auto const blocks_here = getSharedBlocksHere();
580+
581+
double total_shared_bytes = 0;
582+
for (auto const& block_id : blocks_here) {
583+
total_shared_bytes += shared_block_size_.find(block_id)->second;
584+
}
585+
586+
// Compute max object size
587+
// @todo: Slight issue here that this will only count migratable objects
588+
// (those contained in cur_objs), for our current use case this is not a
589+
// problem, but it should include the max of non-migratable
590+
double max_object_working_bytes = 0;
591+
for (auto const& [obj_id, _] : cur_objs_) {
592+
if (obj_working_bytes_.find(obj_id) != obj_working_bytes_.end()) {
593+
max_object_working_bytes =
594+
std::max(max_object_working_bytes, obj_working_bytes_.find(obj_id)->second);
595+
} else {
596+
vt_print(
597+
temperedlb, "Warning: working bytes not found for object: {}\n", obj_id
598+
);
599+
}
600+
}
601+
return rank_bytes_ + total_shared_bytes + max_object_working_bytes;
602+
}
603+
604+
std::set<TemperedLB::SharedIDType> TemperedLB::getSharedBlocksHere() const {
605+
std::set<SharedIDType> blocks_here;
606+
for (auto const& [obj, _] : cur_objs_) {
607+
if (obj_shared_block_.find(obj) != obj_shared_block_.end()) {
608+
blocks_here.insert(obj_shared_block_.find(obj)->second);
609+
}
610+
}
611+
return blocks_here;
612+
}
613+
512614
void TemperedLB::doLBStages(LoadType start_imb) {
513615
decltype(this->cur_objs_) best_objs;
514616
LoadType best_load = 0;
@@ -517,6 +619,9 @@ void TemperedLB::doLBStages(LoadType start_imb) {
517619

518620
auto this_node = theContext()->getNode();
519621

622+
// Read in memory information if it's available before be do any trials
623+
readClustersMemoryData();
624+
520625
for (trial_ = 0; trial_ < num_trials_; ++trial_) {
521626
// Clear out data structures
522627
selected_.clear();
@@ -554,6 +659,13 @@ void TemperedLB::doLBStages(LoadType start_imb) {
554659
LoadType(this_new_load_)
555660
);
556661

662+
vt_print(
663+
temperedlb,
664+
"Current memory info: total memory usage={}, shared blocks here={}, "
665+
"memory_threshold={}\n", computeMemoryUsage(), getSharedBlocksHere().size(),
666+
mem_thresh_
667+
);
668+
557669
if (isOverloaded(this_new_load_)) {
558670
is_overloaded_ = true;
559671
} else if (isUnderloaded(this_new_load_)) {

src/vt/vrt/collection/balance/temperedlb/temperedlb.h

+41
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ struct TemperedLB : BaseLB {
6868
using ReduceMsgType = vt::collective::ReduceNoneMsg;
6969
using QuantityType = std::map<lb::StatisticQuantity, double>;
7070
using StatisticMapType = std::unordered_map<lb::Statistic, QuantityType>;
71+
using SharedIDType = int;
72+
using BytesType = double;
7173

7274
TemperedLB() = default;
7375
TemperedLB(TemperedLB const&) = delete;
@@ -120,6 +122,27 @@ struct TemperedLB : BaseLB {
120122

121123
void setupDone();
122124

125+
/**
126+
* \brief Read the memory data from the user-defined json blocks into data
127+
* structures
128+
*/
129+
void readClustersMemoryData();
130+
131+
/**
132+
* \brief Compute the memory usage for current assignment
133+
*
134+
* \return the total memory usage
135+
*/
136+
BytesType computeMemoryUsage() const;
137+
138+
/**
139+
* \brief Get the shared blocks that are located on this node with the current
140+
* object assignment
141+
*
142+
* \return the number of shared blocks here
143+
*/
144+
std::set<SharedIDType> getSharedBlocksHere() const;
145+
123146
private:
124147
uint16_t f_ = 0;
125148
uint8_t k_max_ = 0;
@@ -184,6 +207,24 @@ struct TemperedLB : BaseLB {
184207
std::mt19937 gen_sample_;
185208
StatisticMapType stats;
186209
LoadType this_load = 0.0f;
210+
211+
212+
//////////////////////////////////////////////////////////////////////////////
213+
// All the memory info (may or may not be present)
214+
//////////////////////////////////////////////////////////////////////////////
215+
216+
/// Whether we have memory information
217+
bool has_memory_data_ = false;
218+
/// Working bytes for this rank
219+
BytesType rank_bytes_ = 0;
220+
/// Shared ID for each object
221+
std::unordered_map<ObjIDType, SharedIDType> obj_shared_block_;
222+
/// Shared block size in bytes
223+
std::unordered_map<SharedIDType, BytesType> shared_block_size_;
224+
/// Working bytes for each object
225+
std::unordered_map<ObjIDType, BytesType> obj_working_bytes_;
226+
/// User-defined memory threshold
227+
BytesType mem_thresh_ = 0;
187228
};
188229

189230
}}}} /* end namespace vt::vrt::collection::lb */

0 commit comments

Comments
 (0)