salomonsky commited on
Commit
041cb9e
·
verified ·
1 Parent(s): 0299673

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +19 -11
inference.py CHANGED
@@ -8,6 +8,10 @@ import torch, face_detection
8
  from models import Wav2Lip
9
  import platform
10
 
 
 
 
 
11
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
 
13
  parser.add_argument('--checkpoint_path', type=str,
@@ -166,17 +170,21 @@ def _load(checkpoint_path):
166
  return checkpoint
167
 
168
  def load_model(path):
169
- model = Wav2Lip()
170
- print("Load checkpoint from: {}".format(path))
171
- checkpoint = _load(path)
172
- s = checkpoint["state_dict"]
173
- new_s = {}
174
- for k, v in s.items():
175
- new_s[k.replace('module.', '')] = v
176
- model.load_state_dict(new_s)
177
-
178
- model = model.to(device)
179
- return model.eval()
 
 
 
 
180
 
181
  def main():
182
  if not os.path.isfile(args.face):
 
8
  from models import Wav2Lip
9
  import platform
10
 
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ print('Using {} for inference.'.format(device))
13
+
14
+
15
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
16
 
17
  parser.add_argument('--checkpoint_path', type=str,
 
170
  return checkpoint
171
 
172
  def load_model(path):
173
+ model = Wav2Lip()
174
+ print("Load checkpoint from: {}".format(path))
175
+
176
+ if device == 'cuda':
177
+ checkpoint = torch.load(path)
178
+ else:
179
+ checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
180
+
181
+ s = checkpoint["state_dict"]
182
+ new_s = {k.replace('module.', ''): v for k, v in s.items()}
183
+
184
+ model.load_state_dict(new_s)
185
+ model = model.to(device)
186
+
187
+ return model.eval()
188
 
189
  def main():
190
  if not os.path.isfile(args.face):