diff --git a/data/envs/download_expert_scores.py b/data/envs/download_expert_scores.py index 4c3f06b..88b6c45 100644 --- a/data/envs/download_expert_scores.py +++ b/data/envs/download_expert_scores.py @@ -12,162 +12,162 @@ from tqdm import tqdm ENV_NAMES = [ - "atari-alien", - "atari-amidar", - "atari-assault", - "atari-asterix", - "atari-asteroids", - "atari-atlantis", - "atari-bankheist", - "atari-battlezone", - "atari-beamrider", - "atari-berzerk", - "atari-bowling", - "atari-boxing", - "atari-breakout", - "atari-centipede", - "atari-choppercommand", - "atari-crazyclimber", - "atari-defender", - "atari-demonattack", - "atari-doubledunk", - "atari-enduro", - "atari-fishingderby", - "atari-freeway", - "atari-frostbite", - "atari-gopher", - "atari-gravitar", - "atari-hero", - "atari-icehockey", - "atari-jamesbond", - "atari-kangaroo", - "atari-krull", - "atari-kungfumaster", - "atari-montezumarevenge", - "atari-mspacman", - "atari-namethisgame", - "atari-phoenix", - "atari-pitfall", - "atari-pong", - "atari-privateeye", - "atari-qbert", - "atari-riverraid", - "atari-roadrunner", - "atari-robotank", - "atari-seaquest", - "atari-skiing", - "atari-solaris", - "atari-spaceinvaders", - "atari-stargunner", - # "atari-surround", # Not in the dataset - "atari-tennis", - "atari-timepilot", - "atari-tutankham", - "atari-upndown", - "atari-venture", - "atari-videopinball", - "atari-wizardofwor", - "atari-yarsrevenge", - "atari-zaxxon", - "babyai-action-obj-door", - "babyai-blocked-unlock-pickup", - "babyai-boss-level-no-unlock", - "babyai-boss-level", - "babyai-find-obj-s5", - "babyai-go-to-door", - # "babyai-go-to-imp-unlock", # Not in the dataset - "babyai-go-to-local", - "babyai-go-to-obj-door", - "babyai-go-to-obj", - "babyai-go-to-red-ball-grey", - "babyai-go-to-red-ball-no-dists", - "babyai-go-to-red-ball", - "babyai-go-to-red-blue-ball", - "babyai-go-to-seq", - "babyai-go-to", - "babyai-key-corridor", - "babyai-key-in-box", - "babyai-mini-boss-level", - "babyai-move-two-across", - "babyai-one-room-s8", - "babyai-open-door", - "babyai-open-doors-order", - "babyai-open-red-door", - "babyai-open-two-doors", - "babyai-open", - "babyai-pickup-above", - "babyai-pickup-dist", - "babyai-pickup-loc", - "babyai-pickup", - "babyai-synth-loc", - "babyai-synth-seq", - "babyai-synth", - "babyai-unblock-pickup", - "babyai-unlock-local", - "babyai-unlock-pickup", - # "babyai-unlock-to-unlock", # Not in the dataset - # "babyai-unlock", # Not in the dataset + # "atari-alien", + # "atari-amidar", + # "atari-assault", + # "atari-asterix", + # "atari-asteroids", + # "atari-atlantis", + # "atari-bankheist", + # "atari-battlezone", + # "atari-beamrider", + # "atari-berzerk", + # "atari-bowling", + # "atari-boxing", + # "atari-breakout", + # "atari-centipede", + # "atari-choppercommand", + # "atari-crazyclimber", + # "atari-defender", + # "atari-demonattack", + # "atari-doubledunk", + # "atari-enduro", + # "atari-fishingderby", + # "atari-freeway", + # "atari-frostbite", + # "atari-gopher", + # "atari-gravitar", + # "atari-hero", + # "atari-icehockey", + # "atari-jamesbond", + # "atari-kangaroo", + # "atari-krull", + # "atari-kungfumaster", + # "atari-montezumarevenge", + # "atari-mspacman", + # "atari-namethisgame", + # "atari-phoenix", + # "atari-pitfall", + # "atari-pong", + # "atari-privateeye", + # "atari-qbert", + # "atari-riverraid", + # "atari-roadrunner", + # "atari-robotank", + # "atari-seaquest", + # "atari-skiing", + # "atari-solaris", + # "atari-spaceinvaders", + # "atari-stargunner", + # # "atari-surround", # Not in the dataset + # "atari-tennis", + # "atari-timepilot", + # "atari-tutankham", + # "atari-upndown", + # "atari-venture", + # "atari-videopinball", + # "atari-wizardofwor", + # "atari-yarsrevenge", + # "atari-zaxxon", + # "babyai-action-obj-door", + # "babyai-blocked-unlock-pickup", + # "babyai-boss-level-no-unlock", + # "babyai-boss-level", + # "babyai-find-obj-s5", + # "babyai-go-to-door", + # # "babyai-go-to-imp-unlock", # Not in the dataset + # "babyai-go-to-local", + # "babyai-go-to-obj-door", + # "babyai-go-to-obj", + # "babyai-go-to-red-ball-grey", + # "babyai-go-to-red-ball-no-dists", + # "babyai-go-to-red-ball", + # "babyai-go-to-red-blue-ball", + # "babyai-go-to-seq", + # "babyai-go-to", + # "babyai-key-corridor", + # "babyai-key-in-box", + # "babyai-mini-boss-level", + # "babyai-move-two-across", + # "babyai-one-room-s8", + # "babyai-open-door", + # "babyai-open-doors-order", + # "babyai-open-red-door", + # "babyai-open-two-doors", + # "babyai-open", + # "babyai-pickup-above", + # "babyai-pickup-dist", + # "babyai-pickup-loc", + # "babyai-pickup", + # "babyai-synth-loc", + # "babyai-synth-seq", + # "babyai-synth", + # "babyai-unblock-pickup", + # "babyai-unlock-local", + # "babyai-unlock-pickup", + # # "babyai-unlock-to-unlock", # Not in the dataset + # # "babyai-unlock", # Not in the dataset "metaworld-assembly", - "metaworld-basketball", - "metaworld-bin-picking", - "metaworld-box-close", - "metaworld-button-press-topdown-wall", - "metaworld-button-press-topdown", - "metaworld-button-press-wall", - "metaworld-button-press", - "metaworld-coffee-button", - "metaworld-coffee-pull", - "metaworld-coffee-push", - "metaworld-dial-turn", - "metaworld-disassemble", - "metaworld-door-close", - "metaworld-door-lock", - "metaworld-door-open", - "metaworld-door-unlock", - "metaworld-drawer-close", - "metaworld-drawer-open", - "metaworld-faucet-close", - "metaworld-faucet-open", - "metaworld-hammer", - "metaworld-hand-insert", - "metaworld-handle-press-side", - "metaworld-handle-press", - "metaworld-handle-pull-side", - "metaworld-handle-pull", - "metaworld-lever-pull", - "metaworld-peg-insert-side", - "metaworld-peg-unplug-side", - "metaworld-pick-out-of-hole", - "metaworld-pick-place-wall", - "metaworld-pick-place", - "metaworld-plate-slide-back-side", - "metaworld-plate-slide-back", - "metaworld-plate-slide-side", - "metaworld-plate-slide", - "metaworld-push-back", - "metaworld-push-wall", - "metaworld-push", - "metaworld-reach-wall", - "metaworld-reach", - "metaworld-shelf-place", - "metaworld-soccer", - "metaworld-stick-pull", - "metaworld-stick-push", - "metaworld-sweep-into", - "metaworld-sweep", - "metaworld-window-close", - "metaworld-window-open", - "mujoco-ant", - "mujoco-doublependulum", - "mujoco-halfcheetah", - "mujoco-hopper", + # "metaworld-basketball", + # "metaworld-bin-picking", + # "metaworld-box-close", + # "metaworld-button-press-topdown-wall", + # "metaworld-button-press-topdown", + # "metaworld-button-press-wall", + # "metaworld-button-press", + # "metaworld-coffee-button", + # "metaworld-coffee-pull", + # "metaworld-coffee-push", + # "metaworld-dial-turn", + # "metaworld-disassemble", + # "metaworld-door-close", + # "metaworld-door-lock", + # "metaworld-door-open", + # "metaworld-door-unlock", + # "metaworld-drawer-close", + # "metaworld-drawer-open", + # "metaworld-faucet-close", + # "metaworld-faucet-open", + # "metaworld-hammer", + # "metaworld-hand-insert", + # "metaworld-handle-press-side", + # "metaworld-handle-press", + # "metaworld-handle-pull-side", + # "metaworld-handle-pull", + # "metaworld-lever-pull", + # "metaworld-peg-insert-side", + # "metaworld-peg-unplug-side", + # "metaworld-pick-out-of-hole", + # "metaworld-pick-place-wall", + # "metaworld-pick-place", + # "metaworld-plate-slide-back-side", + # "metaworld-plate-slide-back", + # "metaworld-plate-slide-side", + # "metaworld-plate-slide", + # "metaworld-push-back", + # "metaworld-push-wall", + # "metaworld-push", + # "metaworld-reach-wall", + # "metaworld-reach", + # "metaworld-shelf-place", + # "metaworld-soccer", + # "metaworld-stick-pull", + # "metaworld-stick-push", + # "metaworld-sweep-into", + # "metaworld-sweep", + # "metaworld-window-close", + # "metaworld-window-open", + # "mujoco-ant", + # "mujoco-doublependulum", + # "mujoco-halfcheetah", + # "mujoco-hopper", # "mujoco-humanoid", # Not in the dataset - "mujoco-pendulum", - # "mujoco-pusher", # Not in the dataset - "mujoco-reacher", + # "mujoco-pendulum", + # # "mujoco-pusher", # Not in the dataset + # "mujoco-reacher", # "mujoco-standup", # Not in the dataset - "mujoco-swimmer", - "mujoco-walker", + # "mujoco-swimmer", + # "mujoco-walker", ] diff --git a/data/envs/metaworld/enjoy.py b/data/envs/metaworld/enjoy.py deleted file mode 100644 index 6ec026b..0000000 --- a/data/envs/metaworld/enjoy.py +++ /dev/null @@ -1,84 +0,0 @@ -import sys -from typing import Dict, Optional - -import gym -import metaworld # noqa: F401 -from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args -from sample_factory.enjoy import enjoy -from sample_factory.envs.env_utils import register_env - - -ENV_NAMES = [ - "assembly-v2", - "basketball-v2", - "bin-picking-v2", - "box-close-v2", - "button-press-topdown-v2", - "button-press-topdown-wall-v2", - "button-press-v2", - "button-press-wall-v2", - "coffee-button-v2", - "coffee-pull-v2", - "coffee-push-v2", - "dial-turn-v2", - "disassemble-v2", - "door-close-v2", - "door-lock-v2", - "door-open-v2", - "door-unlock-v2", - "drawer-close-v2", - "drawer-open-v2", - "faucet-close-v2", - "faucet-open-v2", - "hammer-v2", - "hand-insert-v2", - "handle-press-side-v2", - "handle-press-v2", - "handle-pull-side-v2", - "handle-pull-v2", - "lever-pull-v2", - "peg-insert-side-v2", - "peg-unplug-side-v2", - "pick-out-of-hole-v2", - "pick-place-v2", - "pick-place-wall-v2", - "plate-slide-back-side-v2", - "plate-slide-back-v2", - "plate-slide-side-v2", - "plate-slide-v2", - "push-back-v2", - "push-v2", - "push-wall-v2", - "reach-v2", - "reach-wall-v2", - "shelf-place-v2", - "soccer-v2", - "stick-pull-v2", - "stick-push-v2", - "sweep-into-v2", - "sweep-v2", - "window-close-v2", - "window-open-v2", -] - - -def make_custom_env( - full_env_name: str, - cfg: Optional[Dict] = None, - env_config: Optional[Dict] = None, - render_mode: Optional[str] = None, -) -> gym.Env: - return gym.make(full_env_name, render_mode=render_mode) - - -def main() -> int: - for env_name in ENV_NAMES: - register_env(env_name, make_custom_env) - parser, _ = parse_sf_args(argv=None, evaluation=True) - cfg = parse_full_cfg(parser) - status = enjoy(cfg) - return status - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py index e21b237..c2b1907 100644 --- a/data/envs/metaworld/generate_dataset.py +++ b/data/envs/metaworld/generate_dataset.py @@ -142,7 +142,8 @@ def create_dataset(cfg: Config, dataset_size: int = 100_000, split: str = "train # Actions shape should be [num_agents, num_actions] even if it's [1, 1] actions = preprocess_actions(env_info, actions) - + # Clamp actions to be in the range of the action space + actions = np.clip(actions, env.action_space.low, env.action_space.high) rnn_states = policy_outputs["new_rnn_states"] dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0]) dataset["continuous_actions"][-1].append(actions[0]) diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh index cfdae2f..8720089 100755 --- a/data/envs/metaworld/generate_dataset_all.sh +++ b/data/envs/metaworld/generate_dataset_all.sh @@ -1,7 +1,7 @@ #!/bin/bash ENVS=( - assembly + # assembly basketball bin-picking box-close @@ -9,51 +9,51 @@ ENVS=( button-press-topdown-wall button-press button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn - disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open - hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull - peg-insert-side - peg-unplug-side - pick-out-of-hole - pick-place - pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide - push-back - push - push-wall - reach - reach-wall - shelf-place - soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open + # coffee-button + # coffee-pull + # coffee-push + # dial-turn + # disassemble + # door-close + # door-lock + # door-open + # door-unlock + # drawer-close + # drawer-open + # faucet-close + # faucet-open + # hammer + # hand-insert + # handle-press-side + # handle-press + # handle-pull-side + # handle-pull + # lever-pull + # peg-insert-side + # peg-unplug-side + # pick-out-of-hole + # pick-place + # pick-place-wall + # plate-slide-back-side + # plate-slide-back + # plate-slide-side + # plate-slide + # push-back + # push + # push-wall + # reach + # reach-wall + # shelf-place + # soccer + # stick-pull + # stick-push + # sweep-into + # sweep + # window-close + # window-open ) for ENV in "${ENVS[@]}"; do - python -m sample_factory.huggingface.load_from_hub -r qgallouedec/sample-factory-$ENV-v2 - python generate_dataset.py --env $ENV-v2 --experiment sample-factory-$ENV-v2 --train_dir=./train_dir + python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir done diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh index 9d71467..4fc1fc2 100755 --- a/data/envs/metaworld/push_all.sh +++ b/data/envs/metaworld/push_all.sh @@ -2,57 +2,57 @@ ENVS=( assembly - basketball - bin-picking - box-close - button-press-topdown - button-press-topdown-wall - button-press - button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn - disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open - hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull - peg-insert-side - peg-unplug-side - pick-out-of-hole - pick-place - pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide - push-back - push - push-wall - reach - reach-wall - shelf-place - soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open + # basketball + # bin-picking + # box-close + # button-press-topdown + # button-press-topdown-wall + # button-press + # button-press-wall + # coffee-button + # coffee-pull + # coffee-push + # dial-turn + # disassemble + # door-close + # door-lock + # door-open + # door-unlock + # drawer-close + # drawer-open + # faucet-close + # faucet-open + # hammer + # hand-insert + # handle-press-side + # handle-press + # handle-pull-side + # handle-pull + # lever-pull + # peg-insert-side + # peg-unplug-side + # pick-out-of-hole + # pick-place + # pick-place-wall + # plate-slide-back-side + # plate-slide-back + # plate-slide-side + # plate-slide + # push-back + # push + # push-wall + # reach + # reach-wall + # shelf-place + # soccer + # stick-pull + # stick-push + # sweep-into + # sweep + # window-close + # window-open ) for ENV in "${ENVS[@]}"; do - python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best + python push.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best done diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py index 46dc581..095414e 100644 --- a/data/envs/metaworld/train.py +++ b/data/envs/metaworld/train.py @@ -2,67 +2,13 @@ import argparse import sys from typing import Dict, Optional -import gym +import gymnasium as gym import metaworld # noqa: F401 from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args from sample_factory.envs.env_utils import register_env from sample_factory.train import run_rl -ENV_NAMES = [ - "assembly-v2", - "basketball-v2", - "bin-picking-v2", - "box-close-v2", - "button-press-topdown-v2", - "button-press-topdown-wall-v2", - "button-press-v2", - "button-press-wall-v2", - "coffee-button-v2", - "coffee-pull-v2", - "coffee-push-v2", - "dial-turn-v2", - "disassemble-v2", - "door-close-v2", - "door-lock-v2", - "door-open-v2", - "door-unlock-v2", - "drawer-close-v2", - "drawer-open-v2", - "faucet-close-v2", - "faucet-open-v2", - "hammer-v2", - "hand-insert-v2", - "handle-press-side-v2", - "handle-press-v2", - "handle-pull-side-v2", - "handle-pull-v2", - "lever-pull-v2", - "peg-insert-side-v2", - "peg-unplug-side-v2", - "pick-out-of-hole-v2", - "pick-place-v2", - "pick-place-wall-v2", - "plate-slide-back-side-v2", - "plate-slide-back-v2", - "plate-slide-side-v2", - "plate-slide-v2", - "push-back-v2", - "push-v2", - "push-wall-v2", - "reach-v2", - "reach-wall-v2", - "shelf-place-v2", - "soccer-v2", - "stick-pull-v2", - "stick-push-v2", - "sweep-into-v2", - "sweep-v2", - "window-close-v2", - "window-open-v2", -] - - def make_custom_env( full_env_name: str, cfg: Optional[Dict] = None, @@ -79,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse num_workers=8, num_envs_per_worker=8, worker_num_splits=2, - train_for_env_steps=100_000_000, + train_for_env_steps=10_000_000, encoder_mlp_layers=[64, 64], env_frameskip=1, nonlinearity="tanh", @@ -116,11 +62,10 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse def main() -> int: - for env_name in ENV_NAMES: - register_env(env_name, make_custom_env) parser, _ = parse_sf_args(argv=None, evaluation=False) parser = override_defaults(parser) cfg = parse_full_cfg(parser) + register_env(cfg.env, make_custom_env) status = run_rl(cfg) return status diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py index 91b645c..3e2cae7 100644 --- a/gia/eval/evaluator.py +++ b/gia/eval/evaluator.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from gia.config.arguments import Arguments @@ -5,11 +7,12 @@ from gia.model import GiaModel class Evaluator: - def __init__(self, args: Arguments, task: str) -> None: + def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None: self.args = args self.task = task + self.mean_random = mean_random - @torch.no_grad() + @torch.inference_mode() def evaluate(self, model: GiaModel) -> float: return self._evaluate(model) diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py index f1f83f5..3e8e182 100644 --- a/gia/eval/rl/envs/core.py +++ b/gia/eval/rl/envs/core.py @@ -177,6 +177,7 @@ def make(task_name: str, num_envs: int = 1): elif task_name.startswith("metaworld"): import gym + import metaworld env_id = TASK_TO_ENV_MAPPING[task_name] env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs) diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py index f0d0b9b..ca37721 100644 --- a/gia/eval/rl/gia_agent.py +++ b/gia/eval/rl/gia_agent.py @@ -9,7 +9,7 @@ from gia.datasets import GiaDataCollator, Prompter from gia.model.gia_model import GiaModel from gia.processing import GiaProcessor - +import sample_factory.envs.env_utils class GiaAgent: r""" An RL agent that uses Gia to generate actions. @@ -75,6 +75,11 @@ class GiaAgent: ) -> Tuple[Tuple[Tensor, Tensor], ...]: return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values) + def set_model(self, model: GiaModel) -> None: + self.model = model + self.device = next(model.parameters()).device + self._max_length = self.model.config.max_position_embeddings + def reset(self, num_envs: int = 1) -> None: if self.prompter is not None: prompts = self.prompter.generate_prompts(num_envs) diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py index f8531ee..754c05d 100644 --- a/gia/eval/rl/gym_evaluator.py +++ b/gia/eval/rl/gym_evaluator.py @@ -1,7 +1,7 @@ import gym from gym.vector.vector_env import VectorEnv -from gia.eval.mappings import TASK_TO_ENV_MAPPING +# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING from gia.eval.rl.rl_evaluator import RLEvaluator diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py index c5cc423..91189f3 100644 --- a/gia/eval/rl/rl_evaluator.py +++ b/gia/eval/rl/rl_evaluator.py @@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent class RLEvaluator(Evaluator): + def __init__(self, args, task): + super().__init__(args, task) + self.agent = GiaAgent() + def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ? raise NotImplementedError diff --git a/gia/eval/rl/scores_dict.json b/gia/eval/rl/scores_dict.json index 1b8ebee..ff7d030 100644 --- a/gia/eval/rl/scores_dict.json +++ b/gia/eval/rl/scores_dict.json @@ -929,8 +929,8 @@ }, "metaworld-assembly": { "expert": { - "mean": 311.29314618777823, - "std": 75.04282151450695 + "mean": 3523.81468486244, + "std": 63.22745220327798 }, "random": { "mean": 220.65601680730813,