|
|
|
|
|
|
|
|
|
@@ -4,7 +4,7 @@ ENVS=( |
|
assembly |
|
basketball |
|
bin-picking |
|
- box-close |
|
+ #box-close |
|
button-press-topdown |
|
button-press-topdown-wall |
|
button-press |
|
|
|
|
|
|
|
|
|
@@ -2,10 +2,10 @@ import glob |
|
import json |
|
import subprocess |
|
|
|
-import wandb |
|
from accelerate import Accelerator |
|
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments |
|
|
|
+import wandb |
|
from gia.config import Arguments |
|
from gia.eval.utils import is_slurm_available |
|
|
|
|
|
|
|
|
|
|
|
@@ -180,7 +180,7 @@ def make(task_name: str, num_envs: int = 1): |
|
import metaworld |
|
|
|
env_id = TASK_TO_ENV_MAPPING[task_name] |
|
- env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs) |
|
+ env = gym.make(env_id) |
|
|
|
else: |
|
raise ValueError(f"Unknown task name: {task_name}") |
|
|
|
|
|
|
|
|
|
@@ -54,7 +54,7 @@ class GiaAgent: |
|
self.action_space = action_space |
|
self.deterministic = deterministic |
|
self.device = next(model.parameters()).device |
|
- self._max_length = self.model.config.max_position_embeddings - 10 |
|
+ self._max_length = self.model.config.max_position_embeddings - 100 |
|
|
|
if isinstance(observation_space, spaces.Box): |
|
self._observation_key = "continuous_observations" |
|
|
|
|
|
|
|
|
|
@@ -1,7 +1,6 @@ |
|
import gym |
|
from gym.vector.vector_env import VectorEnv |
|
|
|
-from gia.eval.mappings import TASK_TO_ENV_MAPPING |
|
from gia.eval.rl.rl_evaluator import RLEvaluator |
|
|
|
|
|
|