pablovela5620 commited on
Commit
9bf332b
1 Parent(s): e37e4d1

Upload gradio_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_app.py +331 -0
gradio_app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import PIL.Image
3
+ from PIL.Image import Image
4
+
5
+ from src.rr_logging_utils import (
6
+ log_camera,
7
+ create_svd_blueprint,
8
+ )
9
+
10
+ from src.pose_utils import generate_camera_parameters
11
+ from src.camera_parameters import PinholeParameters
12
+ from src.depth_utils import image_to_depth
13
+ from src.image_warping import image_depth_warping
14
+ from src.sigma_utils import load_lambda_ts
15
+ from src.nerfstudio_data import frames_to_nerfstudio
16
+
17
+
18
+ import gradio as gr
19
+ from gradio_rerun import Rerun
20
+
21
+ import rerun as rr
22
+ import rerun.blueprint as rrb
23
+
24
+ import numpy as np
25
+ import PIL
26
+ import torch
27
+ from pathlib import Path
28
+ import threading
29
+ from queue import SimpleQueue
30
+ import trimesh
31
+ import subprocess
32
+
33
+ import mmcv
34
+ from uuid import uuid4
35
+
36
+ from typing import Final, Literal
37
+
38
+ from jaxtyping import Float64, Float32, UInt8
39
+
40
+ from monopriors.relative_depth_models import (
41
+ get_relative_predictor,
42
+ )
43
+
44
+ from src.custom_diffusers_pipeline.svd import StableVideoDiffusionPipeline
45
+ from src.custom_diffusers_pipeline.scheduler import EulerDiscreteScheduler
46
+
47
+ try:
48
+ import spaces # type: ignore
49
+
50
+ IN_SPACES = True
51
+ except ImportError:
52
+ print("Not running on Zero")
53
+ IN_SPACES = False
54
+
55
+
56
+ SVD_HEIGHT: Final[int] = 576
57
+ SVD_WIDTH: Final[int] = 1024
58
+ NEAR: Final[float] = 0.0001
59
+ FAR: Final[float] = 500.0
60
+
61
+ if gr.NO_RELOAD:
62
+ DepthAnythingV2Predictor = get_relative_predictor("DepthAnythingV2Predictor")(
63
+ device="cuda"
64
+ )
65
+ SVD_PIPE = StableVideoDiffusionPipeline.from_pretrained(
66
+ "stabilityai/stable-video-diffusion-img2vid-xt",
67
+ torch_dtype=torch.float16,
68
+ variant="fp16",
69
+ )
70
+ SVD_PIPE.to("cuda")
71
+ scheduler = EulerDiscreteScheduler.from_config(SVD_PIPE.scheduler.config)
72
+ SVD_PIPE.scheduler = scheduler
73
+
74
+
75
+ def svd_render_threaded(
76
+ image_o: PIL.Image.Image,
77
+ masks: Float64[torch.Tensor, "b 72 128"],
78
+ cond_image: PIL.Image.Image,
79
+ lambda_ts: Float64[torch.Tensor, "n b"],
80
+ num_denoise_iters: Literal[2, 25, 50, 100],
81
+ weight_clamp: float,
82
+ log_queue: SimpleQueue | None = None,
83
+ ):
84
+ frames: list[PIL.Image.Image] = SVD_PIPE(
85
+ [image_o],
86
+ log_queue=log_queue,
87
+ temp_cond=cond_image,
88
+ mask=masks,
89
+ lambda_ts=lambda_ts,
90
+ weight_clamp=weight_clamp,
91
+ num_frames=25,
92
+ decode_chunk_size=8,
93
+ num_inference_steps=num_denoise_iters,
94
+ ).frames[0]
95
+
96
+ log_queue.put(frames)
97
+
98
+
99
+ if IN_SPACES:
100
+ svd_render_threaded = spaces.GPU(svd_render_threaded)
101
+
102
+
103
+ @rr.thread_local_stream("warped_image")
104
+ def gradio_warped_image(
105
+ image_path: str,
106
+ num_denoise_iters: Literal[2, 25, 50, 100],
107
+ direction: Literal["left", "right"],
108
+ degrees_per_frame: int | float,
109
+ major_radius: float = 60.0,
110
+ minor_radius: float = 70.0,
111
+ num_frames: int = 25, # StableDiffusion Video generates 25 frames
112
+ progress=gr.Progress(track_tqdm=True),
113
+ ):
114
+ # ensure that the degrees per frame is a float
115
+ degrees_per_frame = float(degrees_per_frame)
116
+
117
+ image_path: Path = Path(image_path) if isinstance(image_path, str) else image_path
118
+ assert image_path.exists(), f"Image file not found: {image_path}"
119
+ save_path: Path = image_path.parent / f"{image_path.stem}_{uuid4()}"
120
+
121
+ # setup rerun logging
122
+ stream = rr.binary_stream()
123
+ parent_log_path = Path("world")
124
+ rr.log(f"{parent_log_path}", rr.ViewCoordinates.LDB, static=True)
125
+ blueprint: rrb.Blueprint = create_svd_blueprint(parent_log_path)
126
+ rr.send_blueprint(blueprint)
127
+
128
+ # Load image and resize to SVD dimensions
129
+ rgb_original: Image = PIL.Image.open(image_path)
130
+ rgb_resized: Image = rgb_original.resize(
131
+ (SVD_WIDTH, SVD_HEIGHT), PIL.Image.Resampling.NEAREST
132
+ )
133
+ rgb_np_original: UInt8[np.ndarray, "h w 3"] = np.array(rgb_original)
134
+ rgb_np_hw3: UInt8[np.ndarray, "h w 3"] = np.array(rgb_resized)
135
+
136
+ # generate initial camera parameters for video trajectory
137
+ camera_list: list[PinholeParameters] = generate_camera_parameters(
138
+ num_frames=num_frames,
139
+ image_width=SVD_WIDTH,
140
+ image_height=SVD_HEIGHT,
141
+ degrees_per_frame=degrees_per_frame,
142
+ major_radius=major_radius,
143
+ minor_radius=minor_radius,
144
+ direction=direction,
145
+ )
146
+
147
+ assert len(camera_list) == num_frames, "Number of camera parameters mismatch"
148
+
149
+ # Estimate depth map and pointcloud for the input image
150
+ depth: Float32[np.ndarray, "h w"]
151
+ trimesh_pc: trimesh.PointCloud
152
+ depth_original: Float32[np.ndarray, "original_h original_w"]
153
+ trimesh_pc_original: trimesh.PointCloud
154
+
155
+ depth, trimesh_pc, depth_original, trimesh_pc_original = image_to_depth(
156
+ rgb_np_original=rgb_np_original,
157
+ rgb_np_hw3=rgb_np_hw3,
158
+ cam_params=camera_list[0],
159
+ near=NEAR,
160
+ far=FAR,
161
+ depth_predictor=DepthAnythingV2Predictor,
162
+ )
163
+
164
+ rr.log(
165
+ f"{parent_log_path}/point_cloud",
166
+ rr.Points3D(
167
+ positions=trimesh_pc.vertices,
168
+ colors=trimesh_pc.colors,
169
+ ),
170
+ static=True,
171
+ )
172
+
173
+ start_cam: PinholeParameters = camera_list[0]
174
+ cond_image: list[PIL.Image.Image] = []
175
+ masks: list[Float64[torch.Tensor, "1 72 128"]] = []
176
+
177
+ # Perform image depth warping to generated camera parameters
178
+ current_cam: PinholeParameters
179
+ for frame_id, current_cam in enumerate(camera_list):
180
+ rr.set_time_sequence("frame_id", frame_id)
181
+ if frame_id == 0:
182
+ cam_log_path: Path = parent_log_path / "warped_camera"
183
+ log_camera(cam_log_path, current_cam, rgb_np_hw3, depth)
184
+ else:
185
+ # clear logged depth from the previous frame
186
+ rr.log(f"{cam_log_path}/pinhole/depth", rr.Clear(recursive=False))
187
+ cam_log_path: Path = parent_log_path / "warped_camera"
188
+ # do image warping
189
+ warped_frame2, mask_erosion_tensor = image_depth_warping(
190
+ image=rgb_np_hw3,
191
+ depth=depth,
192
+ cam_T_world_44_s=start_cam.extrinsics.cam_T_world,
193
+ cam_T_world_44_t=current_cam.extrinsics.cam_T_world,
194
+ K=current_cam.intrinsics.k_matrix,
195
+ )
196
+ cond_image.append(warped_frame2)
197
+ masks.append(mask_erosion_tensor)
198
+
199
+ log_camera(cam_log_path, current_cam, np.asarray(warped_frame2))
200
+ yield stream.read(), None, [], ""
201
+
202
+ masks: Float64[torch.Tensor, "b 72 128"] = torch.cat(masks)
203
+ # load sigmas to optimize for timestep
204
+ progress(0.1, desc="Optimizing timesteps for diffusion")
205
+ lambda_ts: Float64[torch.Tensor, "n b"] = load_lambda_ts(num_denoise_iters)
206
+ progress(0.15, desc="Starting diffusion")
207
+
208
+ # to allow logging from a separate thread
209
+ log_queue: SimpleQueue = SimpleQueue()
210
+ handle = threading.Thread(
211
+ target=svd_render_threaded,
212
+ kwargs={
213
+ "image_o": rgb_resized,
214
+ "masks": masks,
215
+ "cond_image": cond_image,
216
+ "lambda_ts": lambda_ts,
217
+ "num_denoise_iters": num_denoise_iters,
218
+ "weight_clamp": 0.2,
219
+ "log_queue": log_queue,
220
+ },
221
+ )
222
+
223
+ handle.start()
224
+ i = 0
225
+ while True:
226
+ msg = log_queue.get()
227
+ match msg:
228
+ case frames if all(isinstance(frame, PIL.Image.Image) for frame in frames):
229
+ break
230
+ case entity_path, entity, times:
231
+ i += 1
232
+ rr.reset_time()
233
+ for timeline, time in times:
234
+ if isinstance(time, int):
235
+ rr.set_time_sequence(timeline, time)
236
+ else:
237
+ rr.set_time_seconds(timeline, time)
238
+ static = False
239
+ if entity_path == "diffusion_step":
240
+ static = True
241
+ rr.log(entity_path, entity, static=static)
242
+ yield stream.read(), None, [], f"{i} out of {num_denoise_iters}"
243
+ case _:
244
+ assert False
245
+ handle.join()
246
+
247
+ # all frames but the first one
248
+ frame: np.ndarray
249
+ for frame_id, (frame, cam_pararms) in enumerate(zip(frames, camera_list)):
250
+ # add one since the first frame is the original image
251
+ rr.set_time_sequence("frame_id", frame_id)
252
+ cam_log_path = parent_log_path / "generated_camera"
253
+ generated_rgb_np: UInt8[np.ndarray, "h w 3"] = np.array(frame)
254
+ log_camera(cam_log_path, cam_pararms, generated_rgb_np, depth=None)
255
+ yield stream.read(), None, [], "finished"
256
+
257
+ frames_to_nerfstudio(
258
+ rgb_np_original, frames, trimesh_pc_original, camera_list, save_path
259
+ )
260
+ # zip up nerfstudio data
261
+ zip_file_path = save_path / "nerfstudio.zip"
262
+ progress(0.95, desc="Zipping up camera data in nerfstudio format")
263
+ # Run the zip command
264
+ subprocess.run(["zip", "-r", str(zip_file_path), str(save_path)], check=True)
265
+ video_file_path = save_path / "output.mp4"
266
+ mmcv.frames2video(str(save_path), str(video_file_path), fps=7)
267
+ print(f"Video saved to {video_file_path}")
268
+ yield stream.read(), video_file_path, [str(zip_file_path)], "finished"
269
+
270
+
271
+ with gr.Blocks() as demo:
272
+ with gr.Tab("Streaming"):
273
+ with gr.Row():
274
+ img = gr.Image(interactive=True, label="Image", type="filepath")
275
+ with gr.Tab(label="Settings"):
276
+ with gr.Column():
277
+ warp_img_btn = gr.Button("Warp Images")
278
+ num_iters = gr.Radio(
279
+ choices=[2, 25, 50, 100],
280
+ value=2,
281
+ label="Number of iterations",
282
+ type="value",
283
+ )
284
+ cam_direction = gr.Radio(
285
+ choices=["left", "right"],
286
+ value="left",
287
+ label="Camera direction",
288
+ type="value",
289
+ )
290
+ degrees_per_frame = gr.Slider(
291
+ minimum=0.25,
292
+ maximum=1.0,
293
+ step=0.05,
294
+ value=0.3,
295
+ label="Degrees per frame",
296
+ )
297
+ iteration_num = gr.Textbox(
298
+ value="",
299
+ label="Current Diffusion Step",
300
+ )
301
+ with gr.Tab(label="Outputs"):
302
+ video_output = gr.Video(interactive=False)
303
+ image_files_output = gr.File(interactive=False, file_count="multiple")
304
+
305
+ # Rerun 0.16 has issues when embedded in a Gradio tab, so we share a viewer between all the tabs.
306
+ # In 0.17 we can instead scope each viewer to its own tab to clean up these examples further.
307
+ with gr.Row():
308
+ viewer = Rerun(
309
+ streaming=True,
310
+ )
311
+
312
+ warp_img_btn.click(
313
+ gradio_warped_image,
314
+ inputs=[img, num_iters, cam_direction, degrees_per_frame],
315
+ outputs=[viewer, video_output, image_files_output, iteration_num],
316
+ )
317
+
318
+ gr.Examples(
319
+ [
320
+ [
321
+ "/home/pablo/0Dev/docker/.per/repos/NVS_Solver/example_imgs/single/000001.jpg",
322
+ ],
323
+ ],
324
+ fn=warp_img_btn,
325
+ inputs=[img, num_iters, cam_direction, degrees_per_frame],
326
+ outputs=[viewer, video_output, image_files_output],
327
+ )
328
+
329
+
330
+ if __name__ == "__main__":
331
+ demo.queue().launch()