xunsong.li commited on
Commit
ae534d1
·
1 Parent(s): 7ccc423

add demo for hf space

Browse files
Files changed (2) hide show
  1. app.py +10 -0
  2. local_app.py +263 -0
app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  import torch
8
  from diffusers import AutoencoderKL, DDIMScheduler
9
  from einops import repeat
 
10
  from omegaconf import OmegaConf
11
  from PIL import Image
12
  from torchvision import transforms
@@ -18,6 +19,15 @@ from src.models.unet_3d import UNet3DConditionModel
18
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
19
  from src.utils.util import get_fps, read_frames, save_videos_grid
20
 
 
 
 
 
 
 
 
 
 
21
 
22
  class AnimateController:
23
  def __init__(
 
7
  import torch
8
  from diffusers import AutoencoderKL, DDIMScheduler
9
  from einops import repeat
10
+ from huggingface_hub import snapshot_download
11
  from omegaconf import OmegaConf
12
  from PIL import Image
13
  from torchvision import transforms
 
19
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
20
  from src.utils.util import get_fps, read_frames, save_videos_grid
21
 
22
+ snapshot_download(
23
+ repo_id="runwayml/stable-diffusion-v1-5",
24
+ local_dir="./pretrained_weights/stable-diffusion-v1-5",
25
+ )
26
+ snapshot_download(
27
+ repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse"
28
+ )
29
+ snapshot_download(repo_id="patrolli/AnimateAnyone", local_dir="./pretrained_weights")
30
+
31
 
32
  class AnimateController:
33
  def __init__(
local_app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import AutoencoderKL, DDIMScheduler
9
+ from einops import repeat
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from transformers import CLIPVisionModelWithProjection
14
+
15
+ from src.models.pose_guider import PoseGuider
16
+ from src.models.unet_2d_condition import UNet2DConditionModel
17
+ from src.models.unet_3d import UNet3DConditionModel
18
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
19
+ from src.utils.util import get_fps, read_frames, save_videos_grid
20
+
21
+
22
+ class AnimateController:
23
+ def __init__(
24
+ self,
25
+ config_path="./configs/prompts/animation.yaml",
26
+ weight_dtype=torch.float16,
27
+ ):
28
+ # Read pretrained weights path from config
29
+ self.config = OmegaConf.load(config_path)
30
+ self.pipeline = None
31
+ self.weight_dtype = weight_dtype
32
+
33
+ def animate(
34
+ self,
35
+ ref_image,
36
+ pose_video_path,
37
+ width=512,
38
+ height=768,
39
+ length=24,
40
+ num_inference_steps=25,
41
+ cfg=3.5,
42
+ seed=123,
43
+ ):
44
+ generator = torch.manual_seed(seed)
45
+ if isinstance(ref_image, np.ndarray):
46
+ ref_image = Image.fromarray(ref_image)
47
+ if self.pipeline is None:
48
+ vae = AutoencoderKL.from_pretrained(
49
+ self.config.pretrained_vae_path,
50
+ ).to("cuda", dtype=self.weight_dtype)
51
+
52
+ reference_unet = UNet2DConditionModel.from_pretrained(
53
+ self.config.pretrained_base_model_path,
54
+ subfolder="unet",
55
+ ).to(dtype=self.weight_dtype, device="cuda")
56
+
57
+ inference_config_path = self.config.inference_config
58
+ infer_config = OmegaConf.load(inference_config_path)
59
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
60
+ self.config.pretrained_base_model_path,
61
+ self.config.motion_module_path,
62
+ subfolder="unet",
63
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
64
+ ).to(dtype=self.weight_dtype, device="cuda")
65
+
66
+ pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
67
+ dtype=self.weight_dtype, device="cuda"
68
+ )
69
+
70
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
71
+ self.config.image_encoder_path
72
+ ).to(dtype=self.weight_dtype, device="cuda")
73
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
74
+ scheduler = DDIMScheduler(**sched_kwargs)
75
+
76
+ # load pretrained weights
77
+ denoising_unet.load_state_dict(
78
+ torch.load(self.config.denoising_unet_path, map_location="cpu"),
79
+ strict=False,
80
+ )
81
+ reference_unet.load_state_dict(
82
+ torch.load(self.config.reference_unet_path, map_location="cpu"),
83
+ )
84
+ pose_guider.load_state_dict(
85
+ torch.load(self.config.pose_guider_path, map_location="cpu"),
86
+ )
87
+
88
+ pipe = Pose2VideoPipeline(
89
+ vae=vae,
90
+ image_encoder=image_enc,
91
+ reference_unet=reference_unet,
92
+ denoising_unet=denoising_unet,
93
+ pose_guider=pose_guider,
94
+ scheduler=scheduler,
95
+ )
96
+ pipe = pipe.to("cuda", dtype=self.weight_dtype)
97
+ self.pipeline = pipe
98
+
99
+ pose_images = read_frames(pose_video_path)
100
+ src_fps = get_fps(pose_video_path)
101
+
102
+ pose_list = []
103
+ pose_tensor_list = []
104
+ pose_transform = transforms.Compose(
105
+ [transforms.Resize((height, width)), transforms.ToTensor()]
106
+ )
107
+ for pose_image_pil in pose_images[:length]:
108
+ pose_list.append(pose_image_pil)
109
+ pose_tensor_list.append(pose_transform(pose_image_pil))
110
+
111
+ video = self.pipeline(
112
+ ref_image,
113
+ pose_list,
114
+ width=width,
115
+ height=height,
116
+ video_length=length,
117
+ num_inference_steps=num_inference_steps,
118
+ guidance_scale=cfg,
119
+ generator=generator,
120
+ ).videos
121
+
122
+ ref_image_tensor = pose_transform(ref_image) # (c, h, w)
123
+ ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
124
+ ref_image_tensor = repeat(
125
+ ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length
126
+ )
127
+ pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
128
+ pose_tensor = pose_tensor.transpose(0, 1)
129
+ pose_tensor = pose_tensor.unsqueeze(0)
130
+ video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
131
+
132
+ save_dir = f"./output/gradio"
133
+ if not os.path.exists(save_dir):
134
+ os.makedirs(save_dir, exist_ok=True)
135
+ date_str = datetime.now().strftime("%Y%m%d")
136
+ time_str = datetime.now().strftime("%H%M")
137
+ out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4")
138
+ save_videos_grid(
139
+ video,
140
+ out_path,
141
+ n_rows=3,
142
+ fps=src_fps,
143
+ )
144
+
145
+ torch.cuda.empty_cache()
146
+
147
+ return out_path
148
+
149
+
150
+ controller = AnimateController()
151
+
152
+
153
+ def ui():
154
+ with gr.Blocks() as demo:
155
+ gr.Markdown(
156
+ """
157
+ # Moore-AnimateAnyone Demo
158
+ """
159
+ )
160
+ animation = gr.Video(
161
+ format="mp4",
162
+ label="Animation Results",
163
+ height=448,
164
+ autoplay=True,
165
+ )
166
+
167
+ with gr.Row():
168
+ reference_image = gr.Image(label="Reference Image")
169
+ motion_sequence = gr.Video(
170
+ format="mp4", label="Motion Sequence", height=512
171
+ )
172
+
173
+ with gr.Column():
174
+ width_slider = gr.Slider(
175
+ label="Width", minimum=448, maximum=768, value=512, step=64
176
+ )
177
+ height_slider = gr.Slider(
178
+ label="Height", minimum=512, maximum=1024, value=768, step=64
179
+ )
180
+ length_slider = gr.Slider(
181
+ label="Video Length", minimum=24, maximum=128, value=24, step=24
182
+ )
183
+ with gr.Row():
184
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
185
+ seed_button = gr.Button(
186
+ value="\U0001F3B2", elem_classes="toolbutton"
187
+ )
188
+ seed_button.click(
189
+ fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)),
190
+ inputs=[],
191
+ outputs=[seed_textbox],
192
+ )
193
+ with gr.Row():
194
+ sampling_steps = gr.Slider(
195
+ label="Sampling steps",
196
+ value=25,
197
+ info="default: 25",
198
+ step=5,
199
+ maximum=30,
200
+ minimum=10,
201
+ )
202
+ guidance_scale = gr.Slider(
203
+ label="Guidance scale",
204
+ value=3.5,
205
+ info="default: 3.5",
206
+ step=0.5,
207
+ maximum=10,
208
+ minimum=2.0,
209
+ )
210
+ submit = gr.Button("Animate")
211
+
212
+ def read_video(video):
213
+ return video
214
+
215
+ def read_image(image):
216
+ return Image.fromarray(image)
217
+
218
+ # when user uploads a new video
219
+ motion_sequence.upload(read_video, motion_sequence, motion_sequence)
220
+ # when `first_frame` is updated
221
+ reference_image.upload(read_image, reference_image, reference_image)
222
+ # when the `submit` button is clicked
223
+ submit.click(
224
+ controller.animate,
225
+ [
226
+ reference_image,
227
+ motion_sequence,
228
+ width_slider,
229
+ height_slider,
230
+ length_slider,
231
+ sampling_steps,
232
+ guidance_scale,
233
+ seed_textbox,
234
+ ],
235
+ animation,
236
+ )
237
+
238
+ # Examples
239
+ gr.Markdown("## Examples")
240
+ gr.Examples(
241
+ examples=[
242
+ [
243
+ "./configs/inference/ref_images/anyone-5.png",
244
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
245
+ ],
246
+ [
247
+ "./configs/inference/ref_images/anyone-10.png",
248
+ "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
249
+ ],
250
+ [
251
+ "./configs/inference/ref_images/anyone-2.png",
252
+ "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
253
+ ],
254
+ ],
255
+ inputs=[reference_image, motion_sequence],
256
+ outputs=animation,
257
+ )
258
+
259
+ return demo
260
+
261
+
262
+ demo = ui()
263
+ demo.launch(share=True)