@@ -274,6 +274,15 @@ Default: false
274
274
instead of the processor-average load.
275
275
)"
276
276
},
277
+ {
278
+ " memory_threshold" ,
279
+ R"(
280
+ Values: <double>
281
+ Defaut: 0
282
+ Description: The memory threshold TemperedLB should strictly stay under which is
283
+ respected if memory information is present in the user-defined data.
284
+ )"
285
+ }
277
286
};
278
287
return keys_help;
279
288
}
@@ -378,6 +387,7 @@ void TemperedLB::inputParams(balance::ConfigEntry* config) {
378
387
deterministic_ = config->getOrDefault <bool >(" deterministic" , deterministic_);
379
388
rollback_ = config->getOrDefault <bool >(" rollback" , rollback_);
380
389
target_pole_ = config->getOrDefault <bool >(" targetpole" , target_pole_);
390
+ mem_thresh_ = config->getOrDefault <double >(" memory_threshold" , mem_thresh_);
381
391
382
392
balance::LBArgsEnumConverter<CriterionEnum> criterion_converter_ (
383
393
" criterion" , " CriterionEnum" , {
@@ -509,6 +519,98 @@ void TemperedLB::runLB(LoadType total_load) {
509
519
}
510
520
}
511
521
522
+ void TemperedLB::readClustersMemoryData () {
523
+ if (user_data_) {
524
+ for (auto const & [obj, data_map] : *user_data_) {
525
+ SharedIDType shared_id = -1 ;
526
+ BytesType shared_bytes = 0 ;
527
+ BytesType working_bytes = 0 ;
528
+ for (auto const & [key, variant] : data_map) {
529
+ if (key == " shared_id" ) {
530
+ // Because of how JSON is stored this is always a double, even though
531
+ // it should be an integer
532
+ if (double const * val = std::get_if<double >(&variant)) {
533
+ shared_id = static_cast <int >(*val);
534
+ } else {
535
+ vtAbort (" \" shared_id\" in variant does not match integer" );
536
+ }
537
+ }
538
+ if (key == " shared_bytes" ) {
539
+ if (BytesType const * val = std::get_if<BytesType>(&variant)) {
540
+ shared_bytes = *val;
541
+ } else {
542
+ vtAbort (" \" shared_bytes\" in variant does not match double" );
543
+ }
544
+ }
545
+ if (key == " task_working_bytes" ) {
546
+ if (BytesType const * val = std::get_if<BytesType>(&variant)) {
547
+ working_bytes = *val;
548
+ } else {
549
+ vtAbort (" \" working_bytes\" in variant does not match double" );
550
+ }
551
+ }
552
+ if (key == " rank_working_bytes" ) {
553
+ if (BytesType const * val = std::get_if<BytesType>(&variant)) {
554
+ rank_bytes_ = *val;
555
+ } else {
556
+ vtAbort (" \" rank_bytes\" in variant does not match double" );
557
+ }
558
+ }
559
+ // @todo: for now, skip "task_serialized_bytes" and
560
+ // "task_footprint_bytes"
561
+ }
562
+
563
+ // @todo: switch to debug print at some point
564
+ vt_print (
565
+ temperedlb, " obj={} shared_block={} bytes={}\n " ,
566
+ obj, shared_id, shared_bytes
567
+ );
568
+
569
+ obj_shared_block_[obj] = shared_id;
570
+ obj_working_bytes_[obj] = working_bytes;
571
+ shared_block_size_[shared_id] = shared_bytes;
572
+ has_memory_data_ = true ;
573
+ }
574
+ }
575
+ }
576
+
577
+ TemperedLB::BytesType TemperedLB::computeMemoryUsage () const {
578
+ // Compute bytes used by shared blocks mapped here based on object mapping
579
+ auto const blocks_here = getSharedBlocksHere ();
580
+
581
+ double total_shared_bytes = 0 ;
582
+ for (auto const & block_id : blocks_here) {
583
+ total_shared_bytes += shared_block_size_.find (block_id)->second ;
584
+ }
585
+
586
+ // Compute max object size
587
+ // @todo: Slight issue here that this will only count migratable objects
588
+ // (those contained in cur_objs), for our current use case this is not a
589
+ // problem, but it should include the max of non-migratable
590
+ double max_object_working_bytes = 0 ;
591
+ for (auto const & [obj_id, _] : cur_objs_) {
592
+ if (obj_working_bytes_.find (obj_id) != obj_working_bytes_.end ()) {
593
+ max_object_working_bytes =
594
+ std::max (max_object_working_bytes, obj_working_bytes_.find (obj_id)->second );
595
+ } else {
596
+ vt_print (
597
+ temperedlb, " Warning: working bytes not found for object: {}\n " , obj_id
598
+ );
599
+ }
600
+ }
601
+ return rank_bytes_ + total_shared_bytes + max_object_working_bytes;
602
+ }
603
+
604
+ std::set<TemperedLB::SharedIDType> TemperedLB::getSharedBlocksHere () const {
605
+ std::set<SharedIDType> blocks_here;
606
+ for (auto const & [obj, _] : cur_objs_) {
607
+ if (obj_shared_block_.find (obj) != obj_shared_block_.end ()) {
608
+ blocks_here.insert (obj_shared_block_.find (obj)->second );
609
+ }
610
+ }
611
+ return blocks_here;
612
+ }
613
+
512
614
void TemperedLB::doLBStages (LoadType start_imb) {
513
615
decltype (this ->cur_objs_ ) best_objs;
514
616
LoadType best_load = 0 ;
@@ -517,6 +619,9 @@ void TemperedLB::doLBStages(LoadType start_imb) {
517
619
518
620
auto this_node = theContext ()->getNode ();
519
621
622
+ // Read in memory information if it's available before be do any trials
623
+ readClustersMemoryData ();
624
+
520
625
for (trial_ = 0 ; trial_ < num_trials_; ++trial_) {
521
626
// Clear out data structures
522
627
selected_.clear ();
@@ -554,6 +659,13 @@ void TemperedLB::doLBStages(LoadType start_imb) {
554
659
LoadType (this_new_load_)
555
660
);
556
661
662
+ vt_print (
663
+ temperedlb,
664
+ " Current memory info: total memory usage={}, shared blocks here={}, "
665
+ " memory_threshold={}\n " , computeMemoryUsage (), getSharedBlocksHere ().size (),
666
+ mem_thresh_
667
+ );
668
+
557
669
if (isOverloaded (this_new_load_)) {
558
670
is_overloaded_ = true ;
559
671
} else if (isUnderloaded (this_new_load_)) {
0 commit comments