dongzerui000 commited on
Commit
800f395
·
1 Parent(s): 350a724

Update apply.py

Browse files
Files changed (1) hide show
  1. apply.py +296 -0
apply.py CHANGED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %cd /content
2
+ !git clone -b dev https://github.com/camenduru/generative-models
3
+ !pip install -q -r https://github.com/camenduru/stable-video-diffusion-colab/raw/main/requirements.txt
4
+ !pip install -q -e generative-models
5
+ !pip install -q -e git+https://github.com/Stability-AI/datapipelines@main#egg=sdata
6
+
7
+ !apt -y install -qq aria2
8
+ !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/vdo/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors?download=true -d /content/checkpoints -o svd_xt.safetensors
9
+
10
+ !mkdir -p /content/scripts/util/detection
11
+ !ln -s /content/generative-models/scripts/util/detection/p_head_v1.npz /content/scripts/util/detection/p_head_v1.npz
12
+ !ln -s /content/generative-models/scripts/util/detection/w_head_v1.npz /content/scripts/util/detection/w_head_v1.npz
13
+
14
+ import sys
15
+ sys.path.append("generative-models")
16
+
17
+ import os, math, torch, cv2
18
+ from omegaconf import OmegaConf
19
+ from glob import glob
20
+ from pathlib import Path
21
+ from typing import Optional
22
+ import numpy as np
23
+ from einops import rearrange, repeat
24
+
25
+ from PIL import Image
26
+ from torchvision.transforms import ToTensor
27
+ from torchvision.transforms import functional as TF
28
+ from sgm.util import instantiate_from_config
29
+
30
+ def load_model(config: str, device: str, num_frames: int, num_steps: int):
31
+ config = OmegaConf.load(config)
32
+ config.model.params.conditioner_config.params.emb_models[0].params.open_clip_embedding_config.params.init_device = device
33
+ config.model.params.sampler_config.params.num_steps = num_steps
34
+ config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)
35
+ with torch.device(device):
36
+ model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)
37
+ return model
38
+
39
+ num_frames = 25
40
+ num_steps = 30
41
+ model_config = "generative-models/scripts/sampling/configs/svd_xt.yaml"
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ model = load_model(model_config, device, num_frames, num_steps)
44
+ model.conditioner.cpu()
45
+ model.first_stage_model.cpu()
46
+ model.model.to(dtype=torch.float16)
47
+ torch.cuda.empty_cache()
48
+ model = model.requires_grad_(False)
49
+
50
+ def get_unique_embedder_keys_from_conditioner(conditioner):
51
+ return list(set([x.input_key for x in conditioner.embedders]))
52
+
53
+ def get_batch(keys, value_dict, N, T, device, dtype=None):
54
+ batch = {}
55
+ batch_uc = {}
56
+ for key in keys:
57
+ if key == "fps_id":
58
+ batch[key] = (
59
+ torch.tensor([value_dict["fps_id"]])
60
+ .to(device, dtype=dtype)
61
+ .repeat(int(math.prod(N)))
62
+ )
63
+ elif key == "motion_bucket_id":
64
+ batch[key] = (
65
+ torch.tensor([value_dict["motion_bucket_id"]])
66
+ .to(device, dtype=dtype)
67
+ .repeat(int(math.prod(N)))
68
+ )
69
+ elif key == "cond_aug":
70
+ batch[key] = repeat(
71
+ torch.tensor([value_dict["cond_aug"]]).to(device, dtype=dtype),
72
+ "1 -> b",
73
+ b=math.prod(N),
74
+ )
75
+ elif key == "cond_frames":
76
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
77
+ elif key == "cond_frames_without_noise":
78
+ batch[key] = repeat(
79
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
80
+ )
81
+ else:
82
+ batch[key] = value_dict[key]
83
+ if T is not None:
84
+ batch["num_video_frames"] = T
85
+ for key in batch.keys():
86
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
87
+ batch_uc[key] = torch.clone(batch[key])
88
+ return batch, batch_uc
89
+
90
+ def sample(
91
+ input_path: str = "/content/test_image.png",
92
+ resize_image: bool = False,
93
+ num_frames: Optional[int] = None,
94
+ num_steps: Optional[int] = None,
95
+ fps_id: int = 6,
96
+ motion_bucket_id: int = 127,
97
+ cond_aug: float = 0.02,
98
+ seed: int = 23,
99
+ decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
100
+ device: str = "cuda",
101
+ output_folder: Optional[str] = "/content/outputs",
102
+ ):
103
+ """
104
+ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
105
+ image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
106
+ """
107
+ torch.manual_seed(seed)
108
+
109
+ path = Path(input_path)
110
+ all_img_paths = []
111
+ if path.is_file():
112
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
113
+ all_img_paths = [input_path]
114
+ else:
115
+ raise ValueError("Path is not valid image file.")
116
+ elif path.is_dir():
117
+ all_img_paths = sorted(
118
+ [
119
+ f
120
+ for f in path.iterdir()
121
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
122
+ ]
123
+ )
124
+ if len(all_img_paths) == 0:
125
+ raise ValueError("Folder does not contain any images.")
126
+ else:
127
+ raise ValueError
128
+ all_out_paths = []
129
+ for input_img_path in all_img_paths:
130
+ with Image.open(input_img_path) as image:
131
+ if image.mode == "RGBA":
132
+ image = image.convert("RGB")
133
+ if resize_image and image.size != (1024, 576):
134
+ print(f"Resizing {image.size} to (1024, 576)")
135
+ image = TF.resize(TF.resize(image, 1024), (576, 1024))
136
+ w, h = image.size
137
+ if h % 64 != 0 or w % 64 != 0:
138
+ width, height = map(lambda x: x - x % 64, (w, h))
139
+ image = image.resize((width, height))
140
+ print(
141
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
142
+ )
143
+ image = ToTensor()(image)
144
+ image = image * 2.0 - 1.0
145
+
146
+ image = image.unsqueeze(0).to(device)
147
+ H, W = image.shape[2:]
148
+ assert image.shape[1] == 3
149
+ F = 8
150
+ C = 4
151
+ shape = (num_frames, C, H // F, W // F)
152
+ if (H, W) != (576, 1024):
153
+ print(
154
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
155
+ )
156
+ if motion_bucket_id > 255:
157
+ print(
158
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
159
+ )
160
+ if fps_id < 5:
161
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
162
+ if fps_id > 30:
163
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
164
+
165
+ value_dict = {}
166
+ value_dict["motion_bucket_id"] = motion_bucket_id
167
+ value_dict["fps_id"] = fps_id
168
+ value_dict["cond_aug"] = cond_aug
169
+ value_dict["cond_frames_without_noise"] = image
170
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
171
+ value_dict["cond_aug"] = cond_aug
172
+ # low vram mode
173
+ model.conditioner.cpu()
174
+ model.first_stage_model.cpu()
175
+ torch.cuda.empty_cache()
176
+ model.sampler.verbose = True
177
+
178
+ with torch.no_grad():
179
+ with torch.autocast(device):
180
+ model.conditioner.to(device)
181
+ batch, batch_uc = get_batch(
182
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
183
+ value_dict,
184
+ [1, num_frames],
185
+ T=num_frames,
186
+ device=device,
187
+ )
188
+ c, uc = model.conditioner.get_unconditional_conditioning(
189
+ batch,
190
+ batch_uc=batch_uc,
191
+ force_uc_zero_embeddings=[
192
+ "cond_frames",
193
+ "cond_frames_without_noise",
194
+ ],
195
+ )
196
+ model.conditioner.cpu()
197
+ torch.cuda.empty_cache()
198
+
199
+ # from here, dtype is fp16
200
+ for k in ["crossattn", "concat"]:
201
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
202
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
203
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
204
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
205
+ for k in uc.keys():
206
+ uc[k] = uc[k].to(dtype=torch.float16)
207
+ c[k] = c[k].to(dtype=torch.float16)
208
+
209
+ randn = torch.randn(shape, device=device, dtype=torch.float16)
210
+ additional_model_inputs = {}
211
+ additional_model_inputs["image_only_indicator"] = torch.zeros(2, num_frames).to(device)
212
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
213
+
214
+ for k in additional_model_inputs:
215
+ if isinstance(additional_model_inputs[k], torch.Tensor):
216
+ additional_model_inputs[k] = additional_model_inputs[k].to(dtype=torch.float16)
217
+
218
+ def denoiser(input, sigma, c):
219
+ return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)
220
+
221
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
222
+ samples_z.to(dtype=model.first_stage_model.dtype)
223
+ model.en_and_decode_n_samples_a_time = decoding_t
224
+ model.first_stage_model.to(device)
225
+ samples_x = model.decode_first_stage(samples_z)
226
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
227
+ model.first_stage_model.cpu()
228
+ torch.cuda.empty_cache()
229
+
230
+ os.makedirs(output_folder, exist_ok=True)
231
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
232
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
233
+ writer = cv2.VideoWriter(
234
+ video_path,
235
+ cv2.VideoWriter_fourcc(*"MP4V"),
236
+ fps_id + 1,
237
+ (samples.shape[-1], samples.shape[-2]),
238
+ )
239
+ vid = (
240
+ (rearrange(samples, "t c h w -> t h w c") * 255)
241
+ .cpu()
242
+ .numpy()
243
+ .astype(np.uint8)
244
+ )
245
+ for frame in vid:
246
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
247
+ writer.write(frame)
248
+ writer.release()
249
+ all_out_paths.append(video_path)
250
+ return all_out_paths
251
+
252
+ import gradio as gr
253
+ import random
254
+
255
+ def url2imge(input_path: str)->str:
256
+ return input_path
257
+
258
+ def infer(input_path: str, resize_image: bool, n_frames: int, n_steps: int, seed: str, decoding_t: int) -> str:
259
+ if seed == "random":
260
+ seed = random.randint(0, 2**32)
261
+ seed = int(seed)
262
+ output_paths = sample(
263
+ input_path=input_path,
264
+ resize_image=resize_image,
265
+ num_frames=n_frames,
266
+ num_steps=n_steps,
267
+ fps_id=6,
268
+ motion_bucket_id=127,
269
+ cond_aug=0.02,
270
+ seed=seed,
271
+ decoding_t=decoding_t, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
272
+ device=device,
273
+ )
274
+ return output_paths[0]
275
+
276
+ with gr.Blocks() as demo:
277
+ with gr.Column():
278
+ text = gr.Textbox(label="input image url")
279
+ btn2 = gr.Button("url to imge")
280
+ image = gr.Image(label="input image", type="filepath")
281
+ resize_image = gr.Checkbox(label="resize to optimal size", value=True)
282
+ btn = gr.Button("Run")
283
+ with gr.Accordion(label="Advanced options", open=False):
284
+ n_frames = gr.Number(precision=0, label="number of frames", value=num_frames)
285
+ n_steps = gr.Number(precision=0, label="number of steps", value=num_steps)
286
+ seed = gr.Text(value="random", label="seed (integer or 'random')",)
287
+ decoding_t = gr.Number(precision=0, label="number of frames decoded at a time", value=2)
288
+ with gr.Column():
289
+ video_out = gr.Video(label="generated video")
290
+ examples = [["https://img.technews.tw/wp-content/uploads/2023/08/17150937/zac-durant-_6HzPU9Hyfg-unsplash-800x533.jpg"]]
291
+ inputs = [image, resize_image, n_frames, n_steps, seed, decoding_t]
292
+ outputs = [video_out]
293
+ btn.click(infer, inputs=inputs, outputs=outputs)
294
+ btn2.click(url2imge, inputs=text, outputs=image)
295
+ gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=infer)
296
+ demo.queue().launch(debug=True, share=True, inline=False, show_error=True)