Skip to content

Commit

Permalink
Further use enums in actions and environment
Browse files Browse the repository at this point in the history
  • Loading branch information
Kajiih committed Feb 12, 2024
1 parent aba8091 commit 388151c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
14 changes: 9 additions & 5 deletions src/rl_ai2thor/envs/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from enum import StrEnum
from typing import TYPE_CHECKING, Any

from rl_ai2thor.envs.tasks import ObjFixedPropId, ObjVariablePropId
from rl_ai2thor.utils.general_utils import nested_dict_get

if TYPE_CHECKING:
from rl_ai2thor.envs.ai2thor_envs import ITHOREnv
from rl_ai2thor.envs.tasks import ObjFixedPropId
from rl_ai2thor.utils.ai2thor_types import EventLike


Expand Down Expand Up @@ -135,6 +135,7 @@ class ActionCategory(StrEnum):

# === Action Classes ===
# TODO: Change perform to not need the environment
# TODO: Add target value to object required property
@dataclass
class EnvironmentAction:
"""
Expand Down Expand Up @@ -368,7 +369,11 @@ def __call__(self, env: ITHOREnv) -> bool:
bool: Whether the agent has visible running water in its field of view.
"""
for obj in env.last_event.metadata["objects"]:
if obj["visible"] and obj["isToggled"] and obj["objectType"] in {"Faucet", "ShowerHead"}:
if (
obj[ObjVariablePropId.VISIBLE]
and obj[ObjVariablePropId.IS_TOGGLED]
and obj[ObjFixedPropId.OBJECT_TYPE] in {"Faucet", "ShowerHead"}
):
return True
return False

Expand Down Expand Up @@ -403,7 +408,7 @@ def __call__(self, env: ITHOREnv) -> bool:
"""
return (
len(env.last_event.metadata["inventoryObjects"]) > 0
and env.last_event.metadata["inventoryObjects"][0]["objectType"] == self.object_type
and env.last_event.metadata["inventoryObjects"][0][ObjFixedPropId.OBJECT_TYPE] == self.object_type
)

def _base_error_message(self, action: EnvironmentAction) -> str:
Expand Down Expand Up @@ -770,8 +775,7 @@ def __init__(self, ai2thor_action: str) -> None:
clean_object_action,
]

ACTION_CATEGORIES = {action.action_category for action in ALL_ACTIONS}
ACTIONS_BY_CATEGORY = {category: [] for category in ACTION_CATEGORIES}
ACTIONS_BY_CATEGORY = {category: [] for category in ActionCategory}
for action in ALL_ACTIONS:
category = action.action_category
ACTIONS_BY_CATEGORY[category].append(action)
Expand Down
27 changes: 16 additions & 11 deletions src/rl_ai2thor/envs/ai2thor_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from numpy.typing import ArrayLike

from rl_ai2thor.envs.actions import (
ACTION_CATEGORIES,
ACTIONS_BY_CATEGORY,
ACTIONS_BY_NAME,
ALL_ACTIONS,
ActionCategory,
EnvActionName,
EnvironmentAction,
)
from rl_ai2thor.envs.reward import GraphTaskRewardHandler
from rl_ai2thor.envs.tasks import GraphTask, PlaceObject
from rl_ai2thor.envs.tasks import GraphTask, PlaceObject, ObjFixedPropId
from rl_ai2thor.utils.ai2thor_types import EventLike
from rl_ai2thor.utils.general_utils import ROOT_DIR, update_nested_dict

Expand Down Expand Up @@ -79,7 +80,7 @@ def _create_action_space(self) -> None:

# Get the available actions from the environment mode config
for action_category in self.config["action_categories"]:
if action_category in ACTION_CATEGORIES:
if action_category in ActionCategory:
if self.config["action_categories"][action_category]:
# Enable all actions in the category
for action in ACTIONS_BY_CATEGORY[action_category]:
Expand All @@ -90,20 +91,20 @@ def _create_action_space(self) -> None:
# Handle specific cases
# Simple movement actions
if self.config["simple_movement_actions"]:
self.action_availablities["MoveBack"] = False
self.action_availablities["MoveLeft"] = False
self.action_availablities["MoveRight"] = False
self.action_availablities[EnvActionName.MOVE_BACK] = False
self.action_availablities[EnvActionName.MOVE_LEFT] = False
self.action_availablities[EnvActionName.MOVE_RIGHT] = False
# Done actions
if self.config["use_done_action"]:
self.action_availablities["Done"] = True
self.action_availablities[EnvActionName.DONE] = True
# Partial openness
if (
self.config["partial_openness"]
and self.config["action_categories"]["open_close_actions"]
and not self.config["discrete_actions"]
):
self.action_availablities["OpenObject"] = False
self.action_availablities["CloseObject"] = False
self.action_availablities[EnvActionName.OPEN_OBJECT] = False
self.action_availablities[EnvActionName.CLOSE_OBJECT] = False

available_actions = [action_name for action_name, available in self.action_availablities.items() if available]
self.action_idx_to_name = dict(enumerate(available_actions))
Expand Down Expand Up @@ -290,8 +291,12 @@ def _sample_task(self, event: EventLike) -> GraphTask:
"""
# Temporarily return only a PlaceObject task
# Sample a receptacle and an object to place
scene_pickupable_objects = [obj["objectType"] for obj in event.metadata["objects"] if obj["pickupable"]]
scene_receptacles = [obj["objectType"] for obj in event.metadata["objects"] if obj["receptacle"]]
scene_pickupable_objects = [
obj[ObjFixedPropId.OBJECT_TYPE] for obj in event.metadata["objects"] if obj[ObjFixedPropId.PICKUPABLE]
]
scene_receptacles = [
obj[ObjFixedPropId.OBJECT_TYPE] for obj in event.metadata["objects"] if obj[ObjFixedPropId.RECEPTACLE]
]

np_rng: np.random.Generator = self._np_random # type: ignore
object_to_place = np_rng.choice(scene_pickupable_objects)
Expand Down

0 comments on commit 388151c

Please sign in to comment.