From 617c4ee66333dd355abe691dc273e39a0374852f Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Wed, 5 Feb 2025 20:08:07 +0900 Subject: [PATCH] [BugFix] set num_depots to 1 by default if not found in dict --- routefinder/baselines/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/routefinder/baselines/utils.py b/routefinder/baselines/utils.py index e2a2bee..624fc0f 100644 --- a/routefinder/baselines/utils.py +++ b/routefinder/baselines/utils.py @@ -19,6 +19,9 @@ def mtvrp2anyvrp(td: TensorDict) -> TensorDict: td_.set("cost_matrix", cost_mat) backhaul_class = td.get("backhaul_class", torch.ones(td_.batch_size[0], 1)) td_.set("backhaul_class", backhaul_class) + # if there is no num_depots, we assume it is 1 + if "num_depots" not in td_.keys(): + td_.set("num_depots", torch.full((td_.batch_size[0], 1), 1)) return td_