salomonsky commited on
Commit
b8097c8
1 Parent(s): 4952875

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +224 -239
inference.py CHANGED
@@ -1,295 +1,280 @@
 
1
  import numpy as np
2
- import cv2
3
- import os
4
- import sys
5
- import argparse
6
- import audio
7
  from tqdm import tqdm
8
  from glob import glob
9
- import torch
10
- import face_detection
11
  from models import Wav2Lip
12
  import platform
13
 
14
- try:
15
- last_face = cv2.imread("last_face.jpg")
16
- except Exception as e:
17
- last_face = None
18
 
19
- if last_face is None:
20
- last_face = cv2.imread("imagen_por_defecto.jpg")
21
 
22
- parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- parser.add_argument('--checkpoint_path', type=str, required=True)
25
- parser.add_argument('--face', type=str, required=True)
26
- parser.add_argument('--audio', type=str, required=True)
27
- parser.add_argument('--outfile', type=str, default='results/result_voice.mp4')
28
- parser.add_argument('--static', type=bool, default=False)
29
- parser.add_argument('--fps', type=float, default=25., required=False)
30
- parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0])
31
- parser.add_argument('--face_det_batch_size', type=int, default=16)
32
- parser.add_argument('--wav2lip_batch_size', type=int, default=128)
33
- parser.add_argument('--resize_factor', default=1, type=int)
34
- parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1])
35
- parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1])
36
- parser.add_argument('--rotate', default=False, action='store_true')
37
- parser.add_argument('--nosmooth', default=False, action='store_true')
 
 
 
38
 
39
  args = parser.parse_args()
40
  args.img_size = 96
41
 
42
  if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
43
- args.static = True
44
 
45
  def get_smoothened_boxes(boxes, T):
46
- for i in range(len(boxes)):
47
- if i + T > len(boxes):
48
- window = boxes[len(boxes) - T:]
49
- else:
50
- window = boxes[i : i + T]
51
- boxes[i] = np.mean(window, axis=0)
52
- return boxes
53
 
54
  def face_detect(images):
55
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
56
- batch_size = args.face_det_batch_size
57
- last_face = None
58
-
59
- while 1:
60
- predictions = []
61
- try:
62
- for i in tqdm(range(0, len(images), batch_size)):
63
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
64
- except RuntimeError:
65
- if batch_size == 1:
66
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
67
- batch_size //= 2
68
- continue
69
- break
70
- head_exist = []
71
- results = []
72
- pady1, pady2, padx1, padx2 = args.pads
73
-
74
- first_head_rect = None
75
- first_head_image = None
76
- for rect, image in zip(predictions, images):
77
- if rect is not None:
78
- first_head_rect = rect
79
- first_head_image = image
80
- break
81
- for rect, image in zip(predictions, images):
82
- if rect is None:
83
- head_exist.append(False)
84
- if len(results) == 0:
85
- y1 = max(0, first_head_rect[1] - pady1)
86
- y2 = min(first_head_image.shape[0], first_head_rect[3] + pady2)
87
- x1 = max(0, first_head_rect[0] - padx1)
88
- x2 = min(first_head_image.shape[1], first_head_rect[2] + padx2)
89
- results.append([x1, y1, x2, y2])
90
- else:
91
- results.append(results[-1])
92
- else:
93
- head_exist.append(True)
94
- y1 = max(0, rect[1] - pady1)
95
- y2 = min(image.shape[0], rect[3] + pady2)
96
- x1 = max(0, rect[0] - padx1)
97
- x2 = min(image.shape[1], rect[2] + padx2)
98
- results.append([x1, y1, x2, y2])
99
- last_face = image[y1: y2, x1:x2]
100
- cv2.imwrite("last_face.jpg", last_face)
101
-
102
- boxes = np.array(results)
103
- if not args.nosmooth:
104
- boxes = get_smoothened_boxes(boxes, T=5)
105
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
106
-
107
- del detector
108
- return results, head_exist
109
 
110
  def datagen(frames, mels):
111
- img_batch, head_exist_batch, mel_batch, frame_batch, coords_batch = [], [], [], [], []
112
-
113
- if args.box[0] == -1:
114
- if not args.static:
115
- face_det_results, head_exist = face_detect(frames)
116
- else:
117
- face_det_results, head_exist = face_detect([frames[0]])
118
- else:
119
- y1, y2, x1, x2 = args.box
120
- face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
121
- head_exist = [True] * len(frames)
122
-
123
- for i, m in enumerate(mels):
124
- idx = 0 if args.static else i % len(frames)
125
- frame_to_save = frames[idx].copy()
126
- face, coords = face_det_results[idx].copy()
127
-
128
- face = cv2.resize(face, (args.img_size, args.img_size))
129
- head_exist_batch.append(head_exist[idx])
130
- img_batch.append(face)
131
- melspec = m
132
- mel_batch.append(melspec)
133
- frame_batch.append(frame_to_save)
134
- coords_batch.append(coords)
135
-
136
- if len(img_batch) >= args.wav2lip_batch_size:
137
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
138
-
139
- img_masked = img_batch.copy()
140
- img_masked[:, args.img_size // 2:] = 0
141
-
142
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
143
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
144
-
145
- yield img_batch, head_exist_batch, mel_batch, frame_batch, coords_batch
146
- img_batch, head_exist_batch, mel_batch, frame_batch, coords_batch = [], [], [], [], []
147
-
148
- last_face = cv2.imread("last_face.jpg")
149
- last_face = cv2.resize(last_face, (args.img_size, args.img_size))
150
- img_batch.append(last_face)
151
- melspec = mels[-1]
152
- mel_batch.append(melspec)
153
- frame_batch.append(frames[-1])
154
- coords_batch.append(face_det_results[-1][1])
155
- head_exist_batch.append(head_exist[-1])
156
-
157
- if len(img_batch) > 0:
158
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
159
-
160
- img_masked = img_batch.copy()
161
- img_masked[:, args.img_size // 2:] = 0
162
-
163
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
164
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
165
-
166
- yield img_batch, head_exist_batch, mel_batch, frame_batch, coords_batch
167
 
168
  mel_step_size = 16
169
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
170
 
171
  def _load(checkpoint_path):
172
- if device == 'cuda':
173
- checkpoint = torch.load(checkpoint_path)
174
- else:
175
- checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
176
- return checkpoint
 
177
 
178
  def load_model(path):
179
- model = Wav2Lip()
180
- print("Load checkpoint from: {}".format(path))
181
- checkpoint = _load(path)
182
- s = checkpoint["state_dict"]
183
- new_s = {}
184
- for k, v in s.items():
185
- new_s[k.replace('module.', '')] = v
186
- model.load_state_dict(new_s)
187
- model = model.to(device)
188
- return model.eval()
 
189
 
190
  def main():
191
- if not os.path.isfile(args.face):
192
- raise ValueError('--face argument must be a valid path to video/image file')
193
 
194
- elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
195
- full_frames = [cv2.imread(args.face)]
196
- fps = args.fps
197
 
198
- else:
199
- video_stream = cv2.VideoCapture(args.face)
200
- fps = video_stream.get(cv2.CAP_PROP_FPS)
201
 
202
- print('Reading video frames...')
203
 
204
- full_frames = []
205
- while 1:
206
- still_reading, frame = video_stream.read()
207
- if not still_reading:
208
- video_stream.release()
209
- break
210
- if args.resize_factor > 1:
211
- frame = cv2.resize(frame, (frame.shape[1] // args.resize_factor, frame.shape[0] // args.resize_factor))
212
 
213
- if args.rotate:
214
- frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
215
 
216
- y1, y2, x1, x2 = args.crop
217
- if x2 == -1: x2 = frame.shape[1]
218
- if y2 == -1: y2 = frame.shape[0]
219
 
220
- frame = frame[y1:y2, x1:x2]
221
 
222
- full_frames.append(frame)
223
 
224
- print ("Number of frames available for inference: "+str(len(full_frames)))
225
 
226
- if not args.audio.endswith('.wav'):
227
- print('Extracting raw audio...')
228
- command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
229
 
230
- subprocess.call(command, shell=True)
231
- args.audio = 'temp/temp.wav'
232
 
233
- wav = audio.load_wav(args.audio, 16000)
234
- mel = audio.melspectrogram(wav)
235
- print(mel.shape)
236
 
237
- if np.isnan(mel.reshape(-1)).sum() > 0:
238
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
239
 
240
- mel_chunks = []
241
- mel_idx_multiplier = 80. / fps
242
- i = 0
243
- while 1:
244
- start_idx = int(i * mel_idx_multiplier)
245
- if start_idx + mel_step_size > len(mel[0]):
246
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
247
- break
248
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
249
- i += 1
250
 
251
- print("Length of mel chunks: {}".format(len(mel_chunks)))
252
 
253
- full_frames = full_frames[:len(mel_chunks)]
254
 
255
- batch_size = args.wav2lip_batch_size
256
- gen = datagen(full_frames.copy(), mel_chunks)
257
 
258
- for i, (img_batch, exist_head_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
259
- total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
260
- if i == 0:
261
- model = load_model(args.checkpoint_path)
262
- print("Model loaded")
263
 
264
- frame_h, frame_w = full_frames[0].shape[:-1]
265
- out = cv2.VideoWriter('temp/result.avi',
266
- cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
267
 
268
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
269
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
270
 
271
- with torch.no_grad():
272
- pred = model(mel_batch, img_batch)
273
 
274
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
 
 
 
 
275
 
276
- i = 0
277
- for p, f, c, exist in zip(pred, frames, coords, exist_head_batch):
278
- i += 1
279
- if not exist:
280
- out.write(f)
281
- else:
282
- y1, y2, x1, x2 = c
283
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
284
- head_high, head_width, _ = p.shape
285
- width_cut = int(head_width * 0.2)
286
- f[y1:y2, x1+width_cut:x2-width_cut] = p[:, width_cut:head_width-width_cut]
287
- out.write(f)
288
 
289
- out.release()
290
 
291
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
292
- subprocess.call(command, shell=platform.system() != 'Windows')
293
 
294
  if __name__ == '__main__':
295
- main()
 
1
+ from os import listdir, path
2
  import numpy as np
3
+ import scipy, cv2, os, sys, argparse, audio
4
+ import json, subprocess, random, string
 
 
 
5
  from tqdm import tqdm
6
  from glob import glob
7
+ import torch, face_detection
 
8
  from models import Wav2Lip
9
  import platform
10
 
11
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
 
 
 
12
 
13
+ parser.add_argument('--checkpoint_path', type=str,
14
+ help='Name of saved checkpoint to load weights from', required=True)
15
 
16
+ parser.add_argument('--face', type=str,
17
+ help='Filepath of video/image that contains faces to use', required=True)
18
+ parser.add_argument('--audio', type=str,
19
+ help='Filepath of video/audio file to use as raw audio source', required=True)
20
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
21
+ default='results/result_voice.mp4')
22
+
23
+ parser.add_argument('--static', type=bool,
24
+ help='If True, then use only first video frame for inference', default=False)
25
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
26
+ default=25., required=False)
27
+
28
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
29
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
30
+
31
+ parser.add_argument('--face_det_batch_size', type=int,
32
+ help='Batch size for face detection', default=16)
33
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
34
 
35
+ parser.add_argument('--resize_factor', default=1, type=int,
36
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
37
+
38
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
39
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
40
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
41
+
42
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
43
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
44
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
45
+
46
+ parser.add_argument('--rotate', default=False, action='store_true',
47
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
48
+ 'Use if you get a flipped result, despite feeding a normal looking video')
49
+
50
+ parser.add_argument('--nosmooth', default=False, action='store_true',
51
+ help='Prevent smoothing face detections over a short temporal window')
52
 
53
  args = parser.parse_args()
54
  args.img_size = 96
55
 
56
  if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
57
+ args.static = True
58
 
59
  def get_smoothened_boxes(boxes, T):
60
+ for i in range(len(boxes)):
61
+ if i + T > len(boxes):
62
+ window = boxes[len(boxes) - T:]
63
+ else:
64
+ window = boxes[i : i + T]
65
+ boxes[i] = np.mean(window, axis=0)
66
+ return boxes
67
 
68
  def face_detect(images):
69
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
70
+ flip_input=False, device=device)
71
+
72
+ batch_size = args.face_det_batch_size
73
+
74
+ while 1:
75
+ predictions = []
76
+ try:
77
+ for i in tqdm(range(0, len(images), batch_size)):
78
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
79
+ except RuntimeError:
80
+ if batch_size == 1:
81
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
82
+ batch_size //= 2
83
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
84
+ continue
85
+ break
86
+
87
+ results = []
88
+ pady1, pady2, padx1, padx2 = args.pads
89
+ for rect, image in zip(predictions, images):
90
+ if rect is None:
91
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
92
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
93
+
94
+ y1 = max(0, rect[1] - pady1)
95
+ y2 = min(image.shape[0], rect[3] + pady2)
96
+ x1 = max(0, rect[0] - padx1)
97
+ x2 = min(image.shape[1], rect[2] + padx2)
98
+
99
+ results.append([x1, y1, x2, y2])
100
+
101
+ boxes = np.array(results)
102
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
103
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
104
+
105
+ del detector
106
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def datagen(frames, mels):
109
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
+
111
+ if args.box[0] == -1:
112
+ if not args.static:
113
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
114
+ else:
115
+ face_det_results = face_detect([frames[0]])
116
+ else:
117
+ print('Using the specified bounding box instead of face detection...')
118
+ y1, y2, x1, x2 = args.box
119
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
120
+
121
+ for i, m in enumerate(mels):
122
+ idx = 0 if args.static else i%len(frames)
123
+ frame_to_save = frames[idx].copy()
124
+ face, coords = face_det_results[idx].copy()
125
+
126
+ face = cv2.resize(face, (args.img_size, args.img_size))
127
+
128
+ img_batch.append(face)
129
+ mel_batch.append(m)
130
+ frame_batch.append(frame_to_save)
131
+ coords_batch.append(coords)
132
+
133
+ if len(img_batch) >= args.wav2lip_batch_size:
134
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
135
+
136
+ img_masked = img_batch.copy()
137
+ img_masked[:, args.img_size//2:] = 0
138
+
139
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
140
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
141
+
142
+ yield img_batch, mel_batch, frame_batch, coords_batch
143
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
144
+
145
+ if len(img_batch) > 0:
146
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
147
+
148
+ img_masked = img_batch.copy()
149
+ img_masked[:, args.img_size//2:] = 0
150
+
151
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
152
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
153
+
154
+ yield img_batch, mel_batch, frame_batch, coords_batch
 
 
 
 
 
 
 
 
 
 
155
 
156
  mel_step_size = 16
157
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
+ print('Using {} for inference.'.format(device))
159
 
160
  def _load(checkpoint_path):
161
+ if device == 'cuda':
162
+ checkpoint = torch.load(checkpoint_path)
163
+ else:
164
+ checkpoint = torch.load(checkpoint_path,
165
+ map_location=lambda storage, loc: storage)
166
+ return checkpoint
167
 
168
  def load_model(path):
169
+ model = Wav2Lip()
170
+ print("Load checkpoint from: {}".format(path))
171
+ checkpoint = _load(path)
172
+ s = checkpoint["state_dict"]
173
+ new_s = {}
174
+ for k, v in s.items():
175
+ new_s[k.replace('module.', '')] = v
176
+ model.load_state_dict(new_s)
177
+
178
+ model = model.to(device)
179
+ return model.eval()
180
 
181
  def main():
182
+ if not os.path.isfile(args.face):
183
+ raise ValueError('--face argument must be a valid path to video/image file')
184
 
185
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
186
+ full_frames = [cv2.imread(args.face)]
187
+ fps = args.fps
188
 
189
+ else:
190
+ video_stream = cv2.VideoCapture(args.face)
191
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
192
 
193
+ print('Reading video frames...')
194
 
195
+ full_frames = []
196
+ while 1:
197
+ still_reading, frame = video_stream.read()
198
+ if not still_reading:
199
+ video_stream.release()
200
+ break
201
+ if args.resize_factor > 1:
202
+ frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
203
 
204
+ if args.rotate:
205
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
206
 
207
+ y1, y2, x1, x2 = args.crop
208
+ if x2 == -1: x2 = frame.shape[1]
209
+ if y2 == -1: y2 = frame.shape[0]
210
 
211
+ frame = frame[y1:y2, x1:x2]
212
 
213
+ full_frames.append(frame)
214
 
215
+ print ("Number of frames available for inference: "+str(len(full_frames)))
216
 
217
+ if not args.audio.endswith('.wav'):
218
+ print('Extracting raw audio...')
219
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
220
 
221
+ subprocess.call(command, shell=True)
222
+ args.audio = 'temp/temp.wav'
223
 
224
+ wav = audio.load_wav(args.audio, 16000)
225
+ mel = audio.melspectrogram(wav)
226
+ print(mel.shape)
227
 
228
+ if np.isnan(mel.reshape(-1)).sum() > 0:
229
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
230
 
231
+ mel_chunks = []
232
+ mel_idx_multiplier = 80./fps
233
+ i = 0
234
+ while 1:
235
+ start_idx = int(i * mel_idx_multiplier)
236
+ if start_idx + mel_step_size > len(mel[0]):
237
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
238
+ break
239
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
240
+ i += 1
241
 
242
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
243
 
244
+ full_frames = full_frames[:len(mel_chunks)]
245
 
246
+ batch_size = args.wav2lip_batch_size
247
+ gen = datagen(full_frames.copy(), mel_chunks)
248
 
249
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
250
+ total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
251
+ if i == 0:
252
+ model = load_model(args.checkpoint_path)
253
+ print ("Model loaded")
254
 
255
+ frame_h, frame_w = full_frames[0].shape[:-1]
256
+ out = cv2.VideoWriter('temp/result.avi',
257
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
258
 
259
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
260
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
261
 
262
+ with torch.no_grad():
263
+ pred = model(mel_batch, img_batch)
264
 
265
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
266
+
267
+ for p, f, c in zip(pred, frames, coords):
268
+ y1, y2, x1, x2 = c
269
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
270
 
271
+ f[y1:y2, x1:x2] = p
272
+ out.write(f)
 
 
 
 
 
 
 
 
 
 
273
 
274
+ out.release()
275
 
276
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
277
+ subprocess.call(command, shell=platform.system() != 'Windows')
278
 
279
  if __name__ == '__main__':
280
+ main()