diff --git a/src/rl_ai2thor/envs/actions.py b/src/rl_ai2thor/envs/actions.py index da31eeb..63f812a 100644 --- a/src/rl_ai2thor/envs/actions.py +++ b/src/rl_ai2thor/envs/actions.py @@ -26,11 +26,11 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any +from rl_ai2thor.envs.tasks import ObjFixedPropId, ObjVariablePropId from rl_ai2thor.utils.general_utils import nested_dict_get if TYPE_CHECKING: from rl_ai2thor.envs.ai2thor_envs import ITHOREnv - from rl_ai2thor.envs.tasks import ObjFixedPropId from rl_ai2thor.utils.ai2thor_types import EventLike @@ -135,6 +135,7 @@ class ActionCategory(StrEnum): # === Action Classes === # TODO: Change perform to not need the environment +# TODO: Add target value to object required property @dataclass class EnvironmentAction: """ @@ -368,7 +369,11 @@ def __call__(self, env: ITHOREnv) -> bool: bool: Whether the agent has visible running water in its field of view. """ for obj in env.last_event.metadata["objects"]: - if obj["visible"] and obj["isToggled"] and obj["objectType"] in {"Faucet", "ShowerHead"}: + if ( + obj[ObjVariablePropId.VISIBLE] + and obj[ObjVariablePropId.IS_TOGGLED] + and obj[ObjFixedPropId.OBJECT_TYPE] in {"Faucet", "ShowerHead"} + ): return True return False @@ -403,7 +408,7 @@ def __call__(self, env: ITHOREnv) -> bool: """ return ( len(env.last_event.metadata["inventoryObjects"]) > 0 - and env.last_event.metadata["inventoryObjects"][0]["objectType"] == self.object_type + and env.last_event.metadata["inventoryObjects"][0][ObjFixedPropId.OBJECT_TYPE] == self.object_type ) def _base_error_message(self, action: EnvironmentAction) -> str: @@ -770,8 +775,7 @@ def __init__(self, ai2thor_action: str) -> None: clean_object_action, ] -ACTION_CATEGORIES = {action.action_category for action in ALL_ACTIONS} -ACTIONS_BY_CATEGORY = {category: [] for category in ACTION_CATEGORIES} +ACTIONS_BY_CATEGORY = {category: [] for category in ActionCategory} for action in ALL_ACTIONS: category = action.action_category ACTIONS_BY_CATEGORY[category].append(action) diff --git a/src/rl_ai2thor/envs/ai2thor_envs.py b/src/rl_ai2thor/envs/ai2thor_envs.py index 6c92c71..01760c6 100644 --- a/src/rl_ai2thor/envs/ai2thor_envs.py +++ b/src/rl_ai2thor/envs/ai2thor_envs.py @@ -15,14 +15,15 @@ from numpy.typing import ArrayLike from rl_ai2thor.envs.actions import ( - ACTION_CATEGORIES, ACTIONS_BY_CATEGORY, ACTIONS_BY_NAME, ALL_ACTIONS, + ActionCategory, + EnvActionName, EnvironmentAction, ) from rl_ai2thor.envs.reward import GraphTaskRewardHandler -from rl_ai2thor.envs.tasks import GraphTask, PlaceObject +from rl_ai2thor.envs.tasks import GraphTask, PlaceObject, ObjFixedPropId from rl_ai2thor.utils.ai2thor_types import EventLike from rl_ai2thor.utils.general_utils import ROOT_DIR, update_nested_dict @@ -79,7 +80,7 @@ def _create_action_space(self) -> None: # Get the available actions from the environment mode config for action_category in self.config["action_categories"]: - if action_category in ACTION_CATEGORIES: + if action_category in ActionCategory: if self.config["action_categories"][action_category]: # Enable all actions in the category for action in ACTIONS_BY_CATEGORY[action_category]: @@ -90,20 +91,20 @@ def _create_action_space(self) -> None: # Handle specific cases # Simple movement actions if self.config["simple_movement_actions"]: - self.action_availablities["MoveBack"] = False - self.action_availablities["MoveLeft"] = False - self.action_availablities["MoveRight"] = False + self.action_availablities[EnvActionName.MOVE_BACK] = False + self.action_availablities[EnvActionName.MOVE_LEFT] = False + self.action_availablities[EnvActionName.MOVE_RIGHT] = False # Done actions if self.config["use_done_action"]: - self.action_availablities["Done"] = True + self.action_availablities[EnvActionName.DONE] = True # Partial openness if ( self.config["partial_openness"] and self.config["action_categories"]["open_close_actions"] and not self.config["discrete_actions"] ): - self.action_availablities["OpenObject"] = False - self.action_availablities["CloseObject"] = False + self.action_availablities[EnvActionName.OPEN_OBJECT] = False + self.action_availablities[EnvActionName.CLOSE_OBJECT] = False available_actions = [action_name for action_name, available in self.action_availablities.items() if available] self.action_idx_to_name = dict(enumerate(available_actions)) @@ -290,8 +291,12 @@ def _sample_task(self, event: EventLike) -> GraphTask: """ # Temporarily return only a PlaceObject task # Sample a receptacle and an object to place - scene_pickupable_objects = [obj["objectType"] for obj in event.metadata["objects"] if obj["pickupable"]] - scene_receptacles = [obj["objectType"] for obj in event.metadata["objects"] if obj["receptacle"]] + scene_pickupable_objects = [ + obj[ObjFixedPropId.OBJECT_TYPE] for obj in event.metadata["objects"] if obj[ObjFixedPropId.PICKUPABLE] + ] + scene_receptacles = [ + obj[ObjFixedPropId.OBJECT_TYPE] for obj in event.metadata["objects"] if obj[ObjFixedPropId.RECEPTACLE] + ] np_rng: np.random.Generator = self._np_random # type: ignore object_to_place = np_rng.choice(scene_pickupable_objects)