diff --git a/routefinder/baselines/utils.py b/routefinder/baselines/utils.py index e2a2bee..624fc0f 100644 --- a/routefinder/baselines/utils.py +++ b/routefinder/baselines/utils.py @@ -19,6 +19,9 @@ def mtvrp2anyvrp(td: TensorDict) -> TensorDict: td_.set("cost_matrix", cost_mat) backhaul_class = td.get("backhaul_class", torch.ones(td_.batch_size[0], 1)) td_.set("backhaul_class", backhaul_class) + # if there is no num_depots, we assume it is 1 + if "num_depots" not in td_.keys(): + td_.set("num_depots", torch.full((td_.batch_size[0], 1), 1)) return td_