Spaces:
Runtime error
Runtime error
| """Data collection script.""" | |
| import os | |
| import hydra | |
| import numpy as np | |
| import random | |
| from cliport import tasks | |
| from cliport.dataset import RavensDataset | |
| from cliport.environments.environment import Environment | |
| import IPython | |
| import random | |
| def main(cfg): | |
| # Initialize environment and task. | |
| env = Environment( | |
| cfg['assets_root'], | |
| disp=cfg['disp'], | |
| shared_memory=cfg['shared_memory'], | |
| hz=480, | |
| record_cfg=cfg['record'] | |
| ) | |
| task = tasks.names[cfg['task']]() | |
| task.mode = cfg['mode'] | |
| record = cfg['record']['save_video'] | |
| save_data = cfg['save_data'] | |
| # Initialize scripted oracle agent and dataset. | |
| agent = task.oracle(env) | |
| data_path = os.path.join(cfg['data_dir'], "{}-{}".format(cfg['task'], task.mode)) | |
| dataset = RavensDataset(data_path, cfg, n_demos=0, augment=False) | |
| print(f"Saving to: {data_path}") | |
| print(f"Mode: {task.mode}") | |
| # Train seeds are even and val/test seeds are odd. Test seeds are offset by 10000 | |
| seed = dataset.max_seed | |
| max_eps = 3 * cfg['n'] | |
| if seed < 0: | |
| if task.mode == 'train': | |
| seed = -2 | |
| elif task.mode == 'val': # NOTE: beware of increasing val set to >100 | |
| seed = -1 | |
| elif task.mode == 'test': | |
| seed = -1 + 10000 | |
| else: | |
| raise Exception("Invalid mode. Valid options: train, val, test") | |
| if 'regenerate_data' in cfg: | |
| dataset.n_episodes = 0 | |
| curr_run_eps = 0 | |
| # Collect training data from oracle demonstrations. | |
| while dataset.n_episodes < cfg['n'] and curr_run_eps < max_eps: | |
| # for epi_idx in range(cfg['n']): | |
| episode, total_reward = [], 0 | |
| seed += 2 | |
| # Set seeds. | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, cfg['n'], seed)) | |
| try: | |
| curr_run_eps += 1 # make sure exits the loop | |
| env.set_task(task) | |
| obs = env.reset() | |
| info = env.info | |
| reward = 0 | |
| # Unlikely, but a safety check to prevent leaks. | |
| if task.mode == 'val' and seed > (-1 + 10000): | |
| raise Exception("!!! Seeds for val set will overlap with the test set !!!") | |
| # Start video recording (NOTE: super slow) | |
| if record: | |
| env.start_rec(f'{dataset.n_episodes+1:06d}') | |
| # Rollout expert policy | |
| for _ in range(task.max_steps): | |
| act = agent.act(obs, info) | |
| episode.append((obs, act, reward, info)) | |
| lang_goal = info['lang_goal'] | |
| obs, reward, done, info = env.step(act) | |
| total_reward += reward | |
| print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') | |
| if done: | |
| break | |
| if record: | |
| env.end_rec() | |
| except Exception as e: | |
| from pygments import highlight | |
| from pygments.lexers import PythonLexer | |
| from pygments.formatters import TerminalFormatter | |
| import traceback | |
| to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter()) | |
| print(to_print) | |
| if record: | |
| env.end_rec() | |
| continue | |
| episode.append((obs, None, reward, info)) | |
| # Only save completed demonstrations. | |
| if save_data and total_reward > 0.99: | |
| dataset.add(seed, episode) | |
| if hasattr(env, 'blender_recorder'): | |
| print("blender pickle saved to ", '{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes)) | |
| env.blender_recorder.save('{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes)) | |
| if __name__ == '__main__': | |
| main() | |