feng2022 commited on
Commit
b703853
1 Parent(s): 7b671bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -74
app.py CHANGED
@@ -17,93 +17,22 @@ from huggingface_hub import hf_hub_download
17
 
18
  sys.path.insert(0, 'StyleGAN-Human')
19
 
20
- TITLE = 'StyleGAN-Human'
21
- DESCRIPTION = '''This is an unofficial demo for https://github.com/stylegan-human/StyleGAN-Human.
22
- Expected execution time on Hugging Face Spaces: 0.8s
23
- Related App: [StyleGAN-Human (Interpolation)](https://huggingface.co/spaces/hysts/StyleGAN-Human-Interpolation)
24
  '''
25
  ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan-human" alt="visitor badge"/></center>'
26
 
27
  TOKEN = "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"
28
 
29
 
30
- def parse_args() -> argparse.Namespace:
31
- parser = argparse.ArgumentParser()
32
- parser.add_argument('--device', type=str, default='cpu')
33
- parser.add_argument('--theme', type=str)
34
- parser.add_argument('--live', action='store_true')
35
- parser.add_argument('--share', action='store_true')
36
- parser.add_argument('--port', type=int)
37
- parser.add_argument('--disable-queue',
38
- dest='enable_queue',
39
- action='store_false')
40
- parser.add_argument('--allow-flagging', type=str, default='never')
41
- return parser.parse_args()
42
-
43
-
44
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
45
- return torch.from_numpy(np.random.RandomState(seed).randn(
46
- 1, z_dim)).to(device).float()
47
-
48
-
49
- @torch.inference_mode()
50
- def generate_image(seed: int, truncation_psi: float, model: nn.Module,
51
- device: torch.device) -> np.ndarray:
52
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
53
-
54
- z = generate_z(model.z_dim, seed, device)
55
- label = torch.zeros([1, model.c_dim], device=device)
56
-
57
- out = model(z, label, truncation_psi=truncation_psi, force_fp32=True)
58
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
59
- return out[0].cpu().numpy()
60
-
61
-
62
- def load_model(file_name: str, path:str, device: torch.device) -> nn.Module:
63
- path = hf_hub_download(f'{path}',
64
- f'{file_name}',
65
- use_auth_token=TOKEN)
66
- with open(path, 'rb') as f:
67
- model = torch.load(f)
68
- model.eval()
69
- model.to(device)
70
- with torch.inference_mode():
71
- z = torch.zeros((1, model.z_dim)).to(device)
72
- label = torch.zeros([1, model.c_dim], device=device)
73
- model(z, label, force_fp32=True)
74
- return model
75
-
76
-
77
  def main():
78
- args = parse_args()
79
- device = torch.device(args.device)
80
-
81
- model_e4e = load_model('e4e_ffhq_encode.pt',"feng2022/Time-TravelRephotography_e4e_ffhq_encode", device)
82
-
83
- func = functools.partial(generate_image, model=model, device=device)
84
- func = functools.update_wrapper(func, generate_image)
85
 
86
  gr.Interface(
87
- func,
88
- [
89
- gr.inputs.Number(default=0, label='Seed'),
90
- gr.inputs.Slider(
91
- 0, 2, step=0.05, default=0.7, label='Truncation psi'),
92
- ],
93
- gr.outputs.Image(type='numpy', label='Output'),
94
  title=TITLE,
95
  description=DESCRIPTION,
96
  article=ARTICLE,
97
- theme=args.theme,
98
- allow_flagging=args.allow_flagging,
99
- live=args.live,
100
- ).launch(
101
- enable_queue=args.enable_queue,
102
- server_port=args.port,
103
- share=args.share,
104
  )
105
-
106
-
107
  if __name__ == '__main__':
108
  main()
109
 
 
17
 
18
  sys.path.insert(0, 'StyleGAN-Human')
19
 
20
+ TITLE = 'Time-TravelRephotography'
21
+ DESCRIPTION = '''This is an unofficial demo for https://github.com/Time-Travel-Rephotography.
 
 
22
  '''
23
  ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan-human" alt="visitor badge"/></center>'
24
 
25
  TOKEN = "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"
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def main():
 
 
 
 
 
 
 
29
 
30
  gr.Interface(
 
 
 
 
 
 
 
31
  title=TITLE,
32
  description=DESCRIPTION,
33
  article=ARTICLE,
 
 
 
 
 
 
 
34
  )
35
+
 
36
  if __name__ == '__main__':
37
  main()
38