basketball-v2 / git.diff
qgallouedec's picture
qgallouedec HF staff
Upload folder using huggingface_hub
51dac85
raw
history blame
4.17 kB
diff --git a/gia/eval/callback.py b/gia/eval/callback.py
index 5c3a080..4b6198f 100644
--- a/gia/eval/callback.py
+++ b/gia/eval/callback.py
@@ -2,10 +2,10 @@ import glob
import json
import subprocess
-import wandb
from accelerate import Accelerator
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
+import wandb
from gia.config import Arguments
from gia.eval.utils import is_slurm_available
diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
index 91b645c..3e2cae7 100644
--- a/gia/eval/evaluator.py
+++ b/gia/eval/evaluator.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
import torch
from gia.config.arguments import Arguments
@@ -5,11 +7,12 @@ from gia.model import GiaModel
class Evaluator:
- def __init__(self, args: Arguments, task: str) -> None:
+ def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None:
self.args = args
self.task = task
+ self.mean_random = mean_random
- @torch.no_grad()
+ @torch.inference_mode()
def evaluate(self, model: GiaModel) -> float:
return self._evaluate(model)
diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
index ec5e5b2..eeaf7cb 100644
--- a/gia/eval/rl/envs/core.py
+++ b/gia/eval/rl/envs/core.py
@@ -177,7 +177,6 @@ def make(task_name: str, num_envs: int = 1):
elif task_name.startswith("metaworld"):
import gymnasium as gym
- import metaworld
env_id = TASK_TO_ENV_MAPPING[task_name]
env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs)
diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
index f0d0b9b..39dc0d2 100644
--- a/gia/eval/rl/gia_agent.py
+++ b/gia/eval/rl/gia_agent.py
@@ -54,7 +54,7 @@ class GiaAgent:
self.action_space = action_space
self.deterministic = deterministic
self.device = next(model.parameters()).device
- self._max_length = self.model.config.max_position_embeddings - 10
+ self._max_length = self.model.config.max_position_embeddings - 100 # TODO: fix this
if isinstance(observation_space, spaces.Box):
self._observation_key = "continuous_observations"
@@ -75,6 +75,11 @@ class GiaAgent:
) -> Tuple[Tuple[Tensor, Tensor], ...]:
return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
+ def set_model(self, model: GiaModel) -> None:
+ self.model = model
+ self.device = next(model.parameters()).device
+ self._max_length = self.model.config.max_position_embeddings
+
def reset(self, num_envs: int = 1) -> None:
if self.prompter is not None:
prompts = self.prompter.generate_prompts(num_envs)
diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py
index f8531ee..754c05d 100644
--- a/gia/eval/rl/gym_evaluator.py
+++ b/gia/eval/rl/gym_evaluator.py
@@ -1,7 +1,7 @@
import gym
from gym.vector.vector_env import VectorEnv
-from gia.eval.mappings import TASK_TO_ENV_MAPPING
+# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING
from gia.eval.rl.rl_evaluator import RLEvaluator
diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py
index c5cc423..91189f3 100644
--- a/gia/eval/rl/rl_evaluator.py
+++ b/gia/eval/rl/rl_evaluator.py
@@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent
class RLEvaluator(Evaluator):
+ def __init__(self, args, task):
+ super().__init__(args, task)
+ self.agent = GiaAgent()
+
def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ?
raise NotImplementedError
diff --git a/gia/eval/rl/scores_dict.json b/gia/eval/rl/scores_dict.json
index 1b8ebee..ff7d030 100644
--- a/gia/eval/rl/scores_dict.json
+++ b/gia/eval/rl/scores_dict.json
@@ -929,8 +929,8 @@
},
"metaworld-assembly": {
"expert": {
- "mean": 311.29314618777823,
- "std": 75.04282151450695
+ "mean": 3523.81468486244,
+ "std": 63.22745220327798
},
"random": {
"mean": 220.65601680730813,