|
|
|
|
|
|
|
|
|
@@ -2,10 +2,10 @@ import glob |
|
import json |
|
import subprocess |
|
|
|
-import wandb |
|
from accelerate import Accelerator |
|
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments |
|
|
|
+import wandb |
|
from gia.config import Arguments |
|
from gia.eval.utils import is_slurm_available |
|
|
|
|
|
|
|
|
|
|
|
@@ -1,3 +1,5 @@ |
|
+from typing import Optional |
|
+ |
|
import torch |
|
|
|
from gia.config.arguments import Arguments |
|
@@ -5,11 +7,12 @@ from gia.model import GiaModel |
|
|
|
|
|
class Evaluator: |
|
- def __init__(self, args: Arguments, task: str) -> None: |
|
+ def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None: |
|
self.args = args |
|
self.task = task |
|
+ self.mean_random = mean_random |
|
|
|
- @torch.no_grad() |
|
+ @torch.inference_mode() |
|
def evaluate(self, model: GiaModel) -> float: |
|
return self._evaluate(model) |
|
|
|
|
|
|
|
|
|
|
|
@@ -177,7 +177,6 @@ def make(task_name: str, num_envs: int = 1): |
|
|
|
elif task_name.startswith("metaworld"): |
|
import gymnasium as gym |
|
- import metaworld |
|
|
|
env_id = TASK_TO_ENV_MAPPING[task_name] |
|
env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs) |
|
|
|
|
|
|
|
|
|
@@ -54,7 +54,7 @@ class GiaAgent: |
|
self.action_space = action_space |
|
self.deterministic = deterministic |
|
self.device = next(model.parameters()).device |
|
- self._max_length = self.model.config.max_position_embeddings - 10 |
|
+ self._max_length = self.model.config.max_position_embeddings - 100 # TODO: fix this |
|
|
|
if isinstance(observation_space, spaces.Box): |
|
self._observation_key = "continuous_observations" |
|
@@ -75,6 +75,11 @@ class GiaAgent: |
|
) -> Tuple[Tuple[Tensor, Tensor], ...]: |
|
return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values) |
|
|
|
+ def set_model(self, model: GiaModel) -> None: |
|
+ self.model = model |
|
+ self.device = next(model.parameters()).device |
|
+ self._max_length = self.model.config.max_position_embeddings |
|
+ |
|
def reset(self, num_envs: int = 1) -> None: |
|
if self.prompter is not None: |
|
prompts = self.prompter.generate_prompts(num_envs) |
|
|
|
|
|
|
|
|
|
@@ -1,7 +1,7 @@ |
|
import gym |
|
from gym.vector.vector_env import VectorEnv |
|
|
|
-from gia.eval.mappings import TASK_TO_ENV_MAPPING |
|
+# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING |
|
from gia.eval.rl.rl_evaluator import RLEvaluator |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent |
|
|
|
|
|
class RLEvaluator(Evaluator): |
|
+ def __init__(self, args, task): |
|
+ super().__init__(args, task) |
|
+ self.agent = GiaAgent() |
|
+ |
|
def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ? |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
@@ -929,8 +929,8 @@ |
|
}, |
|
"metaworld-assembly": { |
|
"expert": { |
|
- "mean": 311.29314618777823, |
|
- "std": 75.04282151450695 |
|
+ "mean": 3523.81468486244, |
|
+ "std": 63.22745220327798 |
|
}, |
|
"random": { |
|
"mean": 220.65601680730813, |
|
|