salomonsky commited on
Commit
a16a4f0
1 Parent(s): b8097c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -280
app.py CHANGED
@@ -1,280 +1,69 @@
1
- from os import listdir, path
2
- import numpy as np
3
- import scipy, cv2, os, sys, argparse, audio
4
- import json, subprocess, random, string
5
- from tqdm import tqdm
6
- from glob import glob
7
- import torch, face_detection
8
- from models import Wav2Lip
9
- import platform
10
-
11
- parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
-
13
- parser.add_argument('--checkpoint_path', type=str,
14
- help='Name of saved checkpoint to load weights from', required=True)
15
-
16
- parser.add_argument('--face', type=str,
17
- help='Filepath of video/image that contains faces to use', required=True)
18
- parser.add_argument('--audio', type=str,
19
- help='Filepath of video/audio file to use as raw audio source', required=True)
20
- parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
21
- default='results/result_voice.mp4')
22
-
23
- parser.add_argument('--static', type=bool,
24
- help='If True, then use only first video frame for inference', default=False)
25
- parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
26
- default=25., required=False)
27
-
28
- parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
29
- help='Padding (top, bottom, left, right). Please adjust to include chin at least')
30
-
31
- parser.add_argument('--face_det_batch_size', type=int,
32
- help='Batch size for face detection', default=16)
33
- parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
34
-
35
- parser.add_argument('--resize_factor', default=1, type=int,
36
- help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
37
-
38
- parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
39
- help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
40
- 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
41
-
42
- parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
43
- help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
44
- 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
45
-
46
- parser.add_argument('--rotate', default=False, action='store_true',
47
- help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
48
- 'Use if you get a flipped result, despite feeding a normal looking video')
49
-
50
- parser.add_argument('--nosmooth', default=False, action='store_true',
51
- help='Prevent smoothing face detections over a short temporal window')
52
-
53
- args = parser.parse_args()
54
- args.img_size = 96
55
-
56
- if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
57
- args.static = True
58
-
59
- def get_smoothened_boxes(boxes, T):
60
- for i in range(len(boxes)):
61
- if i + T > len(boxes):
62
- window = boxes[len(boxes) - T:]
63
- else:
64
- window = boxes[i : i + T]
65
- boxes[i] = np.mean(window, axis=0)
66
- return boxes
67
-
68
- def face_detect(images):
69
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
70
- flip_input=False, device=device)
71
-
72
- batch_size = args.face_det_batch_size
73
-
74
- while 1:
75
- predictions = []
76
- try:
77
- for i in tqdm(range(0, len(images), batch_size)):
78
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
79
- except RuntimeError:
80
- if batch_size == 1:
81
- raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
82
- batch_size //= 2
83
- print('Recovering from OOM error; New batch size: {}'.format(batch_size))
84
- continue
85
- break
86
-
87
- results = []
88
- pady1, pady2, padx1, padx2 = args.pads
89
- for rect, image in zip(predictions, images):
90
- if rect is None:
91
- cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
92
- raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
93
-
94
- y1 = max(0, rect[1] - pady1)
95
- y2 = min(image.shape[0], rect[3] + pady2)
96
- x1 = max(0, rect[0] - padx1)
97
- x2 = min(image.shape[1], rect[2] + padx2)
98
-
99
- results.append([x1, y1, x2, y2])
100
-
101
- boxes = np.array(results)
102
- if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
103
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
104
-
105
- del detector
106
- return results
107
-
108
- def datagen(frames, mels):
109
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
-
111
- if args.box[0] == -1:
112
- if not args.static:
113
- face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
114
- else:
115
- face_det_results = face_detect([frames[0]])
116
- else:
117
- print('Using the specified bounding box instead of face detection...')
118
- y1, y2, x1, x2 = args.box
119
- face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
120
-
121
- for i, m in enumerate(mels):
122
- idx = 0 if args.static else i%len(frames)
123
- frame_to_save = frames[idx].copy()
124
- face, coords = face_det_results[idx].copy()
125
-
126
- face = cv2.resize(face, (args.img_size, args.img_size))
127
-
128
- img_batch.append(face)
129
- mel_batch.append(m)
130
- frame_batch.append(frame_to_save)
131
- coords_batch.append(coords)
132
-
133
- if len(img_batch) >= args.wav2lip_batch_size:
134
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
135
-
136
- img_masked = img_batch.copy()
137
- img_masked[:, args.img_size//2:] = 0
138
-
139
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
140
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
141
-
142
- yield img_batch, mel_batch, frame_batch, coords_batch
143
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
144
-
145
- if len(img_batch) > 0:
146
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
147
-
148
- img_masked = img_batch.copy()
149
- img_masked[:, args.img_size//2:] = 0
150
-
151
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
152
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
153
-
154
- yield img_batch, mel_batch, frame_batch, coords_batch
155
-
156
- mel_step_size = 16
157
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
- print('Using {} for inference.'.format(device))
159
-
160
- def _load(checkpoint_path):
161
- if device == 'cuda':
162
- checkpoint = torch.load(checkpoint_path)
163
- else:
164
- checkpoint = torch.load(checkpoint_path,
165
- map_location=lambda storage, loc: storage)
166
- return checkpoint
167
-
168
- def load_model(path):
169
- model = Wav2Lip()
170
- print("Load checkpoint from: {}".format(path))
171
- checkpoint = _load(path)
172
- s = checkpoint["state_dict"]
173
- new_s = {}
174
- for k, v in s.items():
175
- new_s[k.replace('module.', '')] = v
176
- model.load_state_dict(new_s)
177
-
178
- model = model.to(device)
179
- return model.eval()
180
-
181
- def main():
182
- if not os.path.isfile(args.face):
183
- raise ValueError('--face argument must be a valid path to video/image file')
184
-
185
- elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
186
- full_frames = [cv2.imread(args.face)]
187
- fps = args.fps
188
-
189
- else:
190
- video_stream = cv2.VideoCapture(args.face)
191
- fps = video_stream.get(cv2.CAP_PROP_FPS)
192
-
193
- print('Reading video frames...')
194
-
195
- full_frames = []
196
- while 1:
197
- still_reading, frame = video_stream.read()
198
- if not still_reading:
199
- video_stream.release()
200
- break
201
- if args.resize_factor > 1:
202
- frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
203
-
204
- if args.rotate:
205
- frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
206
-
207
- y1, y2, x1, x2 = args.crop
208
- if x2 == -1: x2 = frame.shape[1]
209
- if y2 == -1: y2 = frame.shape[0]
210
-
211
- frame = frame[y1:y2, x1:x2]
212
-
213
- full_frames.append(frame)
214
-
215
- print ("Number of frames available for inference: "+str(len(full_frames)))
216
-
217
- if not args.audio.endswith('.wav'):
218
- print('Extracting raw audio...')
219
- command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
220
-
221
- subprocess.call(command, shell=True)
222
- args.audio = 'temp/temp.wav'
223
-
224
- wav = audio.load_wav(args.audio, 16000)
225
- mel = audio.melspectrogram(wav)
226
- print(mel.shape)
227
-
228
- if np.isnan(mel.reshape(-1)).sum() > 0:
229
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
230
-
231
- mel_chunks = []
232
- mel_idx_multiplier = 80./fps
233
- i = 0
234
- while 1:
235
- start_idx = int(i * mel_idx_multiplier)
236
- if start_idx + mel_step_size > len(mel[0]):
237
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
238
- break
239
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
240
- i += 1
241
-
242
- print("Length of mel chunks: {}".format(len(mel_chunks)))
243
-
244
- full_frames = full_frames[:len(mel_chunks)]
245
-
246
- batch_size = args.wav2lip_batch_size
247
- gen = datagen(full_frames.copy(), mel_chunks)
248
-
249
- for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
250
- total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
251
- if i == 0:
252
- model = load_model(args.checkpoint_path)
253
- print ("Model loaded")
254
-
255
- frame_h, frame_w = full_frames[0].shape[:-1]
256
- out = cv2.VideoWriter('temp/result.avi',
257
- cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
258
-
259
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
260
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
261
-
262
- with torch.no_grad():
263
- pred = model(mel_batch, img_batch)
264
-
265
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
266
-
267
- for p, f, c in zip(pred, frames, coords):
268
- y1, y2, x1, x2 = c
269
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
270
-
271
- f[y1:y2, x1:x2] = p
272
- out.write(f)
273
-
274
- out.release()
275
-
276
- command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
277
- subprocess.call(command, shell=platform.system() != 'Windows')
278
-
279
- if __name__ == '__main__':
280
- main()
 
1
+ import os
2
+ import openai
3
+ import gradio as gr
4
+ import subprocess
5
+ from gtts import gTTS
6
+
7
+ openai.api_key = os.environ.get("openai_api_key")
8
+
9
+ def generate_output(name, birth_date):
10
+ if not birth_date:
11
+ return None, "El campo de fecha de nacimiento es obligatorio."
12
+
13
+ prompt = f"T煤 hor贸scopo de hoy, si naciste el {birth_date}, es:"
14
+ response = openai.Completion.create(
15
+ engine="text-davinci-003",
16
+ prompt=prompt,
17
+ max_tokens=180,
18
+ temperature=0.6,
19
+ n=1,
20
+ stop=None,
21
+ )
22
+ gpt3_output = response.choices[0].text.strip()
23
+ personalized_response = f"Tu hor贸scopo {name} nacido el {birth_date} es: {gpt3_output}"
24
+
25
+ if len(response.choices) == 0 or 'text' not in response.choices[0]:
26
+ return None, "No se pudo generar el texto."
27
+
28
+ try:
29
+ tts = gTTS(personalized_response, lang='es')
30
+ audio_path = "audio.mp3"
31
+ tts.save(audio_path)
32
+ except Exception as e:
33
+ return None, f"No se pudo generar el audio: {str(e)}"
34
+
35
+ video_path = "video.mp4"
36
+ command = f"python3 inference.py --checkpoint_path checkpoints/wav2lip_gan.pth --face face.jpg --audio {audio_path} --outfile {video_path} --nosmooth"
37
+ process = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
38
+ if process.returncode != 0:
39
+ error_message = process.stderr
40
+ return None, f"No se pudo generar el video: {error_message}"
41
+
42
+ if os.path.isfile(video_path):
43
+ return video_path, None
44
+ return None, "No se pudo generar el video"
45
+
46
+ name_input = gr.inputs.Textbox(lines=1, placeholder="Escribe tu Nombre Completo", label="Nombre")
47
+ birth_date_input = gr.inputs.Textbox(lines=1, placeholder="Fecha Nacimiento - DD/MM/AAAA", label="Cumplea帽os")
48
+ output = gr.outputs.Video(label="Resultado", type="mp4").style(width=350)
49
+ error_output = gr.outputs.Textbox(label="Errores")
50
+
51
+ def generate_and_display_output(name, birth_date):
52
+ video_path, error_message = generate_output(name, birth_date)
53
+ if error_message:
54
+ print(f"Error: {error_message}")
55
+ return None, error_message
56
+ else:
57
+ return video_path, None
58
+
59
+ outputs = [output, error_output]
60
+
61
+ iface = gr.Interface(
62
+ fn=generate_and_display_output,
63
+ inputs=[name_input, birth_date_input],
64
+ outputs=outputs,
65
+ layout="vertical",
66
+ theme="darkdefault"
67
+ )
68
+
69
+ iface.launch(share=True)