multimodalart HF staff commited on
Commit
aeacc98
·
1 Parent(s): a3595a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -106
app.py CHANGED
@@ -75,116 +75,137 @@ def sample(
75
  ):
76
  output_folder = str(uuid.uuid4())
77
  torch.manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- all_img_paths = [image]
80
  for input_img_path in all_img_paths:
81
- if image.mode == "RGBA":
82
- image = image.convert("RGB")
83
- w, h = image.size
84
-
85
- if h % 64 != 0 or w % 64 != 0:
86
- width, height = map(lambda x: x - x % 64, (w, h))
87
- image = image.resize((width, height))
88
- print(
89
- f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
90
- )
91
-
92
- image = ToTensor()(image)
93
- image = image * 2.0 - 1.0
94
-
95
- image = image.unsqueeze(0).to(device)
96
- H, W = image.shape[2:]
97
- assert image.shape[1] == 3
98
- F = 8
99
- C = 4
100
- shape = (num_frames, C, H // F, W // F)
101
- if (H, W) != (576, 1024):
102
- print(
103
- "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
104
- )
105
- if motion_bucket_id > 255:
106
- print(
107
- "WARNING: High motion bucket! This may lead to suboptimal performance."
108
- )
109
-
110
- if fps_id < 5:
111
- print("WARNING: Small fps value! This may lead to suboptimal performance.")
112
-
113
- if fps_id > 30:
114
- print("WARNING: Large fps value! This may lead to suboptimal performance.")
115
-
116
- value_dict = {}
117
- value_dict["motion_bucket_id"] = motion_bucket_id
118
- value_dict["fps_id"] = fps_id
119
- value_dict["cond_aug"] = cond_aug
120
- value_dict["cond_frames_without_noise"] = image
121
- value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
122
- value_dict["cond_aug"] = cond_aug
123
-
124
- with torch.no_grad():
125
- with torch.autocast(device):
126
- batch, batch_uc = get_batch(
127
- get_unique_embedder_keys_from_conditioner(model.conditioner),
128
- value_dict,
129
- [1, num_frames],
130
- T=num_frames,
131
- device=device,
132
  )
133
- c, uc = model.conditioner.get_unconditional_conditioning(
134
- batch,
135
- batch_uc=batch_uc,
136
- force_uc_zero_embeddings=[
137
- "cond_frames",
138
- "cond_frames_without_noise",
139
- ],
 
 
 
 
 
 
140
  )
141
-
142
- for k in ["crossattn", "concat"]:
143
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
144
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
145
- c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
146
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
147
-
148
- randn = torch.randn(shape, device=device)
149
-
150
- additional_model_inputs = {}
151
- additional_model_inputs["image_only_indicator"] = torch.zeros(
152
- 2, num_frames
153
- ).to(device)
154
- additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
155
-
156
- def denoiser(input, sigma, c):
157
- return model.denoiser(
158
- model.model, input, sigma, c, **additional_model_inputs
159
- )
160
-
161
- samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
162
- model.en_and_decode_n_samples_a_time = decoding_t
163
- samples_x = model.decode_first_stage(samples_z)
164
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
165
-
166
- os.makedirs(output_folder, exist_ok=True)
167
- base_count = len(glob(os.path.join(output_folder, "*.mp4")))
168
- video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
169
- writer = cv2.VideoWriter(
170
- video_path,
171
- cv2.VideoWriter_fourcc(*'mp4v'),
172
- fps_id + 1,
173
- (samples.shape[-1], samples.shape[-2]),
174
  )
175
-
176
- samples = embed_watermark(samples)
177
- samples = filter(samples)
178
- vid = (
179
- (rearrange(samples, "t c h w -> t h w c") * 255)
180
- .cpu()
181
- .numpy()
182
- .astype(np.uint8)
183
- )
184
- for frame in vid:
185
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
186
- writer.write(frame)
187
- writer.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  return video_path
189
 
190
  def get_unique_embedder_keys_from_conditioner(conditioner):
@@ -272,7 +293,7 @@ with gr.Blocks(css=css) as demo:
272
  Generate 25 frames of video from a single image with SDV-XT. [Join the waitlist](https://stability.ai/contact) for the text-to-video web experience
273
  ''')
274
  with gr.Column():
275
- image = gr.Image(label="Upload your image (it will be center cropped to 1024x576)", type="pil")
276
  generate_btn = gr.Button("Generate")
277
  #with gr.Accordion("Advanced options", open=False):
278
  # cond_aug = gr.Slider(label="Conditioning augmentation", value=0.02, minimum=0.0)
 
75
  ):
76
  output_folder = str(uuid.uuid4())
77
  torch.manual_seed(seed)
78
+ path = Path(input_path)
79
+ all_img_paths = []
80
+ if path.is_file():
81
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
82
+ all_img_paths = [input_path]
83
+ else:
84
+ raise ValueError("Path is not valid image file.")
85
+ elif path.is_dir():
86
+ all_img_paths = sorted(
87
+ [
88
+ f
89
+ for f in path.iterdir()
90
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
91
+ ]
92
+ )
93
+ if len(all_img_paths) == 0:
94
+ raise ValueError("Folder does not contain any images.")
95
+ else:
96
+ raise ValueError
97
 
 
98
  for input_img_path in all_img_paths:
99
+ with Image.open(input_img_path) as image:
100
+ if image.mode == "RGBA":
101
+ image = image.convert("RGB")
102
+ w, h = image.size
103
+
104
+ if h % 64 != 0 or w % 64 != 0:
105
+ width, height = map(lambda x: x - x % 64, (w, h))
106
+ image = image.resize((width, height))
107
+ print(
108
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
+
111
+ image = ToTensor()(image)
112
+ image = image * 2.0 - 1.0
113
+
114
+ image = image.unsqueeze(0).to(device)
115
+ H, W = image.shape[2:]
116
+ assert image.shape[1] == 3
117
+ F = 8
118
+ C = 4
119
+ shape = (num_frames, C, H // F, W // F)
120
+ if (H, W) != (576, 1024):
121
+ print(
122
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
123
  )
124
+ if motion_bucket_id > 255:
125
+ print(
126
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
+
129
+ if fps_id < 5:
130
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
131
+
132
+ if fps_id > 30:
133
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
134
+
135
+ value_dict = {}
136
+ value_dict["motion_bucket_id"] = motion_bucket_id
137
+ value_dict["fps_id"] = fps_id
138
+ value_dict["cond_aug"] = cond_aug
139
+ value_dict["cond_frames_without_noise"] = image
140
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
141
+ value_dict["cond_aug"] = cond_aug
142
+
143
+
144
+
145
+ with torch.no_grad():
146
+ with torch.autocast(device):
147
+ batch, batch_uc = get_batch(
148
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
149
+ value_dict,
150
+ [1, num_frames],
151
+ T=num_frames,
152
+ device=device,
153
+ )
154
+ c, uc = model.conditioner.get_unconditional_conditioning(
155
+ batch,
156
+ batch_uc=batch_uc,
157
+ force_uc_zero_embeddings=[
158
+ "cond_frames",
159
+ "cond_frames_without_noise",
160
+ ],
161
+ )
162
+
163
+ for k in ["crossattn", "concat"]:
164
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
165
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
166
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
167
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
168
+
169
+ randn = torch.randn(shape, device=device)
170
+
171
+ additional_model_inputs = {}
172
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
173
+ 2, num_frames
174
+ ).to(device)
175
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
176
+
177
+ def denoiser(input, sigma, c):
178
+ return model.denoiser(
179
+ model.model, input, sigma, c, **additional_model_inputs
180
+ )
181
+
182
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
183
+ model.en_and_decode_n_samples_a_time = decoding_t
184
+ samples_x = model.decode_first_stage(samples_z)
185
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
186
+
187
+ os.makedirs(output_folder, exist_ok=True)
188
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
189
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
190
+ writer = cv2.VideoWriter(
191
+ video_path,
192
+ cv2.VideoWriter_fourcc(*"MP4V"),
193
+ fps_id + 1,
194
+ (samples.shape[-1], samples.shape[-2]),
195
+ )
196
+
197
+ samples = embed_watermark(samples)
198
+ samples = filter(samples)
199
+ vid = (
200
+ (rearrange(samples, "t c h w -> t h w c") * 255)
201
+ .cpu()
202
+ .numpy()
203
+ .astype(np.uint8)
204
+ )
205
+ for frame in vid:
206
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
207
+ writer.write(frame)
208
+ writer.release()
209
  return video_path
210
 
211
  def get_unique_embedder_keys_from_conditioner(conditioner):
 
293
  Generate 25 frames of video from a single image with SDV-XT. [Join the waitlist](https://stability.ai/contact) for the text-to-video web experience
294
  ''')
295
  with gr.Column():
296
+ image = gr.Image(label="Upload your image (it will be center cropped to 1024x576)", type="filepath")
297
  generate_btn = gr.Button("Generate")
298
  #with gr.Accordion("Advanced options", open=False):
299
  # cond_aug = gr.Slider(label="Conditioning augmentation", value=0.02, minimum=0.0)