Skip to content

Commit

Permalink
Add multi tasks to training script
Browse files Browse the repository at this point in the history
  • Loading branch information
Kajiih committed May 31, 2024
1 parent 2f8be61 commit 038c3db
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
27 changes: 21 additions & 6 deletions examples/benchmark/tasks_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ class AvailableTask(StrEnum):
"""Available tasks for training."""

# Complex tasks
PREPARE_MEAL = TaskType.PREPARE_MEAL
RELAX_ON_SOFA = TaskType.RELAX_ON_SOFA
READ_BOOK_IN_BED = TaskType.READ_BOOK_IN_BED
SETUP_BATH = TaskType.SETUP_BATH
MULTI_TASK = "MultiTask"
CLEAN_UP_KITCHEN = TaskType.CLEAN_UP_KITCHEN
CLEAN_UP_LIVING_ROOM = TaskType.CLEAN_UP_LIVING_ROOM
CLEAN_UP_BEDROOM = TaskType.CLEAN_UP_BEDROOM
CLEAN_UP_BATHROOM = TaskType.CLEAN_UP_BATHROOM
MULTI_TASK_4 = "MultiTask4"
MULTI_TASK_8 = "MultiTask8"
PREPARE_MEAL = TaskType.PREPARE_MEAL
RELAX_ON_SOFA = TaskType.RELAX_ON_SOFA
READ_BOOK_IN_BED = TaskType.READ_BOOK_IN_BED
SETUP_BATH = TaskType.SETUP_BATH

# Gradual tasks
# 1 item
Expand Down Expand Up @@ -457,14 +458,28 @@ def keep_only_n_scenes(task_blueprint_config: dict[str, Any], nb_scenes: int) ->
def get_task_blueprint_config(task: AvailableTask, nb_scenes: int) -> list[dict[str, Any]]:
"""Return the scenes for the task."""
match task:
case AvailableTask.MULTI_TASK:
case AvailableTask.MULTI_TASK_4:
return [
keep_only_n_scenes(task_blueprints_configs[task], nb_scenes)
for task in (
AvailableTask.CLEAN_UP_KITCHEN,
AvailableTask.CLEAN_UP_LIVING_ROOM,
AvailableTask.CLEAN_UP_BEDROOM,
AvailableTask.CLEAN_UP_BATHROOM,
)
]
case AvailableTask.MULTI_TASK_8:
return [
keep_only_n_scenes(task_blueprints_configs[task], nb_scenes)
for task in (
AvailableTask.PREPARE_MEAL,
AvailableTask.RELAX_ON_SOFA,
AvailableTask.READ_BOOK_IN_BED,
AvailableTask.SETUP_BATH,
AvailableTask.CLEAN_UP_KITCHEN,
AvailableTask.CLEAN_UP_LIVING_ROOM,
AvailableTask.CLEAN_UP_BEDROOM,
AvailableTask.CLEAN_UP_BATHROOM,
)
]
case _:
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmark/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main(
do_eval (bool): Evaluate the agent. !! Don't eval with a different environment in a Docker container, both rendering windows might be mixed up.
randomize_agent_position (bool): Randomize the agent position in the environment.
"""
is_single_task = task != AvailableTask.MULTI_TASK
is_single_task = task not in {AvailableTask.MULTI_TASK_4, AvailableTask.MULTI_TASK_8}
if is_single_task:
model_config["policy_type"] = "CnnPolicy"
else:
Expand Down

0 comments on commit 038c3db

Please sign in to comment.