diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py index e21b237..c2b1907 100644 --- a/data/envs/metaworld/generate_dataset.py +++ b/data/envs/metaworld/generate_dataset.py @@ -142,7 +142,8 @@ def create_dataset(cfg: Config, dataset_size: int = 100_000, split: str = "train # Actions shape should be [num_agents, num_actions] even if it's [1, 1] actions = preprocess_actions(env_info, actions) - + # Clamp actions to be in the range of the action space + actions = np.clip(actions, env.action_space.low, env.action_space.high) rnn_states = policy_outputs["new_rnn_states"] dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0]) dataset["continuous_actions"][-1].append(actions[0]) diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh index cfdae2f..5db8c4b 100755 --- a/data/envs/metaworld/generate_dataset_all.sh +++ b/data/envs/metaworld/generate_dataset_all.sh @@ -2,58 +2,58 @@ ENVS=( assembly - basketball - bin-picking - box-close - button-press-topdown - button-press-topdown-wall - button-press - button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn - disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open - hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull - peg-insert-side - peg-unplug-side - pick-out-of-hole - pick-place - pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide - push-back - push - push-wall - reach - reach-wall - shelf-place - soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open + # basketball + # bin-picking + # box-close + # button-press-topdown + # button-press-topdown-wall + # button-press + # button-press-wall + # coffee-button + # coffee-pull + # coffee-push + # dial-turn + # disassemble + # door-close + # door-lock + # door-open + # door-unlock + # drawer-close + # drawer-open + # faucet-close + # faucet-open + # hammer + # hand-insert + # handle-press-side + # handle-press + # handle-pull-side + # handle-pull + # lever-pull + # peg-insert-side + # peg-unplug-side + # pick-out-of-hole + # pick-place + # pick-place-wall + # plate-slide-back-side + # plate-slide-back + # plate-slide-side + # plate-slide + # push-back + # push + # push-wall + # reach + # reach-wall + # shelf-place + # soccer + # stick-pull + # stick-push + # sweep-into + # sweep + # window-close + # window-open ) for ENV in "${ENVS[@]}"; do - python -m sample_factory.huggingface.load_from_hub -r qgallouedec/sample-factory-$ENV-v2 - python generate_dataset.py --env $ENV-v2 --experiment sample-factory-$ENV-v2 --train_dir=./train_dir + python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir done diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh index 9d71467..5b05c6d 100755 --- a/data/envs/metaworld/push_all.sh +++ b/data/envs/metaworld/push_all.sh @@ -2,57 +2,57 @@ ENVS=( assembly - basketball - bin-picking - box-close - button-press-topdown - button-press-topdown-wall - button-press - button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn - disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open - hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull - peg-insert-side - peg-unplug-side - pick-out-of-hole - pick-place - pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide - push-back - push - push-wall - reach - reach-wall - shelf-place - soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open + # basketball + # bin-picking + # box-close + # button-press-topdown + # button-press-topdown-wall + # button-press + # button-press-wall + # coffee-button + # coffee-pull + # coffee-push + # dial-turn + # disassemble + # door-close + # door-lock + # door-open + # door-unlock + # drawer-close + # drawer-open + # faucet-close + # faucet-open + # hammer + # hand-insert + # handle-press-side + # handle-press + # handle-pull-side + # handle-pull + # lever-pull + # peg-insert-side + # peg-unplug-side + # pick-out-of-hole + # pick-place + # pick-place-wall + # plate-slide-back-side + # plate-slide-back + # plate-slide-side + # plate-slide + # push-back + # push + # push-wall + # reach + # reach-wall + # shelf-place + # soccer + # stick-pull + # stick-push + # sweep-into + # sweep + # window-close + # window-open ) for ENV in "${ENVS[@]}"; do - python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best + python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best done diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh index dbf328a..1b3c4c8 100755 --- a/data/envs/metaworld/train_all.sh +++ b/data/envs/metaworld/train_all.sh @@ -1,7 +1,7 @@ #!/bin/bash ENVS=( - assembly + # assembly basketball bin-picking box-close diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py index 91b645c..196a601 100644 --- a/gia/eval/evaluator.py +++ b/gia/eval/evaluator.py @@ -2,14 +2,16 @@ import torch from gia.config.arguments import Arguments from gia.model import GiaModel +from typing import Optional class Evaluator: - def __init__(self, args: Arguments, task: str) -> None: + def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None: self.args = args self.task = task + self.mean_random = mean_random - @torch.no_grad() + @torch.inference_mode() def evaluate(self, model: GiaModel) -> float: return self._evaluate(model) diff --git a/gia/eval/mappings.py b/gia/eval/mappings.py deleted file mode 100644 index e7ba9d3..0000000 --- a/gia/eval/mappings.py +++ /dev/null @@ -1,11 +0,0 @@ -TASK_TO_ENV_MAPPING = { - "mujoco-ant": "Ant-v4", - "mujoco-halfcheetah": "HalfCheetah-v4", - "mujoco-hopper": "Hopper-v4", - "mujoco-doublependulum": "InvertedDoublePendulum-v4", - "mujoco-pendulum": "InvertedPendulum-v4", - "mujoco-reacher": "Reacher-v4", - "mujoco-swimmer": "Swimmer-v4", - "mujoco-walker": "Walker2d-v4", - # Atari etc... -} diff --git a/gia/eval/rl/__init__.py b/gia/eval/rl/__init__.py index 36d890b..85a788d 100644 --- a/gia/eval/rl/__init__.py +++ b/gia/eval/rl/__init__.py @@ -1,4 +1,4 @@ from .gym_evaluator import GymEvaluator +from .envs.core import make - -__all__ = ["GymEvaluator"] +__all__ = ["GymEvaluator", "make"] diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py index f0d0b9b..04b9637 100644 --- a/gia/eval/rl/gia_agent.py +++ b/gia/eval/rl/gia_agent.py @@ -75,6 +75,11 @@ class GiaAgent: ) -> Tuple[Tuple[Tensor, Tensor], ...]: return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values) + def set_model(self, model: GiaModel) -> None: + self.model = model + self.device = next(model.parameters()).device + self._max_length = self.model.config.max_position_embeddings + def reset(self, num_envs: int = 1) -> None: if self.prompter is not None: prompts = self.prompter.generate_prompts(num_envs) diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py index f8531ee..754c05d 100644 --- a/gia/eval/rl/gym_evaluator.py +++ b/gia/eval/rl/gym_evaluator.py @@ -1,7 +1,7 @@ import gym from gym.vector.vector_env import VectorEnv -from gia.eval.mappings import TASK_TO_ENV_MAPPING +# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING from gia.eval.rl.rl_evaluator import RLEvaluator diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py index c5cc423..ca0c7da 100644 --- a/gia/eval/rl/rl_evaluator.py +++ b/gia/eval/rl/rl_evaluator.py @@ -8,6 +8,9 @@ from gia.eval.rl.gia_agent import GiaAgent class RLEvaluator(Evaluator): + def __init__(self, args, task): + super().__init__(args, task) + self.agent = GiaAgent() def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ? raise NotImplementedError