hysts HF staff commited on
Commit
cfde09c
·
1 Parent(s): 7023e95
Files changed (5) hide show
  1. .gitmodules +3 -0
  2. app.py +159 -0
  3. patch +85 -0
  4. requirements.txt +4 -0
  5. stylegan2-pytorch +1 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "stylegan2-pytorch"]
2
+ path = stylegan2-pytorch
3
+ url = https://github.com/rosinality/stylegan2-pytorch
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import subprocess
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ if os.environ.get('SYSTEM') == 'spaces':
18
+ subprocess.call('git apply ../patch'.split(), cwd='stylegan2-pytorch')
19
+
20
+ sys.path.insert(0, 'stylegan2-pytorch')
21
+
22
+ from model import Generator
23
+
24
+ TITLE = 'TADNE (This Anime Does Not Exist) Interpolation'
25
+ DESCRIPTION = 'The original TADNE site is https://thisanimedoesnotexist.ai/.'
26
+ ARTICLE = None
27
+
28
+ TOKEN = os.environ['TOKEN']
29
+
30
+
31
+ def parse_args() -> argparse.Namespace:
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--device', type=str, default='cpu')
34
+ parser.add_argument('--theme', type=str)
35
+ parser.add_argument('--live', action='store_true')
36
+ parser.add_argument('--share', action='store_true')
37
+ parser.add_argument('--port', type=int)
38
+ parser.add_argument('--disable-queue',
39
+ dest='enable_queue',
40
+ action='store_false')
41
+ parser.add_argument('--allow-flagging', type=str, default='never')
42
+ parser.add_argument('--allow-screenshot', action='store_true')
43
+ return parser.parse_args()
44
+
45
+
46
+ def load_model(device: torch.device) -> nn.Module:
47
+ model = Generator(512, 1024, 4, channel_multiplier=2)
48
+ path = hf_hub_download('hysts/TADNE',
49
+ 'models/aydao-anime-danbooru2019s-512-5268480.pt',
50
+ use_auth_token=TOKEN)
51
+ checkpoint = torch.load(path)
52
+ model.load_state_dict(checkpoint['g_ema'])
53
+ model.eval()
54
+ model.to(device)
55
+ model.latent_avg = checkpoint['latent_avg'].to(device)
56
+ with torch.inference_mode():
57
+ z = torch.zeros((1, model.style_dim)).to(device)
58
+ model([z], truncation=0.7, truncation_latent=model.latent_avg)
59
+ return model
60
+
61
+
62
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
63
+ return torch.from_numpy(np.random.RandomState(seed).randn(
64
+ 1, z_dim)).to(device).float()
65
+
66
+
67
+ @torch.inference_mode()
68
+ def generate_image(model: nn.Module, z: torch.Tensor, truncation_psi: float,
69
+ randomize_noise: bool) -> np.ndarray:
70
+ out, _ = model([z],
71
+ truncation=truncation_psi,
72
+ truncation_latent=model.latent_avg,
73
+ randomize_noise=randomize_noise)
74
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
75
+ return out[0].cpu().numpy()
76
+
77
+
78
+ @torch.inference_mode()
79
+ def generate_interpolated_images(seed0: int, seed1: int, num_intermediate: int,
80
+ psi0: float, psi1: float,
81
+ randomize_noise: bool, model: nn.Module,
82
+ device: torch.device) -> np.ndarray:
83
+ seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
84
+ seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
85
+
86
+ z0 = generate_z(model.style_dim, seed0, device)
87
+ if num_intermediate == -1:
88
+ out = generate_image(model, z0, psi0, randomize_noise)
89
+ return out
90
+
91
+ z1 = generate_z(model.style_dim, seed1, device)
92
+ vec = z1 - z0
93
+ dvec = vec / (num_intermediate + 1)
94
+ zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
95
+ dpsi = (psi1 - psi0) / (num_intermediate + 1)
96
+ psis = [psi0 + dpsi * i for i in range(num_intermediate + 2)]
97
+ res = []
98
+ for z, psi in zip(zs, psis):
99
+ out = generate_image(model, z, psi, randomize_noise)
100
+ res.append(out)
101
+ res = np.hstack(res)
102
+ return res
103
+
104
+
105
+ def main():
106
+ gr.close_all()
107
+
108
+ args = parse_args()
109
+ device = torch.device(args.device)
110
+
111
+ model = load_model(device)
112
+
113
+ func = functools.partial(generate_interpolated_images,
114
+ model=model,
115
+ device=device)
116
+ func = functools.update_wrapper(func, generate_interpolated_images)
117
+
118
+ examples = [
119
+ [29703, 55376, 3, 0.7, 0.7, False],
120
+ [34141, 36864, 5, 0.7, 0.7, False],
121
+ [74650, 88322, 7, 0.7, 0.7, False],
122
+ [84314, 70317410, 9, 0.7, 0.7, False],
123
+ [55376, 55376, 5, 0.3, 1.3, False],
124
+ ]
125
+
126
+ gr.Interface(
127
+ func,
128
+ [
129
+ gr.inputs.Number(default=29703, label='Seed 1'),
130
+ gr.inputs.Number(default=55376, label='Seed 2'),
131
+ gr.inputs.Slider(-1,
132
+ 11,
133
+ step=1,
134
+ default=3,
135
+ label='Number of Intermediate Frames'),
136
+ gr.inputs.Slider(
137
+ 0, 2, step=0.05, default=0.7, label='Truncation psi 1'),
138
+ gr.inputs.Slider(
139
+ 0, 2, step=0.05, default=0.7, label='Truncation psi 2'),
140
+ gr.inputs.Checkbox(default=False, label='Randomize Noise'),
141
+ ],
142
+ gr.outputs.Image(type='numpy', label='Output'),
143
+ examples=examples,
144
+ title=TITLE,
145
+ description=DESCRIPTION,
146
+ article=ARTICLE,
147
+ theme=args.theme,
148
+ allow_screenshot=args.allow_screenshot,
149
+ allow_flagging=args.allow_flagging,
150
+ live=args.live,
151
+ ).launch(
152
+ enable_queue=args.enable_queue,
153
+ server_port=args.port,
154
+ share=args.share,
155
+ )
156
+
157
+
158
+ if __name__ == '__main__':
159
+ main()
patch ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/model.py b/model.py
2
+ index 0134c39..3a7826c 100755
3
+ --- a/model.py
4
+ +++ b/model.py
5
+ @@ -395,6 +395,7 @@ class Generator(nn.Module):
6
+ style_dim,
7
+ n_mlp,
8
+ channel_multiplier=2,
9
+ + additional_multiplier=2,
10
+ blur_kernel=[1, 3, 3, 1],
11
+ lr_mlp=0.01,
12
+ ):
13
+ @@ -426,6 +427,9 @@ class Generator(nn.Module):
14
+ 512: 32 * channel_multiplier,
15
+ 1024: 16 * channel_multiplier,
16
+ }
17
+ + if additional_multiplier > 1:
18
+ + for k in list(self.channels.keys()):
19
+ + self.channels[k] *= additional_multiplier
20
+
21
+ self.input = ConstantInput(self.channels[4])
22
+ self.conv1 = StyledConv(
23
+ @@ -518,7 +522,7 @@ class Generator(nn.Module):
24
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
25
+ ]
26
+
27
+ - if truncation < 1:
28
+ + if truncation_latent is not None:
29
+ style_t = []
30
+
31
+ for style in styles:
32
+ diff --git a/op/fused_act.py b/op/fused_act.py
33
+ index 5d46e10..bc522ed 100755
34
+ --- a/op/fused_act.py
35
+ +++ b/op/fused_act.py
36
+ @@ -1,5 +1,3 @@
37
+ -import os
38
+ -
39
+ import torch
40
+ from torch import nn
41
+ from torch.nn import functional as F
42
+ @@ -7,16 +5,6 @@ from torch.autograd import Function
43
+ from torch.utils.cpp_extension import load
44
+
45
+
46
+ -module_path = os.path.dirname(__file__)
47
+ -fused = load(
48
+ - "fused",
49
+ - sources=[
50
+ - os.path.join(module_path, "fused_bias_act.cpp"),
51
+ - os.path.join(module_path, "fused_bias_act_kernel.cu"),
52
+ - ],
53
+ -)
54
+ -
55
+ -
56
+ class FusedLeakyReLUFunctionBackward(Function):
57
+ @staticmethod
58
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
59
+ diff --git a/op/upfirdn2d.py b/op/upfirdn2d.py
60
+ index 67e0375..6c5840e 100755
61
+ --- a/op/upfirdn2d.py
62
+ +++ b/op/upfirdn2d.py
63
+ @@ -1,5 +1,4 @@
64
+ from collections import abc
65
+ -import os
66
+
67
+ import torch
68
+ from torch.nn import functional as F
69
+ @@ -7,16 +6,6 @@ from torch.autograd import Function
70
+ from torch.utils.cpp_extension import load
71
+
72
+
73
+ -module_path = os.path.dirname(__file__)
74
+ -upfirdn2d_op = load(
75
+ - "upfirdn2d",
76
+ - sources=[
77
+ - os.path.join(module_path, "upfirdn2d.cpp"),
78
+ - os.path.join(module_path, "upfirdn2d_kernel.cu"),
79
+ - ],
80
+ -)
81
+ -
82
+ -
83
+ class UpFirDn2dBackward(Function):
84
+ @staticmethod
85
+ def forward(
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ torch==1.11.0
4
+ torchvision==0.12.0
stylegan2-pytorch ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit bef283a1c24087da704d16c30abc8e36e63efa0e