koukyo1994 commited on
Commit
2441869
1 Parent(s): a81bee7

add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +185 -0
inference.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import imageio
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoModel
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Constants
15
+ IMAGE_SIZE = (288, 512)
16
+ N_FRAMES_PER_ROUND = 25
17
+ MAX_NUM_FRAMES = 50
18
+ N_TOKENS_PER_FRAME = 576
19
+ TRAJ_TEMPLATE_PATH = Path("./assets/template_trajectory.json")
20
+ PATH_START_ID = 9
21
+ PATH_POINT_INTERVAL = 10
22
+ N_ACTION_TOKENS = 6
23
+
24
+ # change here if you want to use your own images
25
+ CONDITIONING_FRAMES_DIR = Path("./assets/conditioning_frames")
26
+ CONDITIONING_FRAMES_PATH_LIST = [
27
+ CONDITIONING_FRAMES_DIR / "001.png",
28
+ CONDITIONING_FRAMES_DIR / "002.png",
29
+ CONDITIONING_FRAMES_DIR / "003.png"
30
+ ]
31
+
32
+
33
+ def set_random_seed(seed: int = 0):
34
+ random.seed(seed)
35
+ np.random.seed(seed)
36
+ torch.manual_seed(seed)
37
+ torch.cuda.manual_seed(seed)
38
+ torch.backends.cudnn.deterministic = True
39
+
40
+
41
+ def preprocess_image(image: Image.Image, size: tuple[int, int] = (288, 512)) -> torch.Tensor:
42
+ H, W = size
43
+ image = image.convert("RGB")
44
+ image = image.resize((W, H))
45
+ image_array = np.array(image)
46
+ image_array = (image_array / 127.5 - 1.0).astype(np.float32)
47
+ return torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0).float()
48
+
49
+
50
+ def to_np_images(images: torch.Tensor) -> np.ndarray:
51
+ images = images.detach().cpu()
52
+ images = torch.clamp(images, -1., 1.)
53
+ images = (images + 1.) / 2.
54
+ images = images.permute(0, 2, 3, 1).numpy()
55
+ return (255 * images).astype(np.uint8)
56
+
57
+
58
+ def load_images(file_path_list: list[Path], size: tuple[int, int] = (288, 512)) -> torch.Tensor:
59
+ images = []
60
+ for file_path in file_path_list:
61
+ image = Image.open(file_path)
62
+ image = preprocess_image(image, size)
63
+ images.append(image)
64
+ return torch.cat(images, dim=0)
65
+
66
+
67
+ def save_images_to_mp4(images: np.ndarray, output_path: Path, fps: int = 10):
68
+ writer = imageio.get_writer(output_path, fps=fps)
69
+ for img in images:
70
+ writer.append_data(img)
71
+ writer.close()
72
+
73
+
74
+ def determine_num_rounds(num_frames: int, num_overlapping_frames: int, n_initial_frames: int) -> int:
75
+ n_rounds = (num_frames - n_initial_frames) // (N_FRAMES_PER_ROUND - num_overlapping_frames)
76
+ if (num_frames - n_initial_frames) % (N_FRAMES_PER_ROUND - num_overlapping_frames) > 0:
77
+ n_rounds += 1
78
+ return n_rounds
79
+
80
+
81
+ def prepare_action(
82
+ traj_template: dict,
83
+ cmd: str,
84
+ path_start_id: int,
85
+ path_point_interval: int,
86
+ n_action_tokens: int = 5,
87
+ start_index: int = 0,
88
+ n_frames: int = 25
89
+ ) -> torch.Tensor:
90
+ trajs = traj_template[cmd]["instruction_trajs"]
91
+ actions = []
92
+ timesteps = np.arange(0.0, 3.0, 0.05)
93
+ for i in range(start_index, start_index + n_frames):
94
+ traj = trajs[i][path_start_id::path_point_interval][:n_action_tokens]
95
+ action = np.array(traj)
96
+ timestep = timesteps[path_start_id::path_point_interval][:n_action_tokens]
97
+ action = np.concatenate([
98
+ action[:, [1, 0]],
99
+ timestep.reshape(-1, 1)
100
+ ], axis=1)
101
+ actions.append(torch.tensor(action))
102
+ return torch.cat(actions, dim=0)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--seed", type=int, default=0)
108
+ parser.add_argument("--output_dir", type=Path)
109
+ parser.add_argument("--cmd", type=str, default="curving_to_left/curving_to_left_moderate")
110
+ parser.add_argument("--num_frames", type=int, default=25)
111
+ parser.add_argument("--num_overlapping_frames", type=int, default=3)
112
+ args = parser.parse_args()
113
+
114
+ assert args.num_frames <= MAX_NUM_FRAMES, f"`num_frames` should be less than or equal to {MAX_NUM_FRAMES}"
115
+ assert args.num_overlapping_frames < N_FRAMES_PER_ROUND, f"`num_overlapping_frames` should be less than {N_FRAMES_PER_ROUND}"
116
+
117
+ set_random_seed(args.seed)
118
+ if args.output_dir is None:
119
+ output_dir = Path(f"./outputs/{args.cmd}")
120
+ else:
121
+ output_dir = args.output_dir
122
+ output_dir.mkdir(parents=True, exist_ok=True)
123
+
124
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
125
+ tokenizer = AutoModel.from_pretrained("turing-motors/Terra", subfolder="lfq_tokenizer_B_256", trust_remote_code=True).to(device).eval()
126
+ model = AutoModel.from_pretrained("turing-motors/Terra", subfolder="world_model", trust_remote_code=True).to(device).eval()
127
+
128
+ conditioning_frames = load_images(CONDITIONING_FRAMES_PATH_LIST, IMAGE_SIZE).to(device)
129
+ with torch.inference_mode(), torch.autocast(device_type="cuda"):
130
+ input_ids = tokenizer.tokenize(conditioning_frames).detach().unsqueeze(0)
131
+
132
+ num_rounds = determine_num_rounds(args.num_frames, args.num_overlapping_frames, len(CONDITIONING_FRAMES_PATH_LIST))
133
+ print(f"Number of generation rounds: {num_rounds}")
134
+
135
+ with open(TRAJ_TEMPLATE_PATH) as f:
136
+ traj_template = json.load(f)
137
+
138
+ all_outputs = []
139
+ for round in range(num_rounds):
140
+ start_index = round * (N_FRAMES_PER_ROUND - args.num_overlapping_frames)
141
+ num_frames_for_round = min(N_FRAMES_PER_ROUND, args.num_frames - start_index)
142
+ actions = prepare_action(
143
+ traj_template, args.cmd, PATH_START_ID, PATH_POINT_INTERVAL, N_ACTION_TOKENS, start_index, num_frames_for_round
144
+ ).unsqueeze(0).to(device).float()
145
+ if round == 0:
146
+ num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - len(CONDITIONING_FRAMES_PATH_LIST))
147
+ else:
148
+ num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - args.num_overlapping_frames)
149
+ progress_bar = tqdm(total=num_generated_tokens, desc=f"Round {round + 1}")
150
+ with torch.inference_mode(), torch.autocast(device_type="cuda"):
151
+ output_tokens = model.generate(
152
+ input_ids=input_ids,
153
+ actions=actions,
154
+ do_sample=True,
155
+ max_length=N_TOKENS_PER_FRAME * num_frames_for_round,
156
+ temperature=1.0,
157
+ top_p=1.0,
158
+ use_cache=True,
159
+ pad_token_id=None,
160
+ eos_token_id=None,
161
+ progress_bar=progress_bar
162
+ )
163
+ if round == 0:
164
+ all_outputs.append(output_tokens[0])
165
+ else:
166
+ all_outputs.append(output_tokens[0, args.num_overlapping_frames * N_TOKENS_PER_FRAME:])
167
+ input_ids = output_tokens[:, -args.num_overlapping_frames * N_TOKENS_PER_FRAME:]
168
+ progress_bar.close()
169
+
170
+ output_ids = torch.cat(all_outputs)
171
+
172
+ # Calculate the shape of the latent tensor
173
+ downsample_ratio = 1
174
+ for coef in tokenizer.config.encoder_decoder_config["ch_mult"]:
175
+ downsample_ratio *= coef
176
+ h = IMAGE_SIZE[0] // downsample_ratio
177
+ w = IMAGE_SIZE[1] // downsample_ratio
178
+ c = tokenizer.config.encoder_decoder_config["z_channels"]
179
+ latent_shape = (len(output_ids) // 576, h, w, c)
180
+
181
+ # Decode the latent tensor to images
182
+ with torch.inference_mode(), torch.autocast(device_type="cuda"):
183
+ reconstructed = tokenizer.decode_tokens(output_ids, latent_shape)
184
+ reconstructed_images = to_np_images(reconstructed)
185
+ save_images_to_mp4(reconstructed_images, output_dir / "generated.mp4", fps=10)