Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -59,12 +59,12 @@ def generate_image(seed: int, truncation_psi: float, model: nn.Module,
|
|
59 |
return out[0].cpu().numpy()
|
60 |
|
61 |
|
62 |
-
def load_model(
|
63 |
-
path = hf_hub_download(
|
64 |
-
f'{file_name}',
|
65 |
use_auth_token=TOKEN)
|
66 |
with open(path, 'rb') as f:
|
67 |
-
model =
|
68 |
model.eval()
|
69 |
model.to(device)
|
70 |
with torch.inference_mode():
|
@@ -78,7 +78,7 @@ def main():
|
|
78 |
args = parse_args()
|
79 |
device = torch.device(args.device)
|
80 |
|
81 |
-
model = load_model('
|
82 |
|
83 |
func = functools.partial(generate_image, model=model, device=device)
|
84 |
func = functools.update_wrapper(func, generate_image)
|
|
|
59 |
return out[0].cpu().numpy()
|
60 |
|
61 |
|
62 |
+
def load_model(file_name: str, device: torch.device) -> nn.Module:
|
63 |
+
path = hf_hub_download('hysts/StyleGAN-Human',
|
64 |
+
f'models/{file_name}',
|
65 |
use_auth_token=TOKEN)
|
66 |
with open(path, 'rb') as f:
|
67 |
+
model = pickle.load(f)['G_ema']
|
68 |
model.eval()
|
69 |
model.to(device)
|
70 |
with torch.inference_mode():
|
|
|
78 |
args = parse_args()
|
79 |
device = torch.device(args.device)
|
80 |
|
81 |
+
model = load_model('stylegan_human_v2_1024.pkl', device)
|
82 |
|
83 |
func = functools.partial(generate_image, model=model, device=device)
|
84 |
func = functools.update_wrapper(func, generate_image)
|