hysts HF staff commited on
Commit
99243b4
·
1 Parent(s): 8bfafae
Files changed (5) hide show
  1. .gitmodules +3 -0
  2. DualStyleGAN +1 -0
  3. app.py +253 -0
  4. packages.txt +2 -0
  5. requirements.txt +7 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "DualStyleGAN"]
2
+ path = DualStyleGAN
3
+ url = https://github.com/williamyang1991/DualStyleGAN
DualStyleGAN ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit d9c52c2313913352cd2e35707f72fd450bf16630
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ import tarfile
11
+ from typing import Callable
12
+
13
+ if os.environ['SYSTEM'] == 'spaces':
14
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
15
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
16
+
17
+ sys.path.insert(0, 'DualStyleGAN')
18
+
19
+ import dlib
20
+ import gradio as gr
21
+ import huggingface_hub
22
+ import numpy as np
23
+ import PIL.Image
24
+ import torch
25
+ import torch.nn as nn
26
+ import torchvision.transforms as T
27
+ from model.dualstylegan import DualStyleGAN
28
+ from model.encoder.align_all_parallel import align_face
29
+ from model.encoder.psp import pSp
30
+ from util import load_image, visualize
31
+
32
+ TOKEN = os.environ['TOKEN']
33
+
34
+ MODEL_REPO = 'hysts/DualStyleGAN'
35
+
36
+
37
+ def parse_args() -> argparse.Namespace:
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('--device', type=str, default='cpu')
40
+ parser.add_argument('--theme', type=str)
41
+ parser.add_argument('--live', action='store_true')
42
+ parser.add_argument('--share', action='store_true')
43
+ parser.add_argument('--port', type=int)
44
+ parser.add_argument('--disable-queue',
45
+ dest='enable_queue',
46
+ action='store_false')
47
+ parser.add_argument('--allow-flagging', type=str, default='never')
48
+ parser.add_argument('--allow-screenshot', action='store_true')
49
+ return parser.parse_args()
50
+
51
+
52
+ def download_cartoon_images() -> None:
53
+ image_dir = pathlib.Path('cartoon')
54
+ if not image_dir.exists():
55
+ path = huggingface_hub.hf_hub_download('hysts/DualStyleGAN-Cartoon',
56
+ 'cartoon.tar.gz',
57
+ repo_type='dataset',
58
+ use_auth_token=TOKEN)
59
+ with tarfile.open(path) as f:
60
+ f.extractall()
61
+
62
+
63
+ def load_encoder(device: torch.device) -> nn.Module:
64
+ ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
65
+ 'models/encoder.pt',
66
+ use_auth_token=TOKEN)
67
+ ckpt = torch.load(ckpt_path, map_location='cpu')
68
+ opts = ckpt['opts']
69
+ opts['device'] = 'cpu'
70
+ opts['checkpoint_path'] = ckpt_path
71
+ opts = argparse.Namespace(**opts)
72
+ model = pSp(opts)
73
+ model.to(device)
74
+ model.eval()
75
+ return model
76
+
77
+
78
+ def load_generator(style_type: str, device: torch.device) -> nn.Module:
79
+ model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
80
+ ckpt_path = huggingface_hub.hf_hub_download(
81
+ MODEL_REPO, f'models/{style_type}/generator.pt', use_auth_token=TOKEN)
82
+ ckpt = torch.load(ckpt_path, map_location='cpu')
83
+ model.load_state_dict(ckpt['g_ema'])
84
+ model.to(device)
85
+ model.eval()
86
+ return model
87
+
88
+
89
+ def load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
90
+ if style_type in ['cartoon', 'caricature', 'anime']:
91
+ filename = 'refined_exstyle_code.npy'
92
+ else:
93
+ filename = 'exstyle_code.npy'
94
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
95
+ f'models/{style_type}/{filename}',
96
+ use_auth_token=TOKEN)
97
+ exstyles = np.load(path, allow_pickle=True).item()
98
+ return exstyles
99
+
100
+
101
+ def create_transform() -> Callable:
102
+ transform = T.Compose([
103
+ T.Resize(256),
104
+ T.CenterCrop(256),
105
+ T.ToTensor(),
106
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
107
+ ])
108
+ return transform
109
+
110
+
111
+ def create_dlib_landmark_model():
112
+ path = huggingface_hub.hf_hub_download(
113
+ 'hysts/dlib_face_landmark_model',
114
+ 'shape_predictor_68_face_landmarks.dat',
115
+ use_auth_token=TOKEN)
116
+ return dlib.shape_predictor(path)
117
+
118
+
119
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
120
+ return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
121
+
122
+
123
+ def postprocess(tensor: torch.Tensor) -> PIL.Image.Image:
124
+ tensor = denormalize(tensor)
125
+ image = tensor.cpu().numpy().transpose(1, 2, 0)
126
+ return PIL.Image.fromarray(image)
127
+
128
+
129
+ @torch.inference_mode()
130
+ def run(
131
+ image,
132
+ style_id: int,
133
+ dlib_landmark_model,
134
+ encoder: nn.Module,
135
+ generator: nn.Module,
136
+ exstyles: dict[str, np.ndarray],
137
+ transform: Callable,
138
+ device: torch.device,
139
+ style_image_dir: pathlib.Path,
140
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
141
+ stylename = list(exstyles.keys())[style_id]
142
+
143
+ image = align_face(filepath=image.name, predictor=dlib_landmark_model)
144
+ input_data = transform(image).unsqueeze(0).to(device)
145
+
146
+ img_rec, instyle = encoder(input_data,
147
+ randomize_noise=False,
148
+ return_latents=True,
149
+ z_plus_latent=True,
150
+ return_z_plus_latent=True,
151
+ resize=False)
152
+ img_rec = torch.clamp(img_rec.detach(), -1, 1)
153
+
154
+ latent = torch.tensor(exstyles[stylename]).repeat(2, 1, 1)
155
+ # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer
156
+ latent[1, 7:18] = instyle[0, 7:18]
157
+ exstyle = generator.generator.style(
158
+ latent.reshape(latent.shape[0] * latent.shape[1],
159
+ latent.shape[2])).reshape(latent.shape)
160
+
161
+ img_gen, _ = generator([instyle.repeat(2, 1, 1)],
162
+ exstyle,
163
+ z_plus_latent=True,
164
+ truncation=0.7,
165
+ truncation_latent=0,
166
+ use_res=True,
167
+ interp_weights=[0.6] * 7 + [1] * 11)
168
+ img_gen = torch.clamp(img_gen.detach(), -1, 1)
169
+ # deactivate color-related layers by setting w_c = 0
170
+ img_gen2, _ = generator([instyle],
171
+ exstyle[0:1],
172
+ z_plus_latent=True,
173
+ truncation=0.7,
174
+ truncation_latent=0,
175
+ use_res=True,
176
+ interp_weights=[0.6] * 7 + [0] * 11)
177
+ img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
178
+
179
+ img_rec = postprocess(img_rec[0])
180
+ img_gen0 = postprocess(img_gen[0])
181
+ img_gen1 = postprocess(img_gen[1])
182
+ img_gen2 = postprocess(img_gen2[0])
183
+
184
+ style_image = PIL.Image.open(style_image_dir / stylename)
185
+
186
+ return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
187
+
188
+
189
+ def main():
190
+ gr.close_all()
191
+
192
+ args = parse_args()
193
+ device = torch.device(args.device)
194
+
195
+ style_type = 'cartoon'
196
+ style_image_dir = pathlib.Path(style_type)
197
+
198
+ download_cartoon_images()
199
+ dlib_landmark_model = create_dlib_landmark_model()
200
+ encoder = load_encoder(device)
201
+ generator = load_generator(style_type, device)
202
+ exstyles = load_exstylecode(style_type)
203
+ transform = create_transform()
204
+
205
+ func = functools.partial(run,
206
+ dlib_landmark_model=dlib_landmark_model,
207
+ encoder=encoder,
208
+ generator=generator,
209
+ exstyles=exstyles,
210
+ transform=transform,
211
+ device=device,
212
+ style_image_dir=style_image_dir)
213
+ func = functools.update_wrapper(func, run)
214
+
215
+ repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
216
+ title = 'williamyang1991/DualStyleGAN'
217
+ description = f'A demo for {repo_url}'
218
+ article = None
219
+
220
+ image_paths = sorted(pathlib.Path('images').glob('*'))
221
+ examples = [[path.as_posix(), 26] for path in image_paths]
222
+
223
+ gr.Interface(
224
+ func,
225
+ [
226
+ gr.inputs.Image(type='file', label='Image'),
227
+ gr.inputs.Slider(0, 316, step=1, default=26, label='Style'),
228
+ ],
229
+ [
230
+ gr.outputs.Image(type='pil', label='Aligned face'),
231
+ gr.outputs.Image(type='pil', label='Style'),
232
+ gr.outputs.Image(type='pil', label='Reconstructed'),
233
+ gr.outputs.Image(type='pil', label='Gen 1'),
234
+ gr.outputs.Image(type='pil', label='Gen 2'),
235
+ gr.outputs.Image(type='pil', label='Gen 3'),
236
+ ],
237
+ examples=examples,
238
+ theme=args.theme,
239
+ title=title,
240
+ description=description,
241
+ article=article,
242
+ allow_screenshot=args.allow_screenshot,
243
+ allow_flagging=args.allow_flagging,
244
+ live=args.live,
245
+ ).launch(
246
+ enable_queue=args.enable_queue,
247
+ server_port=args.port,
248
+ share=args.share,
249
+ )
250
+
251
+
252
+ if __name__ == '__main__':
253
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cmake
2
+ ninja-build
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dlib==19.23.0
2
+ numpy==1.22.3
3
+ opencv-python-headless==4.5.5.62
4
+ Pillow==9.0.1
5
+ scipy==1.8.0
6
+ torch==1.11.0
7
+ torchvision==0.12.0