liuhuohuo commited on
Commit
ab8ff7d
1 Parent(s): 2339c38
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -7,7 +7,7 @@ import time
7
  from omegaconf import OmegaConf
8
  import torch
9
  import torchvision
10
- from pytorch_lightning import seed_everything
11
  from huggingface_hub import hf_hub_download
12
  from einops import repeat
13
  import torchvision.transforms as transforms
@@ -21,6 +21,16 @@ sys.path.insert(0, "scripts/evaluation")
21
  from lvdm.models.samplers.ddim import DDIMSampler, DDIMStyleSampler
22
 
23
 
 
 
 
 
 
 
 
 
 
 
24
  def load_model_checkpoint(model, ckpt):
25
  state_dict = torch.load(ckpt, map_location="cpu")
26
  if "state_dict" in list(state_dict.keys()):
@@ -177,9 +187,9 @@ demo_exaples_video = [
177
  ]
178
  css = """
179
  #input_img {max-height: 320px;}
180
- #input_img [data-testid="image"], #input_img [data-testid="image"] > div{max-height: 320px; max-width: 512px;}
181
  #output_img {max-height: 400px;}
182
- #output_vid {max-height: 320px;}
183
  """
184
 
185
  with gr.Blocks(analytics_enabled=False, css=css) as demo_iface:
 
7
  from omegaconf import OmegaConf
8
  import torch
9
  import torchvision
10
+ import numpy as np
11
  from huggingface_hub import hf_hub_download
12
  from einops import repeat
13
  import torchvision.transforms as transforms
 
21
  from lvdm.models.samplers.ddim import DDIMSampler, DDIMStyleSampler
22
 
23
 
24
+ def seed_everything(seed):
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed(seed)
27
+ np.random.seed(seed)
28
+ random.seed(seed)
29
+ torch.backends.cudnn.benchmark = False
30
+ torch.backends.cudnn.deterministic = True
31
+ torch.cuda.manual_seed_all(seed)
32
+
33
+
34
  def load_model_checkpoint(model, ckpt):
35
  state_dict = torch.load(ckpt, map_location="cpu")
36
  if "state_dict" in list(state_dict.keys()):
 
187
  ]
188
  css = """
189
  #input_img {max-height: 320px;}
190
+ #input_img [data-testid="image"], #input_img [data-testid="image"] > div{max-height: 320px;}
191
  #output_img {max-height: 400px;}
192
+ #output_vid {max-height: 320px; max-width: 512px;}
193
  """
194
 
195
  with gr.Blocks(analytics_enabled=False, css=css) as demo_iface: