Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -52,22 +52,18 @@ def parse_args() -> argparse.Namespace:
|
|
52 |
action='store_false')
|
53 |
parser.add_argument('--allow-flagging', type=str, default='never')
|
54 |
return parser.parse_args()
|
55 |
-
|
56 |
-
def load_model(file_name: str, path:str,device: torch.device) -> nn.Module:
|
57 |
-
path = hf_hub_download(f'{path}',
|
58 |
-
f'{file_name}',
|
59 |
-
use_auth_token=TOKEN)
|
60 |
-
with open(path, 'rb') as f:
|
61 |
-
model = torch.load(f)
|
62 |
-
model.eval()
|
63 |
-
model.to(device)
|
64 |
-
with torch.inference_mode():
|
65 |
-
z = torch.zeros((1, model.z_dim)).to(device)
|
66 |
-
label = torch.zeros([1, model.c_dim], device=device)
|
67 |
-
model(z, label, force_fp32=True)
|
68 |
-
return model
|
69 |
|
70 |
def image_create(input_img):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
device = th.device()
|
72 |
generator = create_generator("stylegan2-ffhq-config-f.pt","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",args, device)
|
73 |
latent = torch.randn((1, 512), device=device)
|
@@ -82,33 +78,29 @@ def main():
|
|
82 |
#else:
|
83 |
# ini = "False1"
|
84 |
#result = subprocess.check_output(['nvidia-smi'])
|
85 |
-
|
86 |
-
args = ProjectorArguments().parse(
|
87 |
-
args=[str(input_path)],
|
88 |
-
namespace=Namespace(
|
89 |
-
spectral_sensitivity=spectral_sensitivity,
|
90 |
-
encoder_ckpt=f"checkpoint/encoder/checkpoint_{spectral_sensitivity}.pt",
|
91 |
-
encoder_name=spectral_sensitivity,
|
92 |
-
#gaussian=gaussian_radius,
|
93 |
-
log_visual_freq=1000,
|
94 |
-
input='text',
|
95 |
-
))
|
96 |
iface = gr.Interface(
|
97 |
fn=image_create,
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
title=TITLE,
|
101 |
description=DESCRIPTION,
|
102 |
article=ARTICLE,
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
)
|
107 |
|
108 |
-
iface.launch(
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
if __name__ == '__main__':
|
113 |
main()
|
114 |
|
|
|
52 |
action='store_false')
|
53 |
parser.add_argument('--allow-flagging', type=str, default='never')
|
54 |
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def image_create(input_img):
|
57 |
+
args = ProjectorArguments().parse(
|
58 |
+
args=[str(input_path)],
|
59 |
+
namespace=Namespace(
|
60 |
+
spectral_sensitivity=spectral_sensitivity,
|
61 |
+
encoder_ckpt=f"checkpoint/encoder/checkpoint_{spectral_sensitivity}.pt",
|
62 |
+
encoder_name=spectral_sensitivity,
|
63 |
+
#gaussian=gaussian_radius,
|
64 |
+
log_visual_freq=1000,
|
65 |
+
input='text',
|
66 |
+
))
|
67 |
device = th.device()
|
68 |
generator = create_generator("stylegan2-ffhq-config-f.pt","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",args, device)
|
69 |
latent = torch.randn((1, 512), device=device)
|
|
|
78 |
#else:
|
79 |
# ini = "False1"
|
80 |
#result = subprocess.check_output(['nvidia-smi'])
|
81 |
+
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
iface = gr.Interface(
|
83 |
fn=image_create,
|
84 |
+
[
|
85 |
+
gr.inputs.Number(default=0, label='Seed'),
|
86 |
+
gr.inputs.Slider(
|
87 |
+
0, 2, step=0.05, default=0.7, label='Truncation psi'),
|
88 |
+
],
|
89 |
+
gr.outputs.Image(type='numpy', label='Output'),
|
90 |
title=TITLE,
|
91 |
description=DESCRIPTION,
|
92 |
article=ARTICLE,
|
93 |
+
theme=args.theme,
|
94 |
+
allow_flagging=args.allow_flagging,
|
95 |
+
live=args.live,
|
96 |
)
|
97 |
|
98 |
+
iface.launch(
|
99 |
+
enable_queue=args.enable_queue,
|
100 |
+
server_port=args.port,
|
101 |
+
share=args.share,
|
102 |
+
)
|
103 |
+
|
104 |
if __name__ == '__main__':
|
105 |
main()
|
106 |
|