Spaces:
Running
Running
Update inference.py
Browse files- inference.py +19 -11
inference.py
CHANGED
@@ -8,6 +8,10 @@ import torch, face_detection
|
|
8 |
from models import Wav2Lip
|
9 |
import platform
|
10 |
|
|
|
|
|
|
|
|
|
11 |
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
|
12 |
|
13 |
parser.add_argument('--checkpoint_path', type=str,
|
@@ -166,17 +170,21 @@ def _load(checkpoint_path):
|
|
166 |
return checkpoint
|
167 |
|
168 |
def load_model(path):
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
180 |
|
181 |
def main():
|
182 |
if not os.path.isfile(args.face):
|
|
|
8 |
from models import Wav2Lip
|
9 |
import platform
|
10 |
|
11 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
12 |
+
print('Using {} for inference.'.format(device))
|
13 |
+
|
14 |
+
|
15 |
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
|
16 |
|
17 |
parser.add_argument('--checkpoint_path', type=str,
|
|
|
170 |
return checkpoint
|
171 |
|
172 |
def load_model(path):
|
173 |
+
model = Wav2Lip()
|
174 |
+
print("Load checkpoint from: {}".format(path))
|
175 |
+
|
176 |
+
if device == 'cuda':
|
177 |
+
checkpoint = torch.load(path)
|
178 |
+
else:
|
179 |
+
checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
|
180 |
+
|
181 |
+
s = checkpoint["state_dict"]
|
182 |
+
new_s = {k.replace('module.', ''): v for k, v in s.items()}
|
183 |
+
|
184 |
+
model.load_state_dict(new_s)
|
185 |
+
model = model.to(device)
|
186 |
+
|
187 |
+
return model.eval()
|
188 |
|
189 |
def main():
|
190 |
if not os.path.isfile(args.face):
|