hysts HF staff commited on
Commit
3744a88
·
1 Parent(s): b4a6915
Files changed (5) hide show
  1. .pre-commit-config.yaml +59 -34
  2. .style.yapf +0 -5
  3. README.md +1 -1
  4. app.py +28 -44
  5. requirements.txt +2 -2
.pre-commit-config.yaml CHANGED
@@ -1,36 +1,61 @@
1
  exclude: ^patch
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- - repo: https://github.com/google/yapf
33
- rev: v0.32.0
34
- hooks:
35
- - id: yapf
36
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  exclude: ^patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.6.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.13.2
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ [
34
+ "types-python-slugify",
35
+ "types-requests",
36
+ "types-PyYAML",
37
+ "types-pytz",
38
+ ]
39
+ - repo: https://github.com/psf/black
40
+ rev: 24.4.2
41
+ hooks:
42
+ - id: black
43
+ language_version: python3.10
44
+ args: ["--line-length", "119"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.7.1
47
+ hooks:
48
+ - id: nbstripout
49
+ args:
50
+ [
51
+ "--extra-keys",
52
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
53
+ ]
54
+ - repo: https://github.com/nbQA-dev/nbQA
55
+ rev: 1.8.5
56
+ hooks:
57
+ - id: nbqa-black
58
+ - id: nbqa-pyupgrade
59
+ args: ["--py37-plus"]
60
+ - id: nbqa-isort
61
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏃
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
 
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
app.py CHANGED
@@ -15,36 +15,32 @@ import torch
15
  import torch.nn as nn
16
  from huggingface_hub import hf_hub_download
17
 
18
- if os.environ.get('SYSTEM') == 'spaces':
19
- with open('patch') as f:
20
- subprocess.run(shlex.split('patch -p1'),
21
- cwd='stylegan2-pytorch',
22
- stdin=f)
23
  if not torch.cuda.is_available():
24
- with open('patch-cpu') as f:
25
- subprocess.run(shlex.split('patch -p1'),
26
- cwd='stylegan2-pytorch',
27
- stdin=f)
28
 
29
- sys.path.insert(0, 'stylegan2-pytorch')
30
 
31
  from model import Generator
32
 
33
- DESCRIPTION = '''# [TADNE](https://thisanimedoesnotexist.ai/) (This Anime Does Not Exist)
34
 
35
  Related Apps:
36
  - [TADNE Image Viewer](https://huggingface.co/spaces/hysts/TADNE-image-viewer)
37
  - [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector)
38
  - [TADNE Interpolation](https://huggingface.co/spaces/hysts/TADNE-interpolation)
39
  - [TADNE Image Search with DeepDanbooru](https://huggingface.co/spaces/hysts/TADNE-image-search-with-DeepDanbooru)
40
- '''
41
- SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/TADNE/resolve/main/samples'
42
- ARTICLE = f'''## Generated images
43
  - size: 512x512
44
  - truncation: 0.7
45
  - seed: 0-99
46
  ![samples]({SAMPLE_IMAGE_DIR}/sample.jpg)
47
- '''
48
 
49
  MAX_SEED = np.iinfo(np.int32).max
50
 
@@ -57,13 +53,12 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
57
 
58
  def load_model(device: torch.device) -> nn.Module:
59
  model = Generator(512, 1024, 4, channel_multiplier=2)
60
- path = hf_hub_download('public-data/TADNE',
61
- 'models/aydao-anime-danbooru2019s-512-5268480.pt')
62
  checkpoint = torch.load(path)
63
- model.load_state_dict(checkpoint['g_ema'])
64
  model.eval()
65
  model.to(device)
66
- model.latent_avg = checkpoint['latent_avg'].to(device)
67
  with torch.inference_mode():
68
  z = torch.zeros((1, model.style_dim)).to(device)
69
  model([z], truncation=0.7, truncation_latent=model.latent_avg)
@@ -71,47 +66,36 @@ def load_model(device: torch.device) -> nn.Module:
71
 
72
 
73
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
74
- return torch.from_numpy(np.random.RandomState(seed).randn(
75
- 1, z_dim)).to(device).float()
76
 
77
 
78
  @torch.inference_mode()
79
- def generate_image(seed: int, truncation_psi: float, randomize_noise: bool,
80
- model: nn.Module, device: torch.device) -> np.ndarray:
 
81
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
82
 
83
  z = generate_z(model.style_dim, seed, device)
84
- out, _ = model([z],
85
- truncation=truncation_psi,
86
- truncation_latent=model.latent_avg,
87
- randomize_noise=randomize_noise)
88
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
89
  return out[0].cpu().numpy()
90
 
91
 
92
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
93
  model = load_model(device)
94
  fn = functools.partial(generate_image, model=model, device=device)
95
 
96
- with gr.Blocks(css='style.css') as demo:
97
  gr.Markdown(DESCRIPTION)
98
  with gr.Row():
99
  with gr.Column():
100
- seed = gr.Slider(label='Seed',
101
- minimum=0,
102
- maximum=MAX_SEED,
103
- step=1,
104
- value=0)
105
- randomize_seed = gr.Checkbox(label='Randomize seed', value=True)
106
- psi = gr.Slider(label='Truncation psi',
107
- minimum=0,
108
- maximum=2,
109
- step=0.05,
110
- value=0.7)
111
- randomize_noise = gr.Checkbox(label='Randomize Noise', value=False)
112
- run_button = gr.Button('Run')
113
  with gr.Column():
114
- result = gr.Image(label='Output')
115
  gr.Markdown(ARTICLE)
116
 
117
  run_button.click(
@@ -124,6 +108,6 @@ with gr.Blocks(css='style.css') as demo:
124
  fn=fn,
125
  inputs=[seed, psi, randomize_noise],
126
  outputs=result,
127
- api_name='run',
128
  )
129
  demo.queue(max_size=10).launch()
 
15
  import torch.nn as nn
16
  from huggingface_hub import hf_hub_download
17
 
18
+ if os.environ.get("SYSTEM") == "spaces":
19
+ with open("patch") as f:
20
+ subprocess.run(shlex.split("patch -p1"), cwd="stylegan2-pytorch", stdin=f)
 
 
21
  if not torch.cuda.is_available():
22
+ with open("patch-cpu") as f:
23
+ subprocess.run(shlex.split("patch -p1"), cwd="stylegan2-pytorch", stdin=f)
 
 
24
 
25
+ sys.path.insert(0, "stylegan2-pytorch")
26
 
27
  from model import Generator
28
 
29
+ DESCRIPTION = """# [TADNE](https://thisanimedoesnotexist.ai/) (This Anime Does Not Exist)
30
 
31
  Related Apps:
32
  - [TADNE Image Viewer](https://huggingface.co/spaces/hysts/TADNE-image-viewer)
33
  - [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector)
34
  - [TADNE Interpolation](https://huggingface.co/spaces/hysts/TADNE-interpolation)
35
  - [TADNE Image Search with DeepDanbooru](https://huggingface.co/spaces/hysts/TADNE-image-search-with-DeepDanbooru)
36
+ """
37
+ SAMPLE_IMAGE_DIR = "https://huggingface.co/spaces/hysts/TADNE/resolve/main/samples"
38
+ ARTICLE = f"""## Generated images
39
  - size: 512x512
40
  - truncation: 0.7
41
  - seed: 0-99
42
  ![samples]({SAMPLE_IMAGE_DIR}/sample.jpg)
43
+ """
44
 
45
  MAX_SEED = np.iinfo(np.int32).max
46
 
 
53
 
54
  def load_model(device: torch.device) -> nn.Module:
55
  model = Generator(512, 1024, 4, channel_multiplier=2)
56
+ path = hf_hub_download("public-data/TADNE", "models/aydao-anime-danbooru2019s-512-5268480.pt")
 
57
  checkpoint = torch.load(path)
58
+ model.load_state_dict(checkpoint["g_ema"])
59
  model.eval()
60
  model.to(device)
61
+ model.latent_avg = checkpoint["latent_avg"].to(device)
62
  with torch.inference_mode():
63
  z = torch.zeros((1, model.style_dim)).to(device)
64
  model([z], truncation=0.7, truncation_latent=model.latent_avg)
 
66
 
67
 
68
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
69
+ return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).to(device).float()
 
70
 
71
 
72
  @torch.inference_mode()
73
+ def generate_image(
74
+ seed: int, truncation_psi: float, randomize_noise: bool, model: nn.Module, device: torch.device
75
+ ) -> np.ndarray:
76
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
77
 
78
  z = generate_z(model.style_dim, seed, device)
79
+ out, _ = model([z], truncation=truncation_psi, truncation_latent=model.latent_avg, randomize_noise=randomize_noise)
 
 
 
80
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
81
  return out[0].cpu().numpy()
82
 
83
 
84
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
85
  model = load_model(device)
86
  fn = functools.partial(generate_image, model=model, device=device)
87
 
88
+ with gr.Blocks(css="style.css") as demo:
89
  gr.Markdown(DESCRIPTION)
90
  with gr.Row():
91
  with gr.Column():
92
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
93
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
+ psi = gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7)
95
+ randomize_noise = gr.Checkbox(label="Randomize Noise", value=False)
96
+ run_button = gr.Button("Run")
 
 
 
 
 
 
 
 
97
  with gr.Column():
98
+ result = gr.Image(label="Output")
99
  gr.Markdown(ARTICLE)
100
 
101
  run_button.click(
 
108
  fn=fn,
109
  inputs=[seed, psi, randomize_noise],
110
  outputs=result,
111
+ api_name="run",
112
  )
113
  demo.queue(max_size=10).launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- numpy==1.23.5
2
- Pillow==10.0.0
3
  torch==2.0.1
4
  torchvision==0.15.2
 
1
+ numpy==1.26.4
2
+ Pillow==10.3.0
3
  torch==2.0.1
4
  torchvision==0.15.2