diff --git a/examples/benchmark/tasks_info.py b/examples/benchmark/tasks_info.py index 9c56172..64b9c1e 100644 --- a/examples/benchmark/tasks_info.py +++ b/examples/benchmark/tasks_info.py @@ -12,15 +12,16 @@ class AvailableTask(StrEnum): """Available tasks for training.""" # Complex tasks - PREPARE_MEAL = TaskType.PREPARE_MEAL - RELAX_ON_SOFA = TaskType.RELAX_ON_SOFA - READ_BOOK_IN_BED = TaskType.READ_BOOK_IN_BED - SETUP_BATH = TaskType.SETUP_BATH - MULTI_TASK = "MultiTask" CLEAN_UP_KITCHEN = TaskType.CLEAN_UP_KITCHEN CLEAN_UP_LIVING_ROOM = TaskType.CLEAN_UP_LIVING_ROOM CLEAN_UP_BEDROOM = TaskType.CLEAN_UP_BEDROOM CLEAN_UP_BATHROOM = TaskType.CLEAN_UP_BATHROOM + MULTI_TASK_4 = "MultiTask4" + MULTI_TASK_8 = "MultiTask8" + PREPARE_MEAL = TaskType.PREPARE_MEAL + RELAX_ON_SOFA = TaskType.RELAX_ON_SOFA + READ_BOOK_IN_BED = TaskType.READ_BOOK_IN_BED + SETUP_BATH = TaskType.SETUP_BATH # Gradual tasks # 1 item @@ -457,7 +458,17 @@ def keep_only_n_scenes(task_blueprint_config: dict[str, Any], nb_scenes: int) -> def get_task_blueprint_config(task: AvailableTask, nb_scenes: int) -> list[dict[str, Any]]: """Return the scenes for the task.""" match task: - case AvailableTask.MULTI_TASK: + case AvailableTask.MULTI_TASK_4: + return [ + keep_only_n_scenes(task_blueprints_configs[task], nb_scenes) + for task in ( + AvailableTask.CLEAN_UP_KITCHEN, + AvailableTask.CLEAN_UP_LIVING_ROOM, + AvailableTask.CLEAN_UP_BEDROOM, + AvailableTask.CLEAN_UP_BATHROOM, + ) + ] + case AvailableTask.MULTI_TASK_8: return [ keep_only_n_scenes(task_blueprints_configs[task], nb_scenes) for task in ( @@ -465,6 +476,10 @@ def get_task_blueprint_config(task: AvailableTask, nb_scenes: int) -> list[dict[ AvailableTask.RELAX_ON_SOFA, AvailableTask.READ_BOOK_IN_BED, AvailableTask.SETUP_BATH, + AvailableTask.CLEAN_UP_KITCHEN, + AvailableTask.CLEAN_UP_LIVING_ROOM, + AvailableTask.CLEAN_UP_BEDROOM, + AvailableTask.CLEAN_UP_BATHROOM, ) ] case _: diff --git a/examples/benchmark/train.py b/examples/benchmark/train.py index 9742c14..35dc680 100644 --- a/examples/benchmark/train.py +++ b/examples/benchmark/train.py @@ -124,7 +124,7 @@ def main( do_eval (bool): Evaluate the agent. !! Don't eval with a different environment in a Docker container, both rendering windows might be mixed up. randomize_agent_position (bool): Randomize the agent position in the environment. """ - is_single_task = task != AvailableTask.MULTI_TASK + is_single_task = task not in {AvailableTask.MULTI_TASK_4, AvailableTask.MULTI_TASK_8} if is_single_task: model_config["policy_type"] = "CnnPolicy" else: