hlydecker commited on
Commit
9049007
1 Parent(s): d73a79d

content: make into gandalf

Browse files
Files changed (1) hide show
  1. app.py +17 -94
app.py CHANGED
@@ -74,24 +74,9 @@ ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
74
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
  mean_latent = original_generator.mean_latent(10000)
76
 
77
- generatorjojo = deepcopy(original_generator)
78
-
79
- generatordisney = deepcopy(original_generator)
80
-
81
- generatorjinx = deepcopy(original_generator)
82
-
83
- generatorcaitlyn = deepcopy(original_generator)
84
-
85
- generatoryasuho = deepcopy(original_generator)
86
-
87
- generatorarcanemulti = deepcopy(original_generator)
88
-
89
- generatorart = deepcopy(original_generator)
90
-
91
- generatorspider = deepcopy(original_generator)
92
-
93
- generatorsketch = deepcopy(original_generator)
94
 
 
95
 
96
  transform = transforms.Compose(
97
  [
@@ -104,101 +89,39 @@ transform = transforms.Compose(
104
 
105
 
106
 
107
- modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
108
-
109
-
110
- ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
111
- generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
112
-
113
-
114
- modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
115
-
116
- ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
117
- generatordisney.load_state_dict(ckptdisney["g"], strict=False)
118
-
119
-
120
- modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
121
-
122
- ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
123
- generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
124
-
125
 
126
- modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
127
 
128
- ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
129
- generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
130
 
131
 
132
- modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
133
 
134
- ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
135
- generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
136
 
137
 
138
- model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
139
-
140
- ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
141
- generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
142
-
143
-
144
- modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
145
-
146
- ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
147
- generatorart.load_state_dict(ckptart["g"], strict=False)
148
-
149
-
150
- modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
151
-
152
- ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
153
- generatorspider.load_state_dict(ckptspider["g"], strict=False)
154
-
155
- modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
156
-
157
- ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
158
- generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
159
-
160
  def inference(img, model):
161
  img.save('out.jpg')
162
  aligned_face = align_face('out.jpg')
163
 
164
  my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
165
- if model == 'JoJo':
166
- with torch.no_grad():
167
- my_sample = generatorjojo(my_w, input_is_latent=True)
168
- elif model == 'Disney':
169
- with torch.no_grad():
170
- my_sample = generatordisney(my_w, input_is_latent=True)
171
- elif model == 'Jinx':
172
- with torch.no_grad():
173
- my_sample = generatorjinx(my_w, input_is_latent=True)
174
- elif model == 'Caitlyn':
175
- with torch.no_grad():
176
- my_sample = generatorcaitlyn(my_w, input_is_latent=True)
177
- elif model == 'Yasuho':
178
- with torch.no_grad():
179
- my_sample = generatoryasuho(my_w, input_is_latent=True)
180
- elif model == 'Arcane Multi':
181
- with torch.no_grad():
182
- my_sample = generatorarcanemulti(my_w, input_is_latent=True)
183
- elif model == 'Art':
184
- with torch.no_grad():
185
- my_sample = generatorart(my_w, input_is_latent=True)
186
- elif model == 'Spider-Verse':
187
  with torch.no_grad():
188
- my_sample = generatorspider(my_w, input_is_latent=True)
189
- else:
190
  with torch.no_grad():
191
- my_sample = generatorsketch(my_w, input_is_latent=True)
192
-
193
 
194
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
195
  imageio.imwrite('filename.jpeg', npimage)
196
  return 'filename.jpeg'
197
 
198
- title = "JoJoGAN"
199
- description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
200
 
201
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
202
 
203
- examples=[['mona.png','Jinx']]
204
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
 
74
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
  mean_latent = original_generator.mean_latent(10000)
76
 
77
+ generatorgollum_mod = deepcopy(original_generator)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ generatorgollum_ex = deepcopy(original_generator)
80
 
81
  transform = transforms.Compose(
82
  [
 
89
 
90
 
91
 
92
+ modelgollum_mod = hf_hub_download(repo_id="hlydecker/gandalf-gollum-moderate", filename="gollum_moderate.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
94
 
95
+ ckptgollum_mod = torch.load(modelgollum_mod, map_location=lambda storage, loc: storage)
96
+ generatorjojo.load_state_dict(ckptgollum_mod["g"], strict=False)
97
 
98
 
99
+ modelgollum_ex = hf_hub_download(repo_id="hlydecker/gandalf-gollum-extreme", filename="gollum_extreme.pt")
100
 
101
+ ckptgollum_ex = torch.load(modeldisney, map_location=lambda storage, loc: storage)
102
+ generatorgollum_ex.load_state_dict(ckptgollum_ex["g"], strict=False)
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def inference(img, model):
106
  img.save('out.jpg')
107
  aligned_face = align_face('out.jpg')
108
 
109
  my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
110
+ if model == 'Gollum Moderate':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  with torch.no_grad():
112
+ my_sample = generatorgollum_mod(my_w, input_is_latent=True)
113
+ elif model == 'Gollum Extreme':
114
  with torch.no_grad():
115
+ my_sample = generatorgollum_ex(my_w, input_is_latent=True)
 
116
 
117
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
118
  imageio.imwrite('filename.jpeg', npimage)
119
  return 'filename.jpeg'
120
 
121
+ title = "Gollumizer"
122
+ description = "Gradio Demo for GANdalf: One Shot Face Tolekization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
123
 
124
+ article = "<p style='text-align: center'>GANdalf: One Shot Face Tolkeinization</a>| <a href='https://github.com/hlydecker/GANDalf' target='_blank'>Github Repo Pytorch</a></p> <center></center>"
125
 
126
+ examples=[['mona.png','Gollum Moderate']]
127
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['Gollum Moderate', 'Gollum Extreme'], type="value", default='Gollum Moderate', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()