hysts HF staff commited on
Commit
1a2d639
·
1 Parent(s): d6a0aaf
Files changed (4) hide show
  1. .pre-commit-config.yaml +60 -0
  2. README.md +1 -1
  3. app.py +55 -114
  4. requirements.txt +2 -2
.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌖
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,10 +2,7 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
- import functools
7
  import io
8
- import os
9
  import pathlib
10
  import tarfile
11
 
@@ -17,8 +14,8 @@ import PIL.Image
17
  import tensorflow as tf
18
  from huggingface_hub import hf_hub_download
19
 
20
- TITLE = 'TADNE Image Search with DeepDanbooru'
21
- DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/.
22
 
23
  This app shows images similar to the query image from images generated
24
  by the TADNE model with seed 0-99999.
@@ -38,59 +35,30 @@ Related Apps:
38
  - [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector)
39
  - [TADNE Interpolation](https://huggingface.co/spaces/hysts/TADNE-interpolation)
40
  - [DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
41
- '''
42
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.tadne-image-search-with-deepdanbooru" alt="visitor badge"/></center>'
43
-
44
- TOKEN = os.environ['TOKEN']
45
-
46
-
47
- def parse_args() -> argparse.Namespace:
48
- parser = argparse.ArgumentParser()
49
- parser.add_argument('--theme', type=str)
50
- parser.add_argument('--live', action='store_true')
51
- parser.add_argument('--share', action='store_true')
52
- parser.add_argument('--port', type=int)
53
- parser.add_argument('--disable-queue',
54
- dest='enable_queue',
55
- action='store_false')
56
- parser.add_argument('--allow-flagging', type=str, default='never')
57
- return parser.parse_args()
58
-
59
-
60
- def download_image_tarball(size: int, dirname: str) -> pathlib.Path:
61
- path = hf_hub_download('hysts/TADNE-sample-images',
62
- f'{size}/{dirname}.tar',
63
- repo_type='dataset',
64
- use_auth_token=TOKEN)
65
- return path
66
 
67
 
68
  def load_deepdanbooru_predictions(dirname: str) -> np.ndarray:
69
  path = hf_hub_download(
70
- 'hysts/TADNE-sample-images',
71
- f'prediction_results/deepdanbooru/intermediate_features/{dirname}.npy',
72
- repo_type='dataset',
73
- use_auth_token=TOKEN)
74
  return np.load(path)
75
 
76
 
77
  def load_sample_image_paths() -> list[pathlib.Path]:
78
- image_dir = pathlib.Path('images')
79
  if not image_dir.exists():
80
- dataset_repo = 'hysts/sample-images-TADNE'
81
- path = huggingface_hub.hf_hub_download(dataset_repo,
82
- 'images.tar.gz',
83
- repo_type='dataset',
84
- use_auth_token=TOKEN)
85
  with tarfile.open(path) as f:
86
  f.extractall()
87
- return sorted(image_dir.glob('*'))
88
 
89
 
90
  def create_model() -> tf.keras.Model:
91
- path = huggingface_hub.hf_hub_download('hysts/DeepDanbooru',
92
- 'model-resnet_custom_v3.h5',
93
- use_auth_token=TOKEN)
94
  model = tf.keras.models.load_model(path)
95
  model = tf.keras.Model(model.input, model.layers[-4].output)
96
  layer = tf.keras.layers.GlobalAveragePooling2D()
@@ -98,16 +66,21 @@ def create_model() -> tf.keras.Model:
98
  return model
99
 
100
 
101
- def predict(image: PIL.Image.Image, model: tf.keras.Model) -> np.ndarray:
 
 
 
 
 
 
 
 
102
  _, height, width, _ = model.input_shape
103
  image = np.asarray(image)
104
- image = tf.image.resize(image,
105
- size=(height, width),
106
- method=tf.image.ResizeMethod.AREA,
107
- preserve_aspect_ratio=True)
108
  image = image.numpy()
109
  image = dd.image.transform_and_pad_image(image, width, height)
110
- image = image / 255.
111
  features = model.predict(image[None, ...])[0]
112
  features = features.astype(float)
113
  return features
@@ -117,14 +90,9 @@ def run(
117
  image: PIL.Image.Image,
118
  nrows: int,
119
  ncols: int,
120
- image_size: int,
121
- dirname: str,
122
- tarball_path: pathlib.Path,
123
- deepdanbooru_predictions: np.ndarray,
124
- model: tf.keras.Model,
125
  ) -> tuple[np.ndarray, np.ndarray]:
126
- features = predict(image, model)
127
- distances = ((deepdanbooru_predictions - features)**2).sum(axis=1)
128
 
129
  image_indices = np.argsort(distances)
130
 
@@ -134,69 +102,42 @@ def run(
134
  for index in range(nrows * ncols):
135
  image_index = image_indices[index]
136
  seeds.append(image_index)
137
- member = tar_file.getmember(f'{dirname}/{image_index:07d}.jpg')
138
- with tar_file.extractfile(member) as f:
139
  data = io.BytesIO(f.read())
140
  image = PIL.Image.open(data)
141
  image = np.asarray(image)
142
  images.append(image)
143
- res = np.asarray(images).reshape(nrows, ncols, image_size, image_size,
144
- 3).transpose(0, 2, 1, 3, 4).reshape(
145
- nrows * image_size,
146
- ncols * image_size, 3)
 
 
147
  seeds = np.asarray(seeds).reshape(nrows, ncols)
148
 
149
  return res, seeds
150
 
151
 
152
- def main():
153
- args = parse_args()
154
-
155
- image_size = 128
156
- dirname = '0-99999'
157
- tarball_path = download_image_tarball(image_size, dirname)
158
- deepdanbooru_predictions = load_deepdanbooru_predictions(dirname)
159
-
160
- model = create_model()
161
-
162
- image_paths = load_sample_image_paths()
163
- examples = [[path.as_posix(), 2, 5] for path in image_paths]
164
-
165
- func = functools.partial(
166
- run,
167
- image_size=image_size,
168
- dirname=dirname,
169
- tarball_path=tarball_path,
170
- deepdanbooru_predictions=deepdanbooru_predictions,
171
- model=model,
172
- )
173
- func = functools.update_wrapper(func, run)
174
-
175
- gr.Interface(
176
- func,
177
- [
178
- gr.inputs.Image(type='pil', label='Input'),
179
- gr.inputs.Slider(1, 10, step=1, default=2, label='Number of Rows'),
180
- gr.inputs.Slider(
181
- 1, 10, step=1, default=5, label='Number of Columns'),
182
- ],
183
- [
184
- gr.outputs.Image(type='numpy', label='Output'),
185
- gr.outputs.Dataframe(type='numpy', label='Seed'),
186
- ],
187
- examples=examples,
188
- title=TITLE,
189
- description=DESCRIPTION,
190
- article=ARTICLE,
191
- theme=args.theme,
192
- allow_flagging=args.allow_flagging,
193
- live=args.live,
194
- ).launch(
195
- enable_queue=args.enable_queue,
196
- server_port=args.port,
197
- share=args.share,
198
- )
199
-
200
-
201
- if __name__ == '__main__':
202
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import io
 
6
  import pathlib
7
  import tarfile
8
 
 
14
  import tensorflow as tf
15
  from huggingface_hub import hf_hub_download
16
 
17
+ TITLE = "TADNE Image Search with DeepDanbooru"
18
+ DESCRIPTION = """The original TADNE site is https://thisanimedoesnotexist.ai/.
19
 
20
  This app shows images similar to the query image from images generated
21
  by the TADNE model with seed 0-99999.
 
35
  - [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector)
36
  - [TADNE Interpolation](https://huggingface.co/spaces/hysts/TADNE-interpolation)
37
  - [DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
38
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def load_deepdanbooru_predictions(dirname: str) -> np.ndarray:
42
  path = hf_hub_download(
43
+ "hysts/TADNE-sample-images",
44
+ f"prediction_results/deepdanbooru/intermediate_features/{dirname}.npy",
45
+ repo_type="dataset",
46
+ )
47
  return np.load(path)
48
 
49
 
50
  def load_sample_image_paths() -> list[pathlib.Path]:
51
+ image_dir = pathlib.Path("images")
52
  if not image_dir.exists():
53
+ dataset_repo = "hysts/sample-images-TADNE"
54
+ path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
 
 
 
55
  with tarfile.open(path) as f:
56
  f.extractall()
57
+ return sorted(image_dir.glob("*"))
58
 
59
 
60
  def create_model() -> tf.keras.Model:
61
+ path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5")
 
 
62
  model = tf.keras.models.load_model(path)
63
  model = tf.keras.Model(model.input, model.layers[-4].output)
64
  layer = tf.keras.layers.GlobalAveragePooling2D()
 
66
  return model
67
 
68
 
69
+ image_size = 128
70
+ dirname = "0-99999"
71
+ tarball_path = hf_hub_download("hysts/TADNE-sample-images", f"{image_size}/{dirname}.tar", repo_type="dataset")
72
+ deepdanbooru_predictions = load_deepdanbooru_predictions(dirname)
73
+
74
+ model = create_model()
75
+
76
+
77
+ def predict(image: PIL.Image.Image) -> np.ndarray:
78
  _, height, width, _ = model.input_shape
79
  image = np.asarray(image)
80
+ image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
 
 
 
81
  image = image.numpy()
82
  image = dd.image.transform_and_pad_image(image, width, height)
83
+ image = image / 255.0
84
  features = model.predict(image[None, ...])[0]
85
  features = features.astype(float)
86
  return features
 
90
  image: PIL.Image.Image,
91
  nrows: int,
92
  ncols: int,
 
 
 
 
 
93
  ) -> tuple[np.ndarray, np.ndarray]:
94
+ features = predict(image)
95
+ distances = ((deepdanbooru_predictions - features) ** 2).sum(axis=1)
96
 
97
  image_indices = np.argsort(distances)
98
 
 
102
  for index in range(nrows * ncols):
103
  image_index = image_indices[index]
104
  seeds.append(image_index)
105
+ member = tar_file.getmember(f"{dirname}/{image_index:07d}.jpg")
106
+ with tar_file.extractfile(member) as f: # type: ignore
107
  data = io.BytesIO(f.read())
108
  image = PIL.Image.open(data)
109
  image = np.asarray(image)
110
  images.append(image)
111
+ res = (
112
+ np.asarray(images)
113
+ .reshape(nrows, ncols, image_size, image_size, 3)
114
+ .transpose(0, 2, 1, 3, 4)
115
+ .reshape(nrows * image_size, ncols * image_size, 3)
116
+ )
117
  seeds = np.asarray(seeds).reshape(nrows, ncols)
118
 
119
  return res, seeds
120
 
121
 
122
+ image_paths = load_sample_image_paths()
123
+ examples = [[path.as_posix(), 2, 5] for path in image_paths]
124
+
125
+ demo = gr.Interface(
126
+ fn=run,
127
+ inputs=[
128
+ gr.Image(label="Input", type="pil"),
129
+ gr.Slider(label="Number of Rows", minimum=1, maximum=10, step=1, value=2),
130
+ gr.Slider(label="Number of Columns", minimum=1, maximum=10, step=1, value=2),
131
+ ],
132
+ outputs=[
133
+ gr.Image(label="Output"),
134
+ gr.Dataframe(label="Seed"),
135
+ ],
136
+ examples=examples,
137
+ title=TITLE,
138
+ description=DESCRIPTION,
139
+ )
140
+
141
+
142
+ if __name__ == "__main__":
143
+ demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- pillow==9.1.0
2
- tensorflow==2.8.0
3
  git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru
 
 
 
 
 
1
  git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru
2
+ pillow==10.3.0
3
+ tensorflow==2.8.0