mcding commited on
Commit
ad552d8
·
0 Parent(s):

published version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +164 -0
  3. README.md +10 -0
  4. app.py +426 -0
  5. icon.jpg +0 -0
  6. image.png +0 -0
  7. kit/__init__.py +121 -0
  8. kit/metrics/__init__.py +27 -0
  9. kit/metrics/aesthetics.py +38 -0
  10. kit/metrics/aesthetics_scorer/__init__.py +4 -0
  11. kit/metrics/aesthetics_scorer/model.py +104 -0
  12. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.config +8 -0
  13. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth +3 -0
  14. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.config +8 -0
  15. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth +3 -0
  16. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.config +8 -0
  17. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth +3 -0
  18. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.config +8 -0
  19. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth +3 -0
  20. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.config +8 -0
  21. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth +3 -0
  22. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.config +8 -0
  23. kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth +3 -0
  24. kit/metrics/clean_fid/__init__.py +3 -0
  25. kit/metrics/clean_fid/clip_features.py +40 -0
  26. kit/metrics/clean_fid/downloads_helper.py +75 -0
  27. kit/metrics/clean_fid/features.py +117 -0
  28. kit/metrics/clean_fid/fid.py +836 -0
  29. kit/metrics/clean_fid/inception_pytorch.py +329 -0
  30. kit/metrics/clean_fid/inception_torchscript.py +59 -0
  31. kit/metrics/clean_fid/leaderboard.py +58 -0
  32. kit/metrics/clean_fid/resize.py +108 -0
  33. kit/metrics/clean_fid/utils.py +75 -0
  34. kit/metrics/clean_fid/wrappers.py +111 -0
  35. kit/metrics/clip.py +32 -0
  36. kit/metrics/distributional.py +104 -0
  37. kit/metrics/image.py +112 -0
  38. kit/metrics/lpips/__init__.py +4 -0
  39. kit/metrics/lpips/lpips.py +338 -0
  40. kit/metrics/lpips/pretrained_networks.py +188 -0
  41. kit/metrics/lpips/trainer.py +314 -0
  42. kit/metrics/lpips/utils.py +137 -0
  43. kit/metrics/lpips/weights/v0.0/alex.pth +3 -0
  44. kit/metrics/lpips/weights/v0.0/squeeze.pth +3 -0
  45. kit/metrics/lpips/weights/v0.0/vgg.pth +3 -0
  46. kit/metrics/lpips/weights/v0.1/alex.pth +3 -0
  47. kit/metrics/lpips/weights/v0.1/squeeze.pth +3 -0
  48. kit/metrics/lpips/weights/v0.1/vgg.pth +3 -0
  49. kit/metrics/perceptual.py +93 -0
  50. kit/metrics/prompt.py +39 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ .env
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Erasing the Invisible
3
+ emoji: 🌊
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: true
10
+ ---
app.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import json
5
+ import redis
6
+ import plotly.graph_objects as go
7
+ from datetime import datetime
8
+ from PIL import Image
9
+ from kit import compute_performance, compute_quality
10
+ import dotenv
11
+ import pandas as pd
12
+
13
+ dotenv.load_dotenv()
14
+
15
+ CSS = """
16
+ .tabs button{
17
+ font-size: 20px;
18
+ }
19
+ #download_btn {
20
+ height: 91.6px;
21
+ }
22
+ #submit_btn {
23
+ height: 91.6px;
24
+ }
25
+ #original_image {
26
+ display: block;
27
+ margin-left: auto;
28
+ margin-right: auto;
29
+ }
30
+ #uploaded_image {
31
+ display: block;
32
+ margin-left: auto;
33
+ margin-right: auto;
34
+ }
35
+ #leaderboard_plot {
36
+ display: block;
37
+ margin-left: auto;
38
+ margin-right: auto;
39
+ width: 640px; /* Adjust width as needed */
40
+ height: 640px; /* Adjust height as needed */
41
+ #leaderboard_table {
42
+ display: block;
43
+ margin-left: auto;
44
+ margin-right: auto;
45
+ }
46
+ """
47
+
48
+ JS = """
49
+ function refresh() {
50
+ const url = new URL(window.location);
51
+
52
+ if (url.searchParams.get('__theme') !== 'dark') {
53
+ url.searchParams.set('__theme', 'dark');
54
+ window.location.href = url.href;
55
+ }
56
+ }
57
+ """
58
+
59
+ QUALITY_POST_FUNC = lambda x: x / 4 * 8
60
+ PERFORMANCE_POST_FUNC = lambda x: abs(x - 0.5) * 2
61
+
62
+
63
+ # Connect to Redis
64
+ redis_client = redis.Redis(
65
+ host=os.getenv("REDIS_HOST"),
66
+ port=os.getenv("REDIS_PORT"),
67
+ username=os.getenv("REDIS_USERNAME"),
68
+ password=os.getenv("REDIS_PASSWORD"),
69
+ decode_responses=True,
70
+ )
71
+
72
+
73
+ def save_to_redis(name, performance, quality):
74
+ submission = {
75
+ "name": name,
76
+ "performance": performance,
77
+ "quality": quality,
78
+ "timestamp": datetime.now().isoformat(),
79
+ }
80
+ redis_client.lpush("submissions", json.dumps(submission))
81
+
82
+
83
+ def get_submissions_from_redis():
84
+ submissions = redis_client.lrange("submissions", 0, -1)
85
+ submissions = [json.loads(submission) for submission in submissions]
86
+ for s in submissions:
87
+ s["quality"] = QUALITY_POST_FUNC(s["quality"])
88
+ s["performance"] = PERFORMANCE_POST_FUNC(s["performance"])
89
+ s["score"] = np.sqrt(float(s["quality"]) ** 2 + float(s["performance"]) ** 2)
90
+ return submissions
91
+
92
+
93
+ def update_plot(
94
+ submissions,
95
+ current_name=None,
96
+ ):
97
+ names = [sub["name"] for sub in submissions]
98
+ performances = [float(sub["performance"]) for sub in submissions]
99
+ qualities = [float(sub["quality"]) for sub in submissions]
100
+
101
+ # Create scatter plot
102
+ fig = go.Figure()
103
+
104
+ for name, quality, performance in zip(names, qualities, performances):
105
+ if name == current_name:
106
+ marker = dict(symbol="star", size=15, color="orange")
107
+ elif name.startswith("Baseline: "):
108
+ marker = dict(symbol="square", size=8, color="blue")
109
+ else:
110
+ marker = dict(symbol="circle", size=10, color="green")
111
+
112
+ fig.add_trace(
113
+ go.Scatter(
114
+ x=[quality],
115
+ y=[performance],
116
+ mode="markers+text",
117
+ text=[name if not name.startswith("Baseline: ") else ""],
118
+ textposition="top center",
119
+ name=name,
120
+ marker=marker,
121
+ customdata=[
122
+ name if name.startswith("Baseline: ") else f"User: {name}",
123
+ ],
124
+ hovertemplate="<b>%{customdata}</b><br>"
125
+ + "Performance: %{y:.3f}<br>"
126
+ + "Quality: %{x:.3f}<br>"
127
+ + "<extra></extra>",
128
+ )
129
+ )
130
+
131
+ # Add circles
132
+ circle_radii = np.linspace(0, 1, 5)
133
+ for radius in circle_radii:
134
+ theta = np.linspace(0, 2 * np.pi, 100)
135
+ x = radius * np.cos(theta)
136
+ y = radius * np.sin(theta)
137
+ fig.add_trace(
138
+ go.Scatter(
139
+ x=x,
140
+ y=y,
141
+ mode="lines",
142
+ line=dict(color="gray", dash="dash"),
143
+ showlegend=False,
144
+ hovertemplate="Performance: %{x:.3f}<br>"
145
+ + "Quality: %{y:.3f}<br>"
146
+ + "<extra></extra>",
147
+ )
148
+ )
149
+
150
+ # Update layout
151
+ fig.update_layout(
152
+ xaxis_title="Image Quality Degredation",
153
+ yaxis_title="Watermark Detection Performance",
154
+ xaxis=dict(
155
+ range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
156
+ ),
157
+ yaxis=dict(
158
+ range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
159
+ ),
160
+ width=640,
161
+ height=640,
162
+ showlegend=False, # Remove legend
163
+ modebar=dict(remove=["all"]),
164
+ )
165
+ fig.update_xaxes(title_font_size=20)
166
+ fig.update_yaxes(title_font_size=20)
167
+
168
+ return fig
169
+
170
+
171
+ def update_table(
172
+ submissions,
173
+ current_name=None,
174
+ ):
175
+ def tp(timestamp):
176
+ return timestamp.replace("T", " ").split(".")[0]
177
+
178
+ names = [
179
+ (
180
+ sub["name"][len("Baseline: ") :]
181
+ if sub["name"].startswith("Baseline: ")
182
+ else sub["name"]
183
+ )
184
+ for sub in submissions
185
+ ]
186
+ times = [
187
+ (
188
+ ""
189
+ if sub["name"].startswith("Baseline: ")
190
+ else (
191
+ tp(sub["timestamp"]) + " (Current)"
192
+ if sub["name"] == current_name
193
+ else tp(sub["timestamp"])
194
+ )
195
+ )
196
+ for sub in submissions
197
+ ]
198
+ performances = ["%.4f" % (float(sub["performance"])) for sub in submissions]
199
+ qualities = ["%.4f" % (float(sub["quality"])) for sub in submissions]
200
+ scores = ["%.4f" % (float(sub["score"])) for sub in submissions]
201
+ df = pd.DataFrame(
202
+ {
203
+ "Name": names,
204
+ "Submission Time": times,
205
+ "Performance": performances,
206
+ "Quality": qualities,
207
+ "Score": scores,
208
+ }
209
+ ).sort_values(by=["Score"])
210
+ df.insert(0, "Rank #", list(np.arange(len(names)) + 1), True)
211
+
212
+ def highlight_null(s):
213
+ con = s.copy()
214
+ con[:] = None
215
+ if s["Submission Time"] == "":
216
+ con[:] = "background-color: darkgrey"
217
+ return con
218
+
219
+ return df.style.apply(highlight_null, axis=1)
220
+
221
+
222
+ def process_submission(name, image):
223
+ original_image = Image.open("./image.png")
224
+ progress = gr.Progress()
225
+ progress(0, desc="Detecting Watermark")
226
+ performance = compute_performance(image)
227
+ progress(0.4, desc="Evaluating Image Quality")
228
+ quality = compute_quality(image, original_image)
229
+ progress(1.0, desc="Uploading Results")
230
+
231
+ # Save unprocessed values but display processed values
232
+ save_to_redis(name, performance, quality)
233
+ quality = QUALITY_POST_FUNC(quality)
234
+ performance = PERFORMANCE_POST_FUNC(performance)
235
+
236
+ submissions = get_submissions_from_redis()
237
+ leaderboard_table = update_table(submissions, current_name=name)
238
+ leaderboard_plot = update_plot(submissions, current_name=name)
239
+
240
+ # Calculate rank
241
+ distances = [
242
+ np.sqrt(float(s["quality"]) ** 2 + float(s["performance"]) ** 2)
243
+ for s in submissions
244
+ ]
245
+ rank = sorted(distances).index(np.sqrt(quality**2 + performance**2)) + 1
246
+ gr.Info(f"You ranked {rank} out of {len(submissions)}!")
247
+ return (
248
+ leaderboard_plot,
249
+ leaderboard_table,
250
+ f"{rank} out of {len(submissions)}",
251
+ name,
252
+ f"{performance:.3f}",
253
+ f"{quality:.3f}",
254
+ f"{np.sqrt(quality**2 + performance**2):.3f}",
255
+ )
256
+
257
+
258
+ def upload_and_evaluate(name, image):
259
+ if name == "":
260
+ raise gr.Error("Please enter your name before submitting.")
261
+ if image is None:
262
+ raise gr.Error("Please upload an image before submitting.")
263
+ return process_submission(name, image)
264
+
265
+
266
+ def create_interface():
267
+ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, js=JS) as demo:
268
+ gr.Markdown(
269
+ """
270
+ # Erasing the Invisible (Demo of NeurIPS'24 competition)
271
+
272
+ Welcome to the demo of the NeurIPS'24 competition [Erasing the Invisible: A Stress-Test Challenge for Image Watermarks](https://erasinginvisible.github.io/).
273
+
274
+ You could use this demo to better understand the competition pipeline or just for fun! 🎮
275
+
276
+ Here, we provide an image embedded with an invisible watermark. You only need to:
277
+
278
+ 1. **Download** the original watermarked image. 🌊
279
+ 2. **Remove** the invisible watermark using your preferred attack. 🧼
280
+ 3. **Upload** your image. We will evaluate and rank your attack. 📊
281
+
282
+ That's it! 🚀
283
+
284
+ *Note: This is just a demo. The watermark used here is not necessarily representative of those used for the competition. To officially participate in the competition, please follow the guidelines [here](https://erasinginvisible.github.io/).*
285
+ """
286
+ )
287
+
288
+ with gr.Tabs(elem_classes=["tabs"]) as tabs:
289
+ with gr.Tab("Original Watermarked Image", id="download"):
290
+ with gr.Column():
291
+ original_image = gr.Image(
292
+ value="./image.png",
293
+ format="png",
294
+ label="Original Watermarked Image",
295
+ show_label=True,
296
+ height=512,
297
+ width=512,
298
+ type="filepath",
299
+ show_download_button=False,
300
+ show_share_button=False,
301
+ show_fullscreen_button=False,
302
+ container=True,
303
+ elem_id="original_image",
304
+ )
305
+ with gr.Row():
306
+ download_btn = gr.DownloadButton(
307
+ "Download Watermarked Image",
308
+ value="./image.png",
309
+ elem_id="download_btn",
310
+ )
311
+ submit_btn = gr.Button(
312
+ "Submit Your Removal", elem_id="submit_btn"
313
+ )
314
+
315
+ with gr.Tab(
316
+ "Submit Watermark Removed Image",
317
+ id="submit",
318
+ elem_classes="gr-tab-header",
319
+ ):
320
+
321
+ with gr.Column():
322
+ uploaded_image = gr.Image(
323
+ label="Your Watermark Removed Image",
324
+ format="png",
325
+ show_label=True,
326
+ height=512,
327
+ width=512,
328
+ sources=["upload"],
329
+ type="pil",
330
+ show_download_button=False,
331
+ show_share_button=False,
332
+ show_fullscreen_button=False,
333
+ container=True,
334
+ placeholder="Upload your watermark removed image",
335
+ elem_id="uploaded_image",
336
+ )
337
+ with gr.Row():
338
+ name_input = gr.Textbox(
339
+ label="Your Name", placeholder="Anonymous"
340
+ )
341
+ upload_btn = gr.Button("Upload and Evaluate")
342
+
343
+ with gr.Tab(
344
+ "Evaluation Results",
345
+ id="plot",
346
+ elem_classes="gr-tab-header",
347
+ ):
348
+ gr.Markdown(
349
+ "The evaluation is based on two metrics, watermark performance ($$A$$) and image quality degradation ($$Q$$).",
350
+ latex_delimiters=[{"left": "$$", "right": "$$", "display": False}],
351
+ )
352
+ gr.Markdown(
353
+ "The lower the watermark performance and less quality degradation, the more effective the attack is. The overall score is $$\sqrt{Q^2+A^2}$$, the smaller the better.",
354
+ latex_delimiters=[{"left": "$$", "right": "$$", "display": False}],
355
+ )
356
+ gr.Markdown(
357
+ """
358
+ <p>
359
+ <span style="display: inline-block; width: 20px;"></span>🟦: Baseline attacks
360
+ <span style="display: inline-block; width: 20px;"></span>🟢: Users' submissions
361
+ <span style="display: inline-block; width: 20px;"></span>⭐: Your current submission
362
+ </p>
363
+ <p><em>Note: The performance and quality metrics differ from those in the competition (as only one image is used here), but they still give you an idea of how effective your attack is.</em></p>
364
+ """
365
+ )
366
+ with gr.Column():
367
+ leaderboard_plot = gr.Plot(
368
+ value=update_plot(get_submissions_from_redis()),
369
+ show_label=False,
370
+ elem_id="leaderboard_plot",
371
+ )
372
+ with gr.Row():
373
+ rank_output = gr.Textbox(label="Your Ranking")
374
+ name_output = gr.Textbox(label="Your Name")
375
+ performance_output = gr.Textbox(label="Watermark Performance")
376
+ quality_output = gr.Textbox(label="Quality Degredation")
377
+ overall_output = gr.Textbox(label="Overall Score")
378
+ with gr.Tab(
379
+ "Leaderboard",
380
+ id="leaderboard",
381
+ elem_classes="gr-tab-header",
382
+ ):
383
+ gr.Markdown("Find your ranking on the leaderboard!")
384
+ gr.Markdown(
385
+ "Gray-shaded rows are baseline results provided by the organziers."
386
+ )
387
+ with gr.Column():
388
+ leaderboard_table = gr.Dataframe(
389
+ value=update_table(get_submissions_from_redis()),
390
+ show_label=False,
391
+ elem_id="leaderboard_table",
392
+ )
393
+ submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs)
394
+
395
+ upload_btn.click(lambda: gr.Tabs(selected="plot"), None, tabs).then(
396
+ upload_and_evaluate,
397
+ inputs=[name_input, uploaded_image],
398
+ outputs=[
399
+ leaderboard_plot,
400
+ leaderboard_table,
401
+ rank_output,
402
+ name_output,
403
+ performance_output,
404
+ quality_output,
405
+ overall_output,
406
+ ],
407
+ )
408
+
409
+ demo.load(
410
+ lambda: [
411
+ gr.Image(value="./image.png", height=512, width=512),
412
+ gr.Plot(update_plot(get_submissions_from_redis())),
413
+ gr.Dataframe(update_table(get_submissions_from_redis())),
414
+ ],
415
+ outputs=[original_image, leaderboard_plot, leaderboard_table],
416
+ )
417
+
418
+ return demo
419
+
420
+
421
+ # Create the demo object
422
+ demo = create_interface()
423
+
424
+ # Launch the app
425
+ if __name__ == "__main__":
426
+ demo.launch(share=False)
icon.jpg ADDED
image.png ADDED
kit/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from PIL import Image
6
+ import dotenv
7
+
8
+ dotenv.load_dotenv()
9
+
10
+ GT_MESSAGE = os.environ["GT_MESSAGE"]
11
+
12
+
13
+ QUALITY_COEFFICIENTS = {
14
+ "psnr": -0.0022186489180419534,
15
+ "ssim": -0.11337077856710862,
16
+ "nmi": -0.09878221979274945,
17
+ "lpips": 0.3412626374646173,
18
+ }
19
+
20
+ QUALITY_OFFSETS = {
21
+ "psnr": 43.54757854447622,
22
+ "ssim": 0.984229018845295,
23
+ "nmi": 1.7536553655336136,
24
+ "lpips": 0.014247652621287854,
25
+ }
26
+
27
+
28
+ def compute_performance(image):
29
+ session_options = ort.SessionOptions()
30
+ session_options.intra_op_num_threads = 1
31
+ session_options.inter_op_num_threads = 1
32
+ session_options.log_severity_level = 3
33
+ model = ort.InferenceSession(
34
+ "./kit/models/stable_signature.onnx",
35
+ sess_options=session_options,
36
+ )
37
+ inputs = np.stack(
38
+ [
39
+ (
40
+ (
41
+ np.array(
42
+ image,
43
+ dtype=np.float32,
44
+ )
45
+ / 255.0
46
+ - [0.485, 0.456, 0.406]
47
+ )
48
+ / [0.229, 0.224, 0.225]
49
+ )
50
+ .transpose((2, 0, 1))
51
+ .astype(np.float32)
52
+ ],
53
+ axis=0,
54
+ )
55
+
56
+ outputs = model.run(
57
+ None,
58
+ {
59
+ "image": inputs,
60
+ },
61
+ )
62
+ decoded = (outputs[0] > 0).astype(int)[0]
63
+ gt_message = np.array([int(bit) for bit in GT_MESSAGE])
64
+ return 1 - np.mean(gt_message != decoded)
65
+
66
+
67
+ from .metrics import (
68
+ compute_image_distance_repeated,
69
+ load_perceptual_models,
70
+ compute_perceptual_metric_repeated,
71
+ load_aesthetics_and_artifacts_models,
72
+ compute_aesthetics_and_artifacts_scores,
73
+ )
74
+
75
+
76
+ def compute_quality(attacked_image, clean_image, quiet=True):
77
+
78
+ # Compress the image
79
+ buffer = io.BytesIO()
80
+ attacked_image.save(buffer, format="JPEG", quality=95)
81
+ buffer.seek(0)
82
+
83
+ # Update attacked_image with the compressed version
84
+ attacked_image = Image.open(buffer)
85
+
86
+ modes = ["psnr", "ssim", "nmi", "lpips"]
87
+
88
+ results = {}
89
+ for mode in modes:
90
+ if mode in ["psnr", "ssim", "nmi"]:
91
+ metrics = compute_image_distance_repeated(
92
+ [clean_image],
93
+ [attacked_image],
94
+ metric_name=mode,
95
+ num_workers=1,
96
+ verbose=not quiet,
97
+ )
98
+ results[mode] = metrics
99
+
100
+ elif mode == "lpips":
101
+ model = load_perceptual_models(
102
+ mode,
103
+ mode="alex",
104
+ device="cpu",
105
+ )
106
+ metrics = compute_perceptual_metric_repeated(
107
+ [clean_image],
108
+ [attacked_image],
109
+ metric_name=mode,
110
+ mode="alex",
111
+ model=model,
112
+ device="cpu",
113
+ )
114
+ results[mode] = metrics
115
+
116
+ normalized_quality = 0
117
+ for key, value in results.items():
118
+ normalized_quality += (value[0] - QUALITY_OFFSETS[key]) * QUALITY_COEFFICIENTS[
119
+ key
120
+ ]
121
+ return normalized_quality
kit/metrics/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .distributional import compute_fid
2
+ from .image import (
3
+ compute_mse,
4
+ compute_psnr,
5
+ compute_ssim,
6
+ compute_nmi,
7
+ compute_mse_repeated,
8
+ compute_psnr_repeated,
9
+ compute_ssim_repeated,
10
+ compute_nmi_repeated,
11
+ compute_image_distance_repeated,
12
+ )
13
+ from .perceptual import (
14
+ load_perceptual_models,
15
+ compute_lpips,
16
+ compute_lpips_repeated,
17
+ compute_perceptual_metric_repeated,
18
+ )
19
+ from .aesthetics import (
20
+ load_aesthetics_and_artifacts_models,
21
+ compute_aesthetics_and_artifacts_scores,
22
+ )
23
+ from .clip import load_open_clip_model_preprocess_and_tokenizer, compute_clip_score
24
+ from .prompt import (
25
+ load_perplexity_model_and_tokenizer,
26
+ compute_prompt_perplexity,
27
+ )
kit/metrics/aesthetics.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import CLIPModel, CLIPProcessor
4
+ from .aesthetics_scorer import preprocess, load_model
5
+
6
+
7
+ def load_aesthetics_and_artifacts_models(device=torch.device("cuda")):
8
+ model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
9
+ vision_model = model.vision_model
10
+ vision_model.to(device)
11
+ del model
12
+ clip_processor = CLIPProcessor.from_pretrained(
13
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
14
+ )
15
+ rating_model = load_model("aesthetics_scorer_rating_openclip_vit_h_14").to(device)
16
+ artifacts_model = load_model("aesthetics_scorer_artifacts_openclip_vit_h_14").to(
17
+ device
18
+ )
19
+ return vision_model, clip_processor, rating_model, artifacts_model
20
+
21
+
22
+ def compute_aesthetics_and_artifacts_scores(
23
+ images, models, device=torch.device("cuda")
24
+ ):
25
+ vision_model, clip_processor, rating_model, artifacts_model = models
26
+
27
+ inputs = clip_processor(images=images, return_tensors="pt").to(device)
28
+ with torch.no_grad():
29
+ vision_output = vision_model(**inputs)
30
+ pooled_output = vision_output.pooler_output
31
+ embedding = preprocess(pooled_output)
32
+ with torch.no_grad():
33
+ rating = rating_model(embedding)
34
+ artifact = artifacts_model(embedding)
35
+ return (
36
+ rating.detach().cpu().numpy().flatten().tolist(),
37
+ artifact.detach().cpu().numpy().flatten().tolist(),
38
+ )
kit/metrics/aesthetics_scorer/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ From https://github.com/kenjiqq/aesthetics-scorer#validation-split-of-diffusiondb-dataset
3
+ """
4
+ from .model import preprocess, load_model
kit/metrics/aesthetics_scorer/model.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import json
4
+ import os
5
+ import inspect
6
+
7
+
8
+ class AestheticScorer(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_size=0,
12
+ use_activation=False,
13
+ dropout=0.2,
14
+ config=None,
15
+ hidden_dim=1024,
16
+ reduce_dims=False,
17
+ output_activation=None,
18
+ ):
19
+ super().__init__()
20
+ self.config = {
21
+ "input_size": input_size,
22
+ "use_activation": use_activation,
23
+ "dropout": dropout,
24
+ "hidden_dim": hidden_dim,
25
+ "reduce_dims": reduce_dims,
26
+ "output_activation": output_activation,
27
+ }
28
+ if config != None:
29
+ self.config.update(config)
30
+
31
+ layers = [
32
+ nn.Linear(self.config["input_size"], self.config["hidden_dim"]),
33
+ nn.ReLU() if self.config["use_activation"] else None,
34
+ nn.Dropout(self.config["dropout"]),
35
+ nn.Linear(
36
+ self.config["hidden_dim"],
37
+ round(self.config["hidden_dim"] / (2 if reduce_dims else 1)),
38
+ ),
39
+ nn.ReLU() if self.config["use_activation"] else None,
40
+ nn.Dropout(self.config["dropout"]),
41
+ nn.Linear(
42
+ round(self.config["hidden_dim"] / (2 if reduce_dims else 1)),
43
+ round(self.config["hidden_dim"] / (4 if reduce_dims else 1)),
44
+ ),
45
+ nn.ReLU() if self.config["use_activation"] else None,
46
+ nn.Dropout(self.config["dropout"]),
47
+ nn.Linear(
48
+ round(self.config["hidden_dim"] / (4 if reduce_dims else 1)),
49
+ round(self.config["hidden_dim"] / (8 if reduce_dims else 1)),
50
+ ),
51
+ nn.ReLU() if self.config["use_activation"] else None,
52
+ nn.Linear(round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 1),
53
+ ]
54
+ if self.config["output_activation"] == "sigmoid":
55
+ layers.append(nn.Sigmoid())
56
+ layers = [x for x in layers if x is not None]
57
+ self.layers = nn.Sequential(*layers)
58
+
59
+ def forward(self, x):
60
+ if self.config["output_activation"] == "sigmoid":
61
+ upper, lower = 10, 1
62
+ scale = upper - lower
63
+ return (self.layers(x) * scale) + lower
64
+ else:
65
+ return self.layers(x)
66
+
67
+ def save(self, save_name):
68
+ split_name = os.path.splitext(save_name)
69
+ with open(f"{split_name[0]}.config", "w") as outfile:
70
+ outfile.write(json.dumps(self.config, indent=4))
71
+
72
+ for i in range(
73
+ 6
74
+ ): # saving sometiles fails, so retry 5 times, might be windows issue
75
+ try:
76
+ torch.save(self.state_dict(), save_name)
77
+ break
78
+ except RuntimeError as e:
79
+ # check if error contains string "File"
80
+ if "cannot be opened" in str(e) and i < 5:
81
+ print("Model save failed, retrying...")
82
+ else:
83
+ raise e
84
+
85
+
86
+ def preprocess(embeddings):
87
+ return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
88
+
89
+
90
+ def load_model(weight_name, device="cuda" if torch.cuda.is_available() else "cpu"):
91
+ weight_folder = os.path.abspath(
92
+ os.path.join(
93
+ inspect.getfile(load_model),
94
+ "../weights",
95
+ )
96
+ )
97
+ weight_path = os.path.join(weight_folder, f"{weight_name}.pth")
98
+ config_path = os.path.join(weight_folder, f"{weight_name}.config")
99
+ with open(config_path, "r") as config_file:
100
+ config = json.load(config_file)
101
+ model = AestheticScorer(config=config)
102
+ model.load_state_dict(torch.load(weight_path, map_location=device))
103
+ model.eval()
104
+ return model
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.config ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 1664,
3
+ "use_activation": false,
4
+ "dropout": 0.0,
5
+ "hidden_dim": 1024,
6
+ "reduce_dims": false,
7
+ "output_activation": null
8
+ }
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39a5d014670226d52c408e0dfec840b7626d80a73d003a6a144caafd5e02d031
3
+ size 19423219
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.config ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 1280,
3
+ "use_activation": false,
4
+ "dropout": 0.0,
5
+ "hidden_dim": 1024,
6
+ "reduce_dims": false,
7
+ "output_activation": null
8
+ }
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc48a8a2315cfdbc7bb8278be55f645e8a995e1a2fa234baec5eb41c4d33e070
3
+ size 17850319
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.config ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 1024,
3
+ "use_activation": false,
4
+ "dropout": 0.0,
5
+ "hidden_dim": 1024,
6
+ "reduce_dims": false,
7
+ "output_activation": null
8
+ }
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4a9481fdbce5ff02b252bcb25109b9f3b29841289fadf7e79e884d59f9357d5
3
+ size 16801743
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.config ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 1664,
3
+ "use_activation": false,
4
+ "dropout": 0.0,
5
+ "hidden_dim": 1024,
6
+ "reduce_dims": false,
7
+ "output_activation": null
8
+ }
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19b016304f54ae866e27f1eb498c0861f704958e7c37693adc5ce094e63904a8
3
+ size 19423099
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.config ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 1280,
3
+ "use_activation": false,
4
+ "dropout": 0.0,
5
+ "hidden_dim": 1024,
6
+ "reduce_dims": false,
7
+ "output_activation": null
8
+ }
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03603eee1864c2e5e97ef7079229609653db5b10594ca8b1de9e541d838cae9c
3
+ size 17850199
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.config ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 1024,
3
+ "use_activation": false,
4
+ "dropout": 0.0,
5
+ "hidden_dim": 1024,
6
+ "reduce_dims": false,
7
+ "output_activation": null
8
+ }
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb7fe561369ab6c7dad34b9316a56d2c6070582f0323656148e1107a242cd666
3
+ size 16801623
kit/metrics/clean_fid/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ From https://github.com/GaParmar/clean-fid/tree/main
3
+ """
kit/metrics/clean_fid/clip_features.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install git+https://github.com/openai/CLIP.git
2
+ import pdb
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import clip
8
+ from .fid import compute_fid
9
+
10
+
11
+ def img_preprocess_clip(img_np):
12
+ x = Image.fromarray(img_np.astype(np.uint8)).convert("RGB")
13
+ T = transforms.Compose(
14
+ [
15
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
16
+ transforms.CenterCrop(224),
17
+ ]
18
+ )
19
+ return np.asarray(T(x)).clip(0, 255).astype(np.uint8)
20
+
21
+
22
+ class CLIP_fx:
23
+ def __init__(self, name="ViT-B/32", device="cuda"):
24
+ self.model, _ = clip.load(name, device=device)
25
+ self.model.eval()
26
+ self.name = "clip_" + name.lower().replace("-", "_").replace("/", "_")
27
+
28
+ def __call__(self, img_t):
29
+ img_x = img_t / 255.0
30
+ T_norm = transforms.Normalize(
31
+ (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
32
+ )
33
+ img_x = T_norm(img_x)
34
+ assert torch.is_tensor(img_x)
35
+ if len(img_x.shape) == 3:
36
+ img_x = img_x.unsqueeze(0)
37
+ B, C, H, W = img_x.shape
38
+ with torch.no_grad():
39
+ z = self.model.encode_image(img_x)
40
+ return z
kit/metrics/clean_fid/downloads_helper.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import requests
4
+ import shutil
5
+
6
+
7
+ inception_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt"
8
+
9
+
10
+ """
11
+ Download the pretrined inception weights if it does not exists
12
+ ARGS:
13
+ fpath - output folder path
14
+ """
15
+
16
+
17
+ def check_download_inception(fpath="./"):
18
+ inception_path = os.path.join(fpath, "inception-2015-12-05.pt")
19
+ if not os.path.exists(inception_path):
20
+ # download the file
21
+ with urllib.request.urlopen(inception_url) as response, open(
22
+ inception_path, "wb"
23
+ ) as f:
24
+ shutil.copyfileobj(response, f)
25
+ return inception_path
26
+
27
+
28
+ """
29
+ Download any url if it does not exist
30
+ ARGS:
31
+ local_folder - output folder path
32
+ url - the weburl to download
33
+ """
34
+
35
+
36
+ def check_download_url(local_folder, url):
37
+ name = os.path.basename(url)
38
+ local_path = os.path.join(local_folder, name)
39
+ if not os.path.exists(local_path):
40
+ os.makedirs(local_folder, exist_ok=True)
41
+ print(f"downloading statistics to {local_path}")
42
+ with urllib.request.urlopen(url) as response, open(local_path, "wb") as f:
43
+ shutil.copyfileobj(response, f)
44
+ return local_path
45
+
46
+
47
+ """
48
+ Download a file from google drive
49
+ ARGS:
50
+ file_id - id of the google drive file
51
+ out_path - output folder path
52
+ """
53
+
54
+
55
+ def download_google_drive(file_id, out_path):
56
+ def get_confirm_token(response):
57
+ for key, value in response.cookies.items():
58
+ if key.startswith("download_warning"):
59
+ return value
60
+ return None
61
+
62
+ URL = "https://drive.google.com/uc?export=download"
63
+ session = requests.Session()
64
+ response = session.get(URL, params={"id": file_id}, stream=True)
65
+ token = get_confirm_token(response)
66
+
67
+ if token:
68
+ params = {"id": file_id, "confirm": token}
69
+ response = session.get(URL, params=params, stream=True)
70
+
71
+ CHUNK_SIZE = 32768
72
+ with open(out_path, "wb") as f:
73
+ for chunk in response.iter_content(CHUNK_SIZE):
74
+ if chunk:
75
+ f.write(chunk)
kit/metrics/clean_fid/features.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ helpers for extracting features from image
3
+ """
4
+ import os
5
+ import platform
6
+ import numpy as np
7
+ import torch
8
+ from torch.hub import get_dir
9
+ from .downloads_helper import check_download_url
10
+ from .inception_pytorch import InceptionV3
11
+ from .inception_torchscript import InceptionV3W
12
+
13
+
14
+ """
15
+ returns a functions that takes an image in range [0,255]
16
+ and outputs a feature embedding vector
17
+ """
18
+
19
+
20
+ def feature_extractor(
21
+ name="torchscript_inception",
22
+ device=torch.device("cuda"),
23
+ resize_inside=False,
24
+ use_dataparallel=True,
25
+ ):
26
+ if name == "torchscript_inception":
27
+ path = "./" if platform.system() == "Windows" else "/tmp"
28
+ model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(
29
+ device
30
+ )
31
+ model.eval()
32
+ if use_dataparallel:
33
+ model = torch.nn.DataParallel(model)
34
+
35
+ def model_fn(x):
36
+ return model(x)
37
+
38
+ elif name == "pytorch_inception":
39
+ model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
40
+ model.eval()
41
+ if use_dataparallel:
42
+ model = torch.nn.DataParallel(model)
43
+
44
+ def model_fn(x):
45
+ return model(x / 255)[0].squeeze(-1).squeeze(-1)
46
+
47
+ else:
48
+ raise ValueError(f"{name} feature extractor not implemented")
49
+ return model_fn
50
+
51
+
52
+ """
53
+ Build a feature extractor for each of the modes
54
+ """
55
+
56
+
57
+ def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
58
+ if mode == "legacy_pytorch":
59
+ feat_model = feature_extractor(
60
+ name="pytorch_inception",
61
+ resize_inside=False,
62
+ device=device,
63
+ use_dataparallel=use_dataparallel,
64
+ )
65
+ elif mode == "legacy_tensorflow":
66
+ feat_model = feature_extractor(
67
+ name="torchscript_inception",
68
+ resize_inside=True,
69
+ device=device,
70
+ use_dataparallel=use_dataparallel,
71
+ )
72
+ elif mode == "clean":
73
+ feat_model = feature_extractor(
74
+ name="torchscript_inception",
75
+ resize_inside=False,
76
+ device=device,
77
+ use_dataparallel=use_dataparallel,
78
+ )
79
+ return feat_model
80
+
81
+
82
+ """
83
+ Load precomputed reference statistics for commonly used datasets
84
+ """
85
+
86
+
87
+ def get_reference_statistics(
88
+ name,
89
+ res,
90
+ mode="clean",
91
+ model_name="inception_v3",
92
+ seed=0,
93
+ split="test",
94
+ metric="FID",
95
+ ):
96
+ base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
97
+ if split == "custom":
98
+ res = "na"
99
+ if model_name == "inception_v3":
100
+ model_modifier = ""
101
+ else:
102
+ model_modifier = "_" + model_name
103
+ if metric == "FID":
104
+ rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz")
105
+ url = f"{base_url}/{rel_path}"
106
+ stats_folder = os.path.join(get_dir(), "fid_stats")
107
+ fpath = check_download_url(local_folder=stats_folder, url=url)
108
+ stats = np.load(fpath)
109
+ mu, sigma = stats["mu"], stats["sigma"]
110
+ return mu, sigma
111
+ elif metric == "KID":
112
+ rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz")
113
+ url = f"{base_url}/{rel_path}"
114
+ stats_folder = os.path.join(get_dir(), "fid_stats")
115
+ fpath = check_download_url(local_folder=stats_folder, url=url)
116
+ stats = np.load(fpath)
117
+ return stats["feats"]
kit/metrics/clean_fid/fid.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from tqdm.auto import tqdm
4
+ from glob import glob
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from scipy import linalg
9
+ import zipfile
10
+ from torch.hub import get_dir
11
+ from .utils import *
12
+ from .features import build_feature_extractor, get_reference_statistics
13
+ from .resize import *
14
+
15
+
16
+ """
17
+ Numpy implementation of the Frechet Distance.
18
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
19
+ and X_2 ~ N(mu_2, C_2) is
20
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
21
+ Stable version by Danica J. Sutherland.
22
+ Params:
23
+ mu1 : Numpy array containing the activations of a layer of the
24
+ inception net (like returned by the function 'get_predictions')
25
+ for generated samples.
26
+ mu2 : The sample mean over activations, precalculated on an
27
+ representative data set.
28
+ sigma1: The covariance matrix over activations for generated samples.
29
+ sigma2: The covariance matrix over activations, precalculated on an
30
+ representative data set.
31
+ """
32
+
33
+
34
+ def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
35
+ mu1 = np.atleast_1d(mu1)
36
+ mu2 = np.atleast_1d(mu2)
37
+ sigma1 = np.atleast_2d(sigma1)
38
+ sigma2 = np.atleast_2d(sigma2)
39
+
40
+ assert (
41
+ mu1.shape == mu2.shape
42
+ ), "Training and test mean vectors have different lengths"
43
+ assert (
44
+ sigma1.shape == sigma2.shape
45
+ ), "Training and test covariances have different dimensions"
46
+
47
+ diff = mu1 - mu2
48
+
49
+ # Product might be almost singular
50
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
51
+ if not np.isfinite(covmean).all():
52
+ msg = (
53
+ "fid calculation produces singular product; "
54
+ "adding %s to diagonal of cov estimates"
55
+ ) % eps
56
+ print(msg)
57
+ offset = np.eye(sigma1.shape[0]) * eps
58
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
59
+
60
+ # Numerical error might give slight imaginary component
61
+ if np.iscomplexobj(covmean):
62
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
63
+ m = np.max(np.abs(covmean.imag))
64
+ raise ValueError("Imaginary component {}".format(m))
65
+ covmean = covmean.real
66
+
67
+ tr_covmean = np.trace(covmean)
68
+
69
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
70
+
71
+
72
+ """
73
+ Compute the KID score given the sets of features
74
+ """
75
+
76
+
77
+ def kernel_distance(feats1, feats2, num_subsets=100, max_subset_size=1000):
78
+ n = feats1.shape[1]
79
+ m = min(min(feats1.shape[0], feats2.shape[0]), max_subset_size)
80
+ t = 0
81
+ for _subset_idx in range(num_subsets):
82
+ x = feats2[np.random.choice(feats2.shape[0], m, replace=False)]
83
+ y = feats1[np.random.choice(feats1.shape[0], m, replace=False)]
84
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
85
+ b = (x @ y.T / n + 1) ** 3
86
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
87
+ kid = t / num_subsets / m
88
+ return float(kid)
89
+
90
+
91
+ """
92
+ Compute the inception features for a batch of images
93
+ """
94
+
95
+
96
+ def get_batch_features(batch, model, device):
97
+ with torch.no_grad():
98
+ feat = model(batch.to(device))
99
+ return feat.detach().cpu().numpy()
100
+
101
+
102
+ """
103
+ Compute the inception features for a list of files
104
+ """
105
+
106
+
107
+ def get_files_features(
108
+ l_files,
109
+ model=None,
110
+ num_workers=12,
111
+ batch_size=128,
112
+ device=torch.device("cuda"),
113
+ mode="clean",
114
+ custom_fn_resize=None,
115
+ description="",
116
+ fdir=None,
117
+ verbose=True,
118
+ custom_image_tranform=None,
119
+ ):
120
+ # wrap the images in a dataloader for parallelizing the resize operation
121
+ dataset = ResizeDataset(l_files, fdir=fdir, mode=mode)
122
+ if custom_image_tranform is not None:
123
+ dataset.custom_image_tranform = custom_image_tranform
124
+ if custom_fn_resize is not None:
125
+ dataset.fn_resize = custom_fn_resize
126
+
127
+ dataloader = torch.utils.data.DataLoader(
128
+ dataset,
129
+ batch_size=batch_size,
130
+ shuffle=False,
131
+ drop_last=False,
132
+ num_workers=num_workers,
133
+ )
134
+
135
+ # collect all inception features
136
+ l_feats = []
137
+ if verbose:
138
+ pbar = tqdm(dataloader, desc=description)
139
+ else:
140
+ pbar = dataloader
141
+
142
+ for batch in pbar:
143
+ l_feats.append(get_batch_features(batch, model, device))
144
+ np_feats = np.concatenate(l_feats)
145
+ return np_feats
146
+
147
+
148
+ """
149
+ Compute the inception features for a folder of image files
150
+ """
151
+
152
+
153
+ def get_folder_features(
154
+ fdir,
155
+ model=None,
156
+ num_workers=12,
157
+ num=None,
158
+ shuffle=False,
159
+ seed=0,
160
+ batch_size=128,
161
+ device=torch.device("cuda"),
162
+ mode="clean",
163
+ custom_fn_resize=None,
164
+ description="",
165
+ verbose=True,
166
+ custom_image_tranform=None,
167
+ ):
168
+ # get all relevant files in the dataset
169
+ if ".zip" in fdir:
170
+ files = list(set(zipfile.ZipFile(fdir).namelist()))
171
+ # remove the non-image files inside the zip
172
+ files = [x for x in files if os.path.splitext(x)[1].lower()[1:] in EXTENSIONS]
173
+ else:
174
+ files = sorted(
175
+ [
176
+ file
177
+ for ext in EXTENSIONS
178
+ for file in glob(os.path.join(fdir, f"**/*.{ext}"), recursive=True)
179
+ ]
180
+ )
181
+ # use a subset number of files if needed
182
+ if num is not None:
183
+ if shuffle:
184
+ random.seed(seed)
185
+ random.shuffle(files)
186
+ files = files[:num]
187
+ np_feats = get_files_features(
188
+ files,
189
+ model,
190
+ num_workers=num_workers,
191
+ batch_size=batch_size,
192
+ device=device,
193
+ mode=mode,
194
+ custom_fn_resize=custom_fn_resize,
195
+ custom_image_tranform=custom_image_tranform,
196
+ description=description,
197
+ fdir=fdir,
198
+ verbose=verbose,
199
+ )
200
+ return np_feats
201
+
202
+
203
+ """
204
+ Compute the FID score given the inception features stack
205
+ """
206
+
207
+
208
+ def fid_from_feats(feats1, feats2):
209
+ mu1, sig1 = np.mean(feats1, axis=0), np.cov(feats1, rowvar=False)
210
+ mu2, sig2 = np.mean(feats2, axis=0), np.cov(feats2, rowvar=False)
211
+ return frechet_distance(mu1, sig1, mu2, sig2)
212
+
213
+
214
+ """
215
+ Computes the FID score for a folder of images for a specific dataset
216
+ and a specific resolution
217
+ """
218
+
219
+
220
+ def fid_folder(
221
+ fdir,
222
+ dataset_name,
223
+ dataset_res,
224
+ dataset_split,
225
+ model=None,
226
+ mode="clean",
227
+ model_name="inception_v3",
228
+ num_workers=12,
229
+ batch_size=128,
230
+ device=torch.device("cuda"),
231
+ verbose=True,
232
+ custom_image_tranform=None,
233
+ custom_fn_resize=None,
234
+ ):
235
+ # Load reference FID statistics (download if needed)
236
+ ref_mu, ref_sigma = get_reference_statistics(
237
+ dataset_name,
238
+ dataset_res,
239
+ mode=mode,
240
+ model_name=model_name,
241
+ seed=0,
242
+ split=dataset_split,
243
+ )
244
+ fbname = os.path.basename(fdir)
245
+ # get all inception features for folder images
246
+ np_feats = get_folder_features(
247
+ fdir,
248
+ model,
249
+ num_workers=num_workers,
250
+ batch_size=batch_size,
251
+ device=device,
252
+ mode=mode,
253
+ description=f"FID {fbname} : ",
254
+ verbose=verbose,
255
+ custom_image_tranform=custom_image_tranform,
256
+ custom_fn_resize=custom_fn_resize,
257
+ )
258
+ mu = np.mean(np_feats, axis=0)
259
+ sigma = np.cov(np_feats, rowvar=False)
260
+ fid = frechet_distance(mu, sigma, ref_mu, ref_sigma)
261
+ return fid
262
+
263
+
264
+ """
265
+ Compute the FID stats from a generator model
266
+ """
267
+
268
+
269
+ def get_model_features(
270
+ G,
271
+ model,
272
+ mode="clean",
273
+ z_dim=512,
274
+ num_gen=50_000,
275
+ batch_size=128,
276
+ device=torch.device("cuda"),
277
+ desc="FID model: ",
278
+ verbose=True,
279
+ return_z=False,
280
+ custom_image_tranform=None,
281
+ custom_fn_resize=None,
282
+ ):
283
+ if custom_fn_resize is None:
284
+ fn_resize = build_resizer(mode)
285
+ else:
286
+ fn_resize = custom_fn_resize
287
+
288
+ # Generate test features
289
+ num_iters = int(np.ceil(num_gen / batch_size))
290
+ l_feats = []
291
+ latents = []
292
+ if verbose:
293
+ pbar = tqdm(range(num_iters), desc=desc)
294
+ else:
295
+ pbar = range(num_iters)
296
+ for idx in pbar:
297
+ with torch.no_grad():
298
+ z_batch = torch.randn((batch_size, z_dim)).to(device)
299
+ if return_z:
300
+ latents.append(z_batch)
301
+ # generated image is in range [0,255]
302
+ img_batch = G(z_batch)
303
+ # split into individual batches for resizing if needed
304
+ if mode != "legacy_tensorflow":
305
+ l_resized_batch = []
306
+ for idx in range(batch_size):
307
+ curr_img = img_batch[idx]
308
+ img_np = curr_img.cpu().numpy().transpose((1, 2, 0))
309
+ if custom_image_tranform is not None:
310
+ img_np = custom_image_tranform(img_np)
311
+ img_resize = fn_resize(img_np)
312
+ l_resized_batch.append(
313
+ torch.tensor(img_resize.transpose((2, 0, 1))).unsqueeze(0)
314
+ )
315
+ resized_batch = torch.cat(l_resized_batch, dim=0)
316
+ else:
317
+ resized_batch = img_batch
318
+ feat = get_batch_features(resized_batch, model, device)
319
+ l_feats.append(feat)
320
+ np_feats = np.concatenate(l_feats)[:num_gen]
321
+ if return_z:
322
+ latents = torch.cat(latents, 0)
323
+ return np_feats, latents
324
+ return np_feats
325
+
326
+
327
+ """
328
+ Computes the FID score for a generator model for a specific dataset
329
+ and a specific resolution
330
+ """
331
+
332
+
333
+ def fid_model(
334
+ G,
335
+ dataset_name,
336
+ dataset_res,
337
+ dataset_split,
338
+ model=None,
339
+ model_name="inception_v3",
340
+ z_dim=512,
341
+ num_gen=50_000,
342
+ mode="clean",
343
+ num_workers=0,
344
+ batch_size=128,
345
+ device=torch.device("cuda"),
346
+ verbose=True,
347
+ custom_image_tranform=None,
348
+ custom_fn_resize=None,
349
+ ):
350
+ # Load reference FID statistics (download if needed)
351
+ ref_mu, ref_sigma = get_reference_statistics(
352
+ dataset_name,
353
+ dataset_res,
354
+ mode=mode,
355
+ model_name=model_name,
356
+ seed=0,
357
+ split=dataset_split,
358
+ )
359
+ # Generate features of images generated by the model
360
+ np_feats = get_model_features(
361
+ G,
362
+ model,
363
+ mode=mode,
364
+ z_dim=z_dim,
365
+ num_gen=num_gen,
366
+ batch_size=batch_size,
367
+ device=device,
368
+ verbose=verbose,
369
+ custom_image_tranform=custom_image_tranform,
370
+ custom_fn_resize=custom_fn_resize,
371
+ )
372
+ mu = np.mean(np_feats, axis=0)
373
+ sigma = np.cov(np_feats, rowvar=False)
374
+ fid = frechet_distance(mu, sigma, ref_mu, ref_sigma)
375
+ return fid
376
+
377
+
378
+ """
379
+ Computes the FID score between the two given folders
380
+ """
381
+
382
+
383
+ def compare_folders(
384
+ fdir1,
385
+ fdir2,
386
+ feat_model,
387
+ mode,
388
+ num_workers=0,
389
+ batch_size=8,
390
+ device=torch.device("cuda"),
391
+ verbose=True,
392
+ custom_image_tranform=None,
393
+ custom_fn_resize=None,
394
+ ):
395
+ # get all inception features for the first folder
396
+ fbname1 = os.path.basename(fdir1)
397
+ np_feats1 = get_folder_features(
398
+ fdir1,
399
+ feat_model,
400
+ num_workers=num_workers,
401
+ batch_size=batch_size,
402
+ device=device,
403
+ mode=mode,
404
+ description=f"FID {fbname1} : ",
405
+ verbose=verbose,
406
+ custom_image_tranform=custom_image_tranform,
407
+ custom_fn_resize=custom_fn_resize,
408
+ )
409
+ mu1 = np.mean(np_feats1, axis=0)
410
+ sigma1 = np.cov(np_feats1, rowvar=False)
411
+ # get all inception features for the second folder
412
+ fbname2 = os.path.basename(fdir2)
413
+ np_feats2 = get_folder_features(
414
+ fdir2,
415
+ feat_model,
416
+ num_workers=num_workers,
417
+ batch_size=batch_size,
418
+ device=device,
419
+ mode=mode,
420
+ description=f"FID {fbname2} : ",
421
+ verbose=verbose,
422
+ custom_image_tranform=custom_image_tranform,
423
+ custom_fn_resize=custom_fn_resize,
424
+ )
425
+ mu2 = np.mean(np_feats2, axis=0)
426
+ sigma2 = np.cov(np_feats2, rowvar=False)
427
+ fid = frechet_distance(mu1, sigma1, mu2, sigma2)
428
+ return fid
429
+
430
+
431
+ """
432
+ Test if a custom statistic exists
433
+ """
434
+
435
+
436
+ def test_stats_exists(name, mode, model_name="inception_v3", metric="FID"):
437
+ stats_folder = os.path.join(get_dir(), "fid_stats")
438
+ split, res = "custom", "na"
439
+ if model_name == "inception_v3":
440
+ model_modifier = ""
441
+ else:
442
+ model_modifier = "_" + model_name
443
+ if metric == "FID":
444
+ fname = f"{name}_{mode}{model_modifier}_{split}_{res}.npz"
445
+ elif metric == "KID":
446
+ fname = f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz"
447
+ fpath = os.path.join(stats_folder, fname)
448
+ return os.path.exists(fpath)
449
+
450
+
451
+ """
452
+ Remove the custom FID features from the stats folder
453
+ """
454
+
455
+
456
+ def remove_custom_stats(name, mode="clean", model_name="inception_v3"):
457
+ stats_folder = os.path.join(get_dir(), "fid_stats")
458
+ # remove the FID stats
459
+ split, res = "custom", "na"
460
+ if model_name == "inception_v3":
461
+ model_modifier = ""
462
+ else:
463
+ model_modifier = "_" + model_name
464
+ outf = os.path.join(
465
+ stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}.npz"
466
+ )
467
+ if not os.path.exists(outf):
468
+ msg = f"The stats file {name} does not exist."
469
+ raise Exception(msg)
470
+ os.remove(outf)
471
+ # remove the KID stats
472
+ outf = os.path.join(
473
+ stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz"
474
+ )
475
+ if not os.path.exists(outf):
476
+ msg = f"The stats file {name} does not exist."
477
+ raise Exception(msg)
478
+ os.remove(outf)
479
+
480
+
481
+ """
482
+ Cache a custom dataset statistics file
483
+ """
484
+
485
+
486
+ def make_custom_stats(
487
+ name,
488
+ fdir,
489
+ num=None,
490
+ mode="clean",
491
+ model_name="inception_v3",
492
+ num_workers=0,
493
+ batch_size=64,
494
+ device=torch.device("cuda"),
495
+ verbose=True,
496
+ ):
497
+ stats_folder = os.path.join(get_dir(), "fid_stats")
498
+ os.makedirs(stats_folder, exist_ok=True)
499
+ split, res = "custom", "na"
500
+ if model_name == "inception_v3":
501
+ model_modifier = ""
502
+ else:
503
+ model_modifier = "_" + model_name
504
+ outf = os.path.join(
505
+ stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}.npz"
506
+ )
507
+ # if the custom stat file already exists
508
+ if os.path.exists(outf):
509
+ msg = f"The statistics file {name} already exists. "
510
+ msg += "Use remove_custom_stats function to delete it first."
511
+ raise Exception(msg)
512
+ if model_name == "inception_v3":
513
+ feat_model = build_feature_extractor(mode, device)
514
+ custom_fn_resize = None
515
+ custom_image_tranform = None
516
+ elif model_name == "clip_vit_b_32":
517
+ from .clip_features import CLIP_fx, img_preprocess_clip
518
+
519
+ clip_fx = CLIP_fx("ViT-B/32")
520
+ feat_model = clip_fx
521
+ custom_fn_resize = img_preprocess_clip
522
+ custom_image_tranform = None
523
+ else:
524
+ raise ValueError(f"The entered model name - {model_name} was not recognized.")
525
+
526
+ # get all inception features for folder images
527
+ np_feats = get_folder_features(
528
+ fdir,
529
+ feat_model,
530
+ num_workers=num_workers,
531
+ num=num,
532
+ batch_size=batch_size,
533
+ device=device,
534
+ verbose=verbose,
535
+ mode=mode,
536
+ description=f"custom stats: {os.path.basename(fdir)} : ",
537
+ custom_image_tranform=custom_image_tranform,
538
+ custom_fn_resize=custom_fn_resize,
539
+ )
540
+
541
+ mu = np.mean(np_feats, axis=0)
542
+ sigma = np.cov(np_feats, rowvar=False)
543
+ # print(f"saving custom FID stats to {outf}")
544
+ np.savez_compressed(outf, mu=mu, sigma=sigma)
545
+
546
+ # KID stats
547
+ outf = os.path.join(
548
+ stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz"
549
+ )
550
+ # print(f"saving custom KID stats to {outf}")
551
+ np.savez_compressed(outf, feats=np_feats)
552
+
553
+
554
+ def compute_kid(
555
+ fdir1=None,
556
+ fdir2=None,
557
+ gen=None,
558
+ mode="clean",
559
+ num_workers=12,
560
+ batch_size=32,
561
+ device=torch.device("cuda"),
562
+ dataset_name="FFHQ",
563
+ dataset_res=1024,
564
+ dataset_split="train",
565
+ num_gen=50_000,
566
+ z_dim=512,
567
+ verbose=True,
568
+ use_dataparallel=True,
569
+ ):
570
+ # build the feature extractor based on the mode
571
+ feat_model = build_feature_extractor(
572
+ mode, device, use_dataparallel=use_dataparallel
573
+ )
574
+
575
+ # if both dirs are specified, compute KID between folders
576
+ if fdir1 is not None and fdir2 is not None:
577
+ # get all inception features for the first folder
578
+ fbname1 = os.path.basename(fdir1)
579
+ np_feats1 = get_folder_features(
580
+ fdir1,
581
+ feat_model,
582
+ num_workers=num_workers,
583
+ batch_size=batch_size,
584
+ device=device,
585
+ mode=mode,
586
+ description=f"KID {fbname1} : ",
587
+ verbose=verbose,
588
+ )
589
+ # get all inception features for the second folder
590
+ fbname2 = os.path.basename(fdir2)
591
+ np_feats2 = get_folder_features(
592
+ fdir2,
593
+ feat_model,
594
+ num_workers=num_workers,
595
+ batch_size=batch_size,
596
+ device=device,
597
+ mode=mode,
598
+ description=f"KID {fbname2} : ",
599
+ verbose=verbose,
600
+ )
601
+ score = kernel_distance(np_feats1, np_feats2)
602
+ return score
603
+
604
+ # compute kid of a folder
605
+ elif fdir1 is not None and fdir2 is None:
606
+ if verbose:
607
+ print(f"compute KID of a folder with {dataset_name} statistics")
608
+ ref_feats = get_reference_statistics(
609
+ dataset_name,
610
+ dataset_res,
611
+ mode=mode,
612
+ seed=0,
613
+ split=dataset_split,
614
+ metric="KID",
615
+ )
616
+ fbname = os.path.basename(fdir1)
617
+ # get all inception features for folder images
618
+ np_feats = get_folder_features(
619
+ fdir1,
620
+ feat_model,
621
+ num_workers=num_workers,
622
+ batch_size=batch_size,
623
+ device=device,
624
+ mode=mode,
625
+ description=f"KID {fbname} : ",
626
+ verbose=verbose,
627
+ )
628
+ score = kernel_distance(ref_feats, np_feats)
629
+ return score
630
+
631
+ # compute kid for a generator, using images in fdir2
632
+ elif gen is not None and fdir2 is not None:
633
+ if verbose:
634
+ print(f"compute KID of a model, using references in fdir2")
635
+ # get all inception features for the second folder
636
+ fbname2 = os.path.basename(fdir2)
637
+ ref_feats = get_folder_features(
638
+ fdir2,
639
+ feat_model,
640
+ num_workers=num_workers,
641
+ batch_size=batch_size,
642
+ device=device,
643
+ mode=mode,
644
+ description=f"KID {fbname2} : ",
645
+ )
646
+ # Generate test features
647
+ np_feats = get_model_features(
648
+ gen,
649
+ feat_model,
650
+ mode=mode,
651
+ z_dim=z_dim,
652
+ num_gen=num_gen,
653
+ desc="KID model: ",
654
+ batch_size=batch_size,
655
+ device=device,
656
+ )
657
+ score = kernel_distance(ref_feats, np_feats)
658
+ return score
659
+
660
+ # compute fid for a generator, using reference statistics
661
+ elif gen is not None:
662
+ if verbose:
663
+ print(
664
+ f"compute KID of a model with {dataset_name}-{dataset_res} statistics"
665
+ )
666
+ ref_feats = get_reference_statistics(
667
+ dataset_name,
668
+ dataset_res,
669
+ mode=mode,
670
+ seed=0,
671
+ split=dataset_split,
672
+ metric="KID",
673
+ )
674
+ # Generate test features
675
+ np_feats = get_model_features(
676
+ gen,
677
+ feat_model,
678
+ mode=mode,
679
+ z_dim=z_dim,
680
+ num_gen=num_gen,
681
+ desc="KID model: ",
682
+ batch_size=batch_size,
683
+ device=device,
684
+ verbose=verbose,
685
+ )
686
+ score = kernel_distance(ref_feats, np_feats)
687
+ return score
688
+
689
+ else:
690
+ raise ValueError("invalid combination of directories and models entered")
691
+
692
+
693
+ """
694
+ custom_image_tranform:
695
+ function that takes an np_array image as input [0,255] and
696
+ applies a custom transform such as cropping
697
+ """
698
+
699
+
700
+ def compute_fid(
701
+ fdir1=None,
702
+ fdir2=None,
703
+ gen=None,
704
+ mode="clean",
705
+ model_name="inception_v3",
706
+ num_workers=12,
707
+ batch_size=32,
708
+ device=torch.device("cuda"),
709
+ dataset_name="FFHQ",
710
+ dataset_res=1024,
711
+ dataset_split="train",
712
+ num_gen=50_000,
713
+ z_dim=512,
714
+ custom_feat_extractor=None,
715
+ verbose=True,
716
+ custom_image_tranform=None,
717
+ custom_fn_resize=None,
718
+ use_dataparallel=True,
719
+ ):
720
+ # build the feature extractor based on the mode and the model to be used
721
+ if custom_feat_extractor is None and model_name == "inception_v3":
722
+ feat_model = build_feature_extractor(
723
+ mode, device, use_dataparallel=use_dataparallel
724
+ )
725
+ elif custom_feat_extractor is None and model_name == "clip_vit_b_32":
726
+ from .clip_features import CLIP_fx, img_preprocess_clip
727
+
728
+ clip_fx = CLIP_fx("ViT-B/32", device=device)
729
+ feat_model = clip_fx
730
+ custom_fn_resize = img_preprocess_clip
731
+ else:
732
+ feat_model = custom_feat_extractor
733
+
734
+ # if both dirs are specified, compute FID between folders
735
+ if fdir1 is not None and fdir2 is not None:
736
+ score = compare_folders(
737
+ fdir1,
738
+ fdir2,
739
+ feat_model,
740
+ mode=mode,
741
+ batch_size=batch_size,
742
+ num_workers=num_workers,
743
+ device=device,
744
+ custom_image_tranform=custom_image_tranform,
745
+ custom_fn_resize=custom_fn_resize,
746
+ verbose=verbose,
747
+ )
748
+ return score
749
+
750
+ # compute fid of a folder
751
+ elif fdir1 is not None and fdir2 is None:
752
+ if verbose:
753
+ print(f"compute FID of a folder with {dataset_name} statistics")
754
+ score = fid_folder(
755
+ fdir1,
756
+ dataset_name,
757
+ dataset_res,
758
+ dataset_split,
759
+ model=feat_model,
760
+ mode=mode,
761
+ model_name=model_name,
762
+ custom_fn_resize=custom_fn_resize,
763
+ custom_image_tranform=custom_image_tranform,
764
+ num_workers=num_workers,
765
+ batch_size=batch_size,
766
+ device=device,
767
+ verbose=verbose,
768
+ )
769
+ return score
770
+
771
+ # compute fid for a generator, using images in fdir2
772
+ elif gen is not None and fdir2 is not None:
773
+ if verbose:
774
+ print(f"compute FID of a model, using references in fdir2")
775
+ # get all inception features for the second folder
776
+ fbname2 = os.path.basename(fdir2)
777
+ np_feats2 = get_folder_features(
778
+ fdir2,
779
+ feat_model,
780
+ num_workers=num_workers,
781
+ batch_size=batch_size,
782
+ device=device,
783
+ mode=mode,
784
+ description=f"FID {fbname2} : ",
785
+ verbose=verbose,
786
+ custom_fn_resize=custom_fn_resize,
787
+ custom_image_tranform=custom_image_tranform,
788
+ )
789
+ mu2 = np.mean(np_feats2, axis=0)
790
+ sigma2 = np.cov(np_feats2, rowvar=False)
791
+ # Generate test features
792
+ np_feats = get_model_features(
793
+ gen,
794
+ feat_model,
795
+ mode=mode,
796
+ z_dim=z_dim,
797
+ num_gen=num_gen,
798
+ custom_fn_resize=custom_fn_resize,
799
+ custom_image_tranform=custom_image_tranform,
800
+ batch_size=batch_size,
801
+ device=device,
802
+ verbose=verbose,
803
+ )
804
+
805
+ mu = np.mean(np_feats, axis=0)
806
+ sigma = np.cov(np_feats, rowvar=False)
807
+ fid = frechet_distance(mu, sigma, mu2, sigma2)
808
+ return fid
809
+
810
+ # compute fid for a generator, using reference statistics
811
+ elif gen is not None:
812
+ if verbose:
813
+ print(
814
+ f"compute FID of a model with {dataset_name}-{dataset_res} statistics"
815
+ )
816
+ score = fid_model(
817
+ gen,
818
+ dataset_name,
819
+ dataset_res,
820
+ dataset_split,
821
+ model=feat_model,
822
+ model_name=model_name,
823
+ z_dim=z_dim,
824
+ num_gen=num_gen,
825
+ mode=mode,
826
+ num_workers=num_workers,
827
+ batch_size=batch_size,
828
+ custom_image_tranform=custom_image_tranform,
829
+ custom_fn_resize=custom_fn_resize,
830
+ device=device,
831
+ verbose=verbose,
832
+ )
833
+ return score
834
+
835
+ else:
836
+ raise ValueError("invalid combination of directories and models entered")
kit/metrics/clean_fid/inception_pytorch.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File from: https://github.com/mseitzer/pytorch-fid
3
+ """
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ import warnings
11
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
12
+
13
+ # Inception weights ported to Pytorch from
14
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
15
+ FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
16
+
17
+
18
+ class InceptionV3(nn.Module):
19
+ """Pretrained InceptionV3 network returning feature maps"""
20
+
21
+ # Index of default block of inception to return,
22
+ # corresponds to output of final average pooling
23
+ DEFAULT_BLOCK_INDEX = 3
24
+
25
+ # Maps feature dimensionality to their output blocks indices
26
+ BLOCK_INDEX_BY_DIM = {
27
+ 64: 0, # First max pooling features
28
+ 192: 1, # Second max pooling featurs
29
+ 768: 2, # Pre-aux classifier features
30
+ 2048: 3, # Final average pooling features
31
+ }
32
+
33
+ def __init__(
34
+ self,
35
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
36
+ resize_input=True,
37
+ normalize_input=True,
38
+ requires_grad=False,
39
+ use_fid_inception=True,
40
+ ):
41
+ """Build pretrained InceptionV3
42
+ Parameters
43
+ ----------
44
+ output_blocks : list of int
45
+ Indices of blocks to return features of. Possible values are:
46
+ - 0: corresponds to output of first max pooling
47
+ - 1: corresponds to output of second max pooling
48
+ - 2: corresponds to output which is fed to aux classifier
49
+ - 3: corresponds to output of final average pooling
50
+ resize_input : bool
51
+ If true, bilinearly resizes input to width and height 299 before
52
+ feeding input to model. As the network without fully connected
53
+ layers is fully convolutional, it should be able to handle inputs
54
+ of arbitrary size, so resizing might not be strictly needed
55
+ normalize_input : bool
56
+ If true, scales the input from range (0, 1) to the range the
57
+ pretrained Inception network expects, namely (-1, 1)
58
+ requires_grad : bool
59
+ If true, parameters of the model require gradients. Possibly useful
60
+ for finetuning the network
61
+ use_fid_inception : bool
62
+ If true, uses the pretrained Inception model used in Tensorflow's
63
+ FID implementation. If false, uses the pretrained Inception model
64
+ available in torchvision. The FID Inception model has different
65
+ weights and a slightly different structure from torchvision's
66
+ Inception model. If you want to compute FID scores, you are
67
+ strongly advised to set this parameter to true to get comparable
68
+ results.
69
+ """
70
+ super(InceptionV3, self).__init__()
71
+
72
+ self.resize_input = resize_input
73
+ self.normalize_input = normalize_input
74
+ self.output_blocks = sorted(output_blocks)
75
+ self.last_needed_block = max(output_blocks)
76
+
77
+ assert self.last_needed_block <= 3, "Last possible output block index is 3"
78
+
79
+ self.blocks = nn.ModuleList()
80
+
81
+ if use_fid_inception:
82
+ inception = fid_inception_v3()
83
+ else:
84
+ inception = _inception_v3(pretrained=True)
85
+
86
+ # Block 0: input to maxpool1
87
+ block0 = [
88
+ inception.Conv2d_1a_3x3,
89
+ inception.Conv2d_2a_3x3,
90
+ inception.Conv2d_2b_3x3,
91
+ nn.MaxPool2d(kernel_size=3, stride=2),
92
+ ]
93
+ self.blocks.append(nn.Sequential(*block0))
94
+
95
+ # Block 1: maxpool1 to maxpool2
96
+ if self.last_needed_block >= 1:
97
+ block1 = [
98
+ inception.Conv2d_3b_1x1,
99
+ inception.Conv2d_4a_3x3,
100
+ nn.MaxPool2d(kernel_size=3, stride=2),
101
+ ]
102
+ self.blocks.append(nn.Sequential(*block1))
103
+
104
+ # Block 2: maxpool2 to aux classifier
105
+ if self.last_needed_block >= 2:
106
+ block2 = [
107
+ inception.Mixed_5b,
108
+ inception.Mixed_5c,
109
+ inception.Mixed_5d,
110
+ inception.Mixed_6a,
111
+ inception.Mixed_6b,
112
+ inception.Mixed_6c,
113
+ inception.Mixed_6d,
114
+ inception.Mixed_6e,
115
+ ]
116
+ self.blocks.append(nn.Sequential(*block2))
117
+
118
+ # Block 3: aux classifier to final avgpool
119
+ if self.last_needed_block >= 3:
120
+ block3 = [
121
+ inception.Mixed_7a,
122
+ inception.Mixed_7b,
123
+ inception.Mixed_7c,
124
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
125
+ ]
126
+ self.blocks.append(nn.Sequential(*block3))
127
+
128
+ for param in self.parameters():
129
+ param.requires_grad = requires_grad
130
+
131
+ def forward(self, inp):
132
+ """Get Inception feature maps
133
+ Parameters
134
+ ----------
135
+ inp : torch.autograd.Variable
136
+ Input tensor of shape Bx3xHxW. Values are expected to be in
137
+ range (0, 1)
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ raise ValueError("should not resize here")
148
+ x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
149
+
150
+ if self.normalize_input:
151
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
152
+
153
+ for idx, block in enumerate(self.blocks):
154
+ x = block(x)
155
+ if idx in self.output_blocks:
156
+ outp.append(x)
157
+
158
+ if idx == self.last_needed_block:
159
+ break
160
+
161
+ return outp
162
+
163
+
164
+ def _inception_v3(*args, **kwargs):
165
+ """Wraps `torchvision.models.inception_v3`
166
+ Skips default weight inititialization if supported by torchvision version.
167
+ See https://github.com/mseitzer/pytorch-fid/issues/28.
168
+ """
169
+ warnings.filterwarnings("ignore")
170
+ try:
171
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
172
+ except ValueError:
173
+ # Just a caution against weird version strings
174
+ version = (0,)
175
+
176
+ if version >= (0, 6):
177
+ kwargs["init_weights"] = False
178
+
179
+ return torchvision.models.inception_v3(*args, **kwargs)
180
+
181
+
182
+ def fid_inception_v3():
183
+ """Build pretrained Inception model for FID computation
184
+ The Inception model for FID computation uses a different set of weights
185
+ and has a slightly different structure than torchvision's Inception.
186
+ This method first constructs torchvision's Inception and then patches the
187
+ necessary parts that are different in the FID Inception model.
188
+ """
189
+ inception = _inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
190
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
191
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
192
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
193
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
194
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
195
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
196
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
197
+ inception.Mixed_7b = FIDInceptionE_1(1280)
198
+ inception.Mixed_7c = FIDInceptionE_2(2048)
199
+
200
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
201
+ inception.load_state_dict(state_dict)
202
+ return inception
203
+
204
+
205
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
206
+ """InceptionA block patched for FID computation"""
207
+
208
+ def __init__(self, in_channels, pool_features):
209
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
210
+
211
+ def forward(self, x):
212
+ branch1x1 = self.branch1x1(x)
213
+
214
+ branch5x5 = self.branch5x5_1(x)
215
+ branch5x5 = self.branch5x5_2(branch5x5)
216
+
217
+ branch3x3dbl = self.branch3x3dbl_1(x)
218
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
219
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
220
+
221
+ # Patch: Tensorflow's average pool does not use the padded zero's in
222
+ # its average calculation
223
+ branch_pool = F.avg_pool2d(
224
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
225
+ )
226
+ branch_pool = self.branch_pool(branch_pool)
227
+
228
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
229
+ return torch.cat(outputs, 1)
230
+
231
+
232
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
233
+ """InceptionC block patched for FID computation"""
234
+
235
+ def __init__(self, in_channels, channels_7x7):
236
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
237
+
238
+ def forward(self, x):
239
+ branch1x1 = self.branch1x1(x)
240
+
241
+ branch7x7 = self.branch7x7_1(x)
242
+ branch7x7 = self.branch7x7_2(branch7x7)
243
+ branch7x7 = self.branch7x7_3(branch7x7)
244
+
245
+ branch7x7dbl = self.branch7x7dbl_1(x)
246
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
247
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
248
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
249
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
250
+
251
+ # Patch: Tensorflow's average pool does not use the padded zero's in
252
+ # its average calculation
253
+ branch_pool = F.avg_pool2d(
254
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
255
+ )
256
+ branch_pool = self.branch_pool(branch_pool)
257
+
258
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
259
+ return torch.cat(outputs, 1)
260
+
261
+
262
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
263
+ """First InceptionE block patched for FID computation"""
264
+
265
+ def __init__(self, in_channels):
266
+ super(FIDInceptionE_1, self).__init__(in_channels)
267
+
268
+ def forward(self, x):
269
+ branch1x1 = self.branch1x1(x)
270
+
271
+ branch3x3 = self.branch3x3_1(x)
272
+ branch3x3 = [
273
+ self.branch3x3_2a(branch3x3),
274
+ self.branch3x3_2b(branch3x3),
275
+ ]
276
+ branch3x3 = torch.cat(branch3x3, 1)
277
+
278
+ branch3x3dbl = self.branch3x3dbl_1(x)
279
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
280
+ branch3x3dbl = [
281
+ self.branch3x3dbl_3a(branch3x3dbl),
282
+ self.branch3x3dbl_3b(branch3x3dbl),
283
+ ]
284
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
285
+
286
+ # Patch: Tensorflow's average pool does not use the padded zero's in
287
+ # its average calculation
288
+ branch_pool = F.avg_pool2d(
289
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
290
+ )
291
+ branch_pool = self.branch_pool(branch_pool)
292
+
293
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294
+ return torch.cat(outputs, 1)
295
+
296
+
297
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298
+ """Second InceptionE block patched for FID computation"""
299
+
300
+ def __init__(self, in_channels):
301
+ super(FIDInceptionE_2, self).__init__(in_channels)
302
+
303
+ def forward(self, x):
304
+ branch1x1 = self.branch1x1(x)
305
+
306
+ branch3x3 = self.branch3x3_1(x)
307
+ branch3x3 = [
308
+ self.branch3x3_2a(branch3x3),
309
+ self.branch3x3_2b(branch3x3),
310
+ ]
311
+ branch3x3 = torch.cat(branch3x3, 1)
312
+
313
+ branch3x3dbl = self.branch3x3dbl_1(x)
314
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
315
+ branch3x3dbl = [
316
+ self.branch3x3dbl_3a(branch3x3dbl),
317
+ self.branch3x3dbl_3b(branch3x3dbl),
318
+ ]
319
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
320
+
321
+ # Patch: The FID Inception model uses max pooling instead of average
322
+ # pooling. This is likely an error in this specific Inception
323
+ # implementation, as other Inception models use average pooling here
324
+ # (which matches the description in the paper).
325
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
326
+ branch_pool = self.branch_pool(branch_pool)
327
+
328
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
329
+ return torch.cat(outputs, 1)
kit/metrics/clean_fid/inception_torchscript.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import contextlib
5
+ from .downloads_helper import *
6
+
7
+
8
+ @contextlib.contextmanager
9
+ def disable_gpu_fuser_on_pt19():
10
+ # On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See
11
+ # https://github.com/GaParmar/clean-fid/issues/5
12
+ # https://github.com/pytorch/pytorch/issues/64062
13
+ if torch.__version__.startswith("1.9."):
14
+ old_val = torch._C._jit_can_fuse_on_gpu()
15
+ torch._C._jit_override_can_fuse_on_gpu(False)
16
+ yield
17
+ if torch.__version__.startswith("1.9."):
18
+ torch._C._jit_override_can_fuse_on_gpu(old_val)
19
+
20
+
21
+ class InceptionV3W(nn.Module):
22
+ """
23
+ Wrapper around Inception V3 torchscript model provided here
24
+ https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt
25
+
26
+ path: locally saved inception weights
27
+ """
28
+
29
+ def __init__(self, path, download=True, resize_inside=False):
30
+ super(InceptionV3W, self).__init__()
31
+ # download the network if it is not present at the given directory
32
+ # use the current directory by default
33
+ if download:
34
+ check_download_inception(fpath=path)
35
+ path = os.path.join(path, "inception-2015-12-05.pt")
36
+ self.base = torch.jit.load(path).eval()
37
+ self.layers = self.base.layers
38
+ self.resize_inside = resize_inside
39
+
40
+ """
41
+ Get the inception features without resizing
42
+ x: Image with values in range [0,255]
43
+ """
44
+
45
+ def forward(self, x):
46
+ with disable_gpu_fuser_on_pt19():
47
+ bs = x.shape[0]
48
+ if self.resize_inside:
49
+ features = self.base(x, return_features=True).view((bs, 2048))
50
+ else:
51
+ # make sure it is resized already
52
+ assert (x.shape[2] == 299) and (x.shape[3] == 299)
53
+ # apply normalization
54
+ x1 = x - 128
55
+ x2 = x1 / 128
56
+ features = self.layers.forward(
57
+ x2,
58
+ ).view((bs, 2048))
59
+ return features
kit/metrics/clean_fid/leaderboard.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import shutil
4
+ import urllib.request
5
+
6
+
7
+ def get_score(
8
+ model_name=None,
9
+ dataset_name=None,
10
+ dataset_res=None,
11
+ dataset_split=None,
12
+ task_name=None,
13
+ ):
14
+ # download the csv file from server
15
+ url = "https://www.cs.cmu.edu/~clean-fid/files/leaderboard.csv"
16
+ local_path = "/tmp/leaderboard.csv"
17
+ with urllib.request.urlopen(url) as response, open(local_path, "wb") as f:
18
+ shutil.copyfileobj(response, f)
19
+
20
+ d_field2idx = {}
21
+ l_matches = []
22
+ with open(local_path, "r") as f:
23
+ csvreader = csv.reader(f)
24
+ l_fields = next(csvreader)
25
+ for idx, val in enumerate(l_fields):
26
+ d_field2idx[val.strip()] = idx
27
+ # iterate through all rows
28
+ for row in csvreader:
29
+ # skip empty rows
30
+ if len(row) == 0:
31
+ continue
32
+ # skip if the filter doesn't match
33
+ if model_name is not None and (
34
+ row[d_field2idx["model_name"]].strip() != model_name
35
+ ):
36
+ continue
37
+ if dataset_name is not None and (
38
+ row[d_field2idx["dataset_name"]].strip() != dataset_name
39
+ ):
40
+ continue
41
+ if dataset_res is not None and (
42
+ row[d_field2idx["dataset_res"]].strip() != dataset_res
43
+ ):
44
+ continue
45
+ if dataset_split is not None and (
46
+ row[d_field2idx["dataset_split"]].strip() != dataset_split
47
+ ):
48
+ continue
49
+ if task_name is not None and (
50
+ row[d_field2idx["task_name"]].strip() != task_name
51
+ ):
52
+ continue
53
+ curr = {}
54
+ for f in l_fields:
55
+ curr[f.strip()] = row[d_field2idx[f.strip()]].strip()
56
+ l_matches.append(curr)
57
+ os.remove(local_path)
58
+ return l_matches
kit/metrics/clean_fid/resize.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for resizing with multiple CPU cores
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def build_resizer(mode):
12
+ if mode == "clean":
13
+ return make_resizer("PIL", False, "bicubic", (299, 299))
14
+ # if using legacy tensorflow, do not manually resize outside the network
15
+ elif mode == "legacy_tensorflow":
16
+ return lambda x: x
17
+ elif mode == "legacy_pytorch":
18
+ return make_resizer("PyTorch", False, "bilinear", (299, 299))
19
+ else:
20
+ raise ValueError(f"Invalid mode {mode} specified")
21
+
22
+
23
+ """
24
+ Construct a function that resizes a numpy image based on the
25
+ flags passed in.
26
+ """
27
+
28
+
29
+ def make_resizer(library, quantize_after, filter, output_size):
30
+ if library == "PIL" and quantize_after:
31
+ name_to_filter = {
32
+ "bicubic": Image.BICUBIC,
33
+ "bilinear": Image.BILINEAR,
34
+ "nearest": Image.NEAREST,
35
+ "lanczos": Image.LANCZOS,
36
+ "box": Image.BOX,
37
+ }
38
+
39
+ def func(x):
40
+ x = Image.fromarray(x)
41
+ x = x.resize(output_size, resample=name_to_filter[filter])
42
+ x = np.asarray(x).clip(0, 255).astype(np.uint8)
43
+ return x
44
+
45
+ elif library == "PIL" and not quantize_after:
46
+ name_to_filter = {
47
+ "bicubic": Image.BICUBIC,
48
+ "bilinear": Image.BILINEAR,
49
+ "nearest": Image.NEAREST,
50
+ "lanczos": Image.LANCZOS,
51
+ "box": Image.BOX,
52
+ }
53
+ s1, s2 = output_size
54
+
55
+ def resize_single_channel(x_np):
56
+ img = Image.fromarray(x_np.astype(np.float32), mode="F")
57
+ img = img.resize(output_size, resample=name_to_filter[filter])
58
+ return np.asarray(img).clip(0, 255).reshape(s2, s1, 1)
59
+
60
+ def func(x):
61
+ x = [resize_single_channel(x[:, :, idx]) for idx in range(3)]
62
+ x = np.concatenate(x, axis=2).astype(np.float32)
63
+ return x
64
+
65
+ elif library == "PyTorch":
66
+ import warnings
67
+
68
+ # ignore the numpy warnings
69
+ warnings.filterwarnings("ignore")
70
+
71
+ def func(x):
72
+ x = torch.Tensor(x.transpose((2, 0, 1)))[None, ...]
73
+ x = F.interpolate(x, size=output_size, mode=filter, align_corners=False)
74
+ x = x[0, ...].cpu().data.numpy().transpose((1, 2, 0)).clip(0, 255)
75
+ if quantize_after:
76
+ x = x.astype(np.uint8)
77
+ return x
78
+
79
+ else:
80
+ raise NotImplementedError("library [%s] is not include" % library)
81
+ return func
82
+
83
+
84
+ class FolderResizer(torch.utils.data.Dataset):
85
+ def __init__(self, files, outpath, fn_resize, output_ext=".png"):
86
+ self.files = files
87
+ self.outpath = outpath
88
+ self.output_ext = output_ext
89
+ self.fn_resize = fn_resize
90
+
91
+ def __len__(self):
92
+ return len(self.files)
93
+
94
+ def __getitem__(self, i):
95
+ path = str(self.files[i])
96
+ img_np = np.asarray(Image.open(path))
97
+ img_resize_np = self.fn_resize(img_np)
98
+ # swap the output extension
99
+ basename = os.path.basename(path).split(".")[0] + self.output_ext
100
+ outname = os.path.join(self.outpath, basename)
101
+ if self.output_ext == ".npy":
102
+ np.save(outname, img_resize_np)
103
+ elif self.output_ext == ".png":
104
+ img_resized_pil = Image.fromarray(img_resize_np)
105
+ img_resized_pil.save(outname)
106
+ else:
107
+ raise ValueError("invalid output extension")
108
+ return 0
kit/metrics/clean_fid/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision
4
+ from PIL import Image
5
+ import zipfile
6
+ from .resize import build_resizer
7
+
8
+
9
+ class ResizeDataset(torch.utils.data.Dataset):
10
+ """
11
+ A placeholder Dataset that enables parallelizing the resize operation
12
+ using multiple CPU cores
13
+
14
+ files: list of all files in the folder
15
+ fn_resize: function that takes an np_array as input [0,255]
16
+ """
17
+
18
+ def __init__(self, files, mode, size=(299, 299), fdir=None):
19
+ self.files = files
20
+ self.fdir = fdir
21
+ self.transforms = torchvision.transforms.ToTensor()
22
+ self.size = size
23
+ self.fn_resize = build_resizer(mode)
24
+ self.custom_image_tranform = lambda x: x
25
+ self._zipfile = None
26
+
27
+ def _get_zipfile(self):
28
+ assert self.fdir is not None and ".zip" in self.fdir
29
+ if self._zipfile is None:
30
+ self._zipfile = zipfile.ZipFile(self.fdir)
31
+ return self._zipfile
32
+
33
+ def __len__(self):
34
+ return len(self.files)
35
+
36
+ def __getitem__(self, i):
37
+ path = str(self.files[i])
38
+ if self.fdir is not None and ".zip" in self.fdir:
39
+ with self._get_zipfile().open(path, "r") as f:
40
+ img_np = np.array(Image.open(f).convert("RGB"))
41
+ elif ".npy" in path:
42
+ img_np = np.load(path)
43
+ else:
44
+ img_pil = Image.open(path).convert("RGB")
45
+ img_np = np.array(img_pil)
46
+
47
+ # apply a custom image transform before resizing the image to 299x299
48
+ img_np = self.custom_image_tranform(img_np)
49
+ # fn_resize expects a np array and returns a np array
50
+ img_resized = self.fn_resize(img_np)
51
+
52
+ # ToTensor() converts to [0,1] only if input in uint8
53
+ if img_resized.dtype == "uint8":
54
+ img_t = self.transforms(np.array(img_resized)) * 255
55
+ elif img_resized.dtype == "float32":
56
+ img_t = self.transforms(img_resized)
57
+
58
+ return img_t
59
+
60
+
61
+ EXTENSIONS = {
62
+ "bmp",
63
+ "jpg",
64
+ "jpeg",
65
+ "pgm",
66
+ "png",
67
+ "ppm",
68
+ "tif",
69
+ "tiff",
70
+ "webp",
71
+ "npy",
72
+ "JPEG",
73
+ "JPG",
74
+ "PNG",
75
+ }
kit/metrics/clean_fid/wrappers.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ from .features import build_feature_extractor, get_reference_statistics
5
+ from .fid import get_batch_features, fid_from_feats
6
+ from .resize import build_resizer
7
+
8
+
9
+ """
10
+ A helper class that allowing adding the images one batch at a time.
11
+ """
12
+
13
+
14
+ class CleanFID:
15
+ def __init__(self, mode="clean", model_name="inception_v3", device="cuda"):
16
+ self.real_features = []
17
+ self.gen_features = []
18
+ self.mode = mode
19
+ self.device = device
20
+ if model_name == "inception_v3":
21
+ self.feat_model = build_feature_extractor(mode, device)
22
+ self.fn_resize = build_resizer(mode)
23
+ elif model_name == "clip_vit_b_32":
24
+ from .clip_features import CLIP_fx, img_preprocess_clip
25
+
26
+ clip_fx = CLIP_fx("ViT-B/32")
27
+ self.feat_model = clip_fx
28
+ self.fn_resize = img_preprocess_clip
29
+
30
+ """
31
+ Funtion that takes an image (PIL.Image or np.array or torch.tensor)
32
+ and returns the corresponding feature embedding vector.
33
+ The image x is expected to be in range [0, 255]
34
+ """
35
+
36
+ def compute_features(self, x):
37
+ # if x is a PIL Image
38
+ if isinstance(x, Image.Image):
39
+ x_np = np.array(x)
40
+ x_np_resized = self.fn_resize(x_np)
41
+ x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0)
42
+ x_feat = get_batch_features(x_t, self.feat_model, self.device)
43
+ elif isinstance(x, np.ndarray):
44
+ x_np_resized = self.fn_resize(x)
45
+ x_t = (
46
+ torch.tensor(x_np_resized.transpose((2, 0, 1)))
47
+ .unsqueeze(0)
48
+ .to(self.device)
49
+ )
50
+ # normalization happens inside the self.feat_model, expected image range here is [0,255]
51
+ x_feat = get_batch_features(x_t, self.feat_model, self.device)
52
+ elif isinstance(x, torch.Tensor):
53
+ # pdb.set_trace()
54
+ # add the batch dimension if x is passed in as C,H,W
55
+ if len(x.shape) == 3:
56
+ x = x.unsqueeze(0)
57
+ b, c, h, w = x.shape
58
+ # convert back to np array and resize
59
+ l_x_np_resized = []
60
+ for _ in range(b):
61
+ x_np = x[_].cpu().numpy().transpose((1, 2, 0))
62
+ l_x_np_resized.append(self.fn_resize(x_np)[None,])
63
+ x_np_resized = np.concatenate(l_x_np_resized)
64
+ x_t = torch.tensor(x_np_resized.transpose((0, 3, 1, 2))).to(self.device)
65
+ # normalization happens inside the self.feat_model, expected image range here is [0,255]
66
+ x_feat = get_batch_features(x_t, self.feat_model, self.device)
67
+ else:
68
+ raise ValueError("image type could not be inferred")
69
+ return x_feat
70
+
71
+ """
72
+ Extract the faetures from x and add to the list of reference real images
73
+ """
74
+
75
+ def add_real_images(self, x):
76
+ x_feat = self.compute_features(x)
77
+ self.real_features.append(x_feat)
78
+
79
+ """
80
+ Extract the faetures from x and add to the list of generated images
81
+ """
82
+
83
+ def add_gen_images(self, x):
84
+ x_feat = self.compute_features(x)
85
+ self.gen_features.append(x_feat)
86
+
87
+ """
88
+ Compute FID between the real and generated images added so far
89
+ """
90
+
91
+ def calculate_fid(self, verbose=True):
92
+ feats1 = np.concatenate(self.real_features)
93
+ feats2 = np.concatenate(self.gen_features)
94
+ if verbose:
95
+ print(f"# real images = {feats1.shape[0]}")
96
+ print(f"# generated images = {feats2.shape[0]}")
97
+ return fid_from_feats(feats1, feats2)
98
+
99
+ """
100
+ Remove the real image features added so far
101
+ """
102
+
103
+ def reset_real_features(self):
104
+ self.real_features = []
105
+
106
+ """
107
+ Remove the generated image features added so far
108
+ """
109
+
110
+ def reset_gen_features(self):
111
+ self.gen_features = []
kit/metrics/clip.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import open_clip
4
+
5
+
6
+ def load_open_clip_model_preprocess_and_tokenizer(device=torch.device("cuda")):
7
+ clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
8
+ "ViT-g-14", pretrained="laion2b_s12b_b42k", device=device
9
+ )
10
+ clip_tokenizer = open_clip.get_tokenizer("ViT-g-14")
11
+ return clip_model, clip_preprocess, clip_tokenizer
12
+
13
+
14
+ def compute_clip_score(
15
+ images,
16
+ prompts,
17
+ models,
18
+ device=torch.device("cuda"),
19
+ ):
20
+ clip_model, clip_preprocess, clip_tokenizer = models
21
+ with torch.no_grad():
22
+ tensors = [clip_preprocess(image) for image in images]
23
+ image_processed_tensor = torch.stack(tensors, 0).to(device)
24
+ image_features = clip_model.encode_image(image_processed_tensor)
25
+
26
+ encoding = clip_tokenizer(prompts).to(device)
27
+ text_features = clip_model.encode_text(encoding)
28
+
29
+ image_features /= image_features.norm(dim=-1, keepdim=True)
30
+ text_features /= text_features.norm(dim=-1, keepdim=True)
31
+
32
+ return (image_features @ text_features.T).mean(-1).cpu().numpy().tolist()
kit/metrics/distributional.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import torch
4
+ from PIL import Image
5
+ from tqdm.auto import tqdm
6
+ from concurrent.futures import ProcessPoolExecutor
7
+ from functools import partial
8
+ from PIL import Image
9
+ from .clean_fid import fid
10
+
11
+
12
+ def save_single_image_to_temp(i, image, temp_dir):
13
+ save_path = os.path.join(temp_dir, f"{i}.png")
14
+ image.save(save_path, "PNG")
15
+
16
+
17
+ def save_images_to_temp(images, num_workers, verbose=False):
18
+ assert isinstance(images, list) and isinstance(images[0], Image.Image)
19
+ temp_dir = tempfile.mkdtemp()
20
+
21
+ # Using ProcessPoolExecutor to save images in parallel
22
+ func = partial(save_single_image_to_temp, temp_dir=temp_dir)
23
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
24
+ tasks = executor.map(func, range(len(images)), images)
25
+ list(tasks) if not verbose else list(
26
+ tqdm(
27
+ tasks,
28
+ total=len(images),
29
+ desc="Saving images ",
30
+ )
31
+ )
32
+ return temp_dir
33
+
34
+
35
+ # Compute FID between two sets of images
36
+ def compute_fid(
37
+ images1,
38
+ images2,
39
+ mode="legacy",
40
+ device=None,
41
+ batch_size=64,
42
+ num_workers=None,
43
+ verbose=False,
44
+ ):
45
+ # Support four types of FID scores
46
+ assert mode in ["legacy", "clean", "clip"]
47
+ if mode == "legacy":
48
+ mode = "legacy_pytorch"
49
+ model_name = "inception_v3"
50
+ elif mode == "clean":
51
+ mode = "clean"
52
+ model_name = "inception_v3"
53
+ elif mode == "clip":
54
+ mode = "clean"
55
+ model_name = "clip_vit_b_32"
56
+ else:
57
+ assert False
58
+
59
+ # Set up device and num_workers
60
+ if device is None:
61
+ device = (
62
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
63
+ )
64
+ if num_workers is not None:
65
+ assert 1 <= num_workers <= os.cpu_count()
66
+ else:
67
+ num_workers = max(torch.cuda.device_count() * 4, 8)
68
+
69
+ # Check images, can be paths or lists of PIL images
70
+ if not isinstance(images1, list):
71
+ assert isinstance(images1, str) and os.path.exists(images1)
72
+ assert isinstance(images2, str) and os.path.exists(images2)
73
+ path1 = images1
74
+ path2 = images2
75
+ else:
76
+ assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
77
+ assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
78
+ # Save images to temp dir if needed
79
+ path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose)
80
+ path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose)
81
+
82
+ # Attempt to cache statistics for path1
83
+ if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name):
84
+ fid.make_custom_stats(
85
+ name=str(os.path.abspath(path1)).replace("/", "_"),
86
+ fdir=path1,
87
+ mode=mode,
88
+ model_name=model_name,
89
+ device=device,
90
+ num_workers=num_workers,
91
+ verbose=verbose,
92
+ )
93
+ fid_score = fid.compute_fid(
94
+ path2,
95
+ dataset_name=str(os.path.abspath(path1)).replace("/", "_"),
96
+ dataset_split="custom",
97
+ mode=mode,
98
+ model_name=model_name,
99
+ device=device,
100
+ batch_size=batch_size,
101
+ num_workers=num_workers,
102
+ verbose=verbose,
103
+ )
104
+ return fid_score
kit/metrics/image.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from skimage.metrics import (
6
+ mean_squared_error,
7
+ peak_signal_noise_ratio,
8
+ structural_similarity as structural_similarity_index_measure,
9
+ normalized_mutual_information,
10
+ )
11
+ from tqdm.auto import tqdm
12
+ from concurrent.futures import ThreadPoolExecutor
13
+
14
+
15
+ # Process images to numpy arrays
16
+ def convert_image_pair_to_numpy(image1, image2):
17
+ assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image)
18
+
19
+ image1_np = np.array(image1)
20
+ image2_np = np.array(image2)
21
+ assert image1_np.shape == image2_np.shape
22
+
23
+ return image1_np, image2_np
24
+
25
+
26
+ # Compute MSE between two images
27
+ def compute_mse(image1, image2):
28
+ image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
29
+ return float(mean_squared_error(image1_np, image2_np))
30
+
31
+
32
+ # Compute PSNR between two images
33
+ def compute_psnr(image1, image2):
34
+ image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
35
+ return float(peak_signal_noise_ratio(image1_np, image2_np))
36
+
37
+
38
+ # Compute SSIM between two images
39
+ def compute_ssim(image1, image2):
40
+ image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
41
+ return float(
42
+ structural_similarity_index_measure(image1_np, image2_np, channel_axis=2)
43
+ )
44
+
45
+
46
+ # Compute NMI between two images
47
+ def compute_nmi(image1, image2):
48
+ image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
49
+ return float(normalized_mutual_information(image1_np, image2_np))
50
+
51
+
52
+ # Compute metrics
53
+ def compute_metric_repeated(
54
+ images1, images2, metric_func, num_workers=None, verbose=False
55
+ ):
56
+ # Accept list of PIL images
57
+ assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
58
+ assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
59
+ assert len(images1) == len(images2)
60
+
61
+ if num_workers is not None:
62
+ assert 1 <= num_workers <= os.cpu_count()
63
+ else:
64
+ num_workers = max(torch.cuda.device_count() * 4, 8)
65
+
66
+ metric_name = metric_func.__name__.split("_")[1].upper()
67
+
68
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
69
+ tasks = executor.map(metric_func, images1, images2)
70
+ values = (
71
+ list(tasks)
72
+ if not verbose
73
+ else list(
74
+ tqdm(
75
+ tasks,
76
+ total=len(images1),
77
+ desc=f"{metric_name} ",
78
+ )
79
+ )
80
+ )
81
+ return values
82
+
83
+
84
+ # Compute MSE between pairs of images
85
+ def compute_mse_repeated(images1, images2, num_workers=None, verbose=False):
86
+ return compute_metric_repeated(images1, images2, compute_mse, num_workers, verbose)
87
+
88
+
89
+ # Compute PSNR between pairs of images
90
+ def compute_psnr_repeated(images1, images2, num_workers=None, verbose=False):
91
+ return compute_metric_repeated(images1, images2, compute_psnr, num_workers, verbose)
92
+
93
+
94
+ # Compute SSIM between pairs of images
95
+ def compute_ssim_repeated(images1, images2, num_workers=None, verbose=False):
96
+ return compute_metric_repeated(images1, images2, compute_ssim, num_workers, verbose)
97
+
98
+
99
+ # Compute NMI between pairs of images
100
+ def compute_nmi_repeated(images1, images2, num_workers=None, verbose=False):
101
+ return compute_metric_repeated(images1, images2, compute_nmi, num_workers, verbose)
102
+
103
+
104
+ def compute_image_distance_repeated(
105
+ images1, images2, metric_name, num_workers=None, verbose=False
106
+ ):
107
+ metric_func = {
108
+ "psnr": compute_psnr,
109
+ "ssim": compute_ssim,
110
+ "nmi": compute_nmi,
111
+ }[metric_name]
112
+ return compute_metric_repeated(images1, images2, metric_func, num_workers, verbose)
kit/metrics/lpips/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ From https://github.com/richzhang/PerceptualSimilarity
3
+ """
4
+ from .lpips import LPIPS
kit/metrics/lpips/lpips.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.autograd import Variable
5
+ import warnings
6
+ from . import pretrained_networks as pn
7
+ from .utils import normalize_tensor, l2, dssim, tensor2np, tensor2tensorlab, tensor2im
8
+
9
+
10
+ def spatial_average(in_tens, keepdim=True):
11
+ return in_tens.mean([2, 3], keepdim=keepdim)
12
+
13
+
14
+ def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
15
+ in_H, in_W = in_tens.shape[2], in_tens.shape[3]
16
+ return nn.Upsample(size=out_HW, mode="bilinear", align_corners=False)(in_tens)
17
+
18
+
19
+ # Learned perceptual metric
20
+ class LPIPS(nn.Module):
21
+ def __init__(
22
+ self,
23
+ pretrained=True,
24
+ net="alex",
25
+ version="0.1",
26
+ lpips=True,
27
+ spatial=False,
28
+ pnet_rand=False,
29
+ pnet_tune=False,
30
+ use_dropout=True,
31
+ model_path=None,
32
+ eval_mode=True,
33
+ verbose=True,
34
+ ):
35
+ """Initializes a perceptual loss torch.nn.Module
36
+
37
+ Parameters (default listed first)
38
+ ---------------------------------
39
+ lpips : bool
40
+ [True] use linear layers on top of base/trunk network
41
+ [False] means no linear layers; each layer is averaged together
42
+ pretrained : bool
43
+ This flag controls the linear layers, which are only in effect when lpips=True above
44
+ [True] means linear layers are calibrated with human perceptual judgments
45
+ [False] means linear layers are randomly initialized
46
+ pnet_rand : bool
47
+ [False] means trunk loaded with ImageNet classification weights
48
+ [True] means randomly initialized trunk
49
+ net : str
50
+ ['alex','vgg','squeeze'] are the base/trunk networks available
51
+ version : str
52
+ ['v0.1'] is the default and latest
53
+ ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
54
+ model_path : 'str'
55
+ [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
56
+
57
+ The following parameters should only be changed if training the network
58
+
59
+ eval_mode : bool
60
+ [True] is for test mode (default)
61
+ [False] is for training mode
62
+ pnet_tune
63
+ [False] keep base/trunk frozen
64
+ [True] tune the base/trunk network
65
+ use_dropout : bool
66
+ [True] to use dropout when training linear layers
67
+ [False] for no dropout when training linear layers
68
+ """
69
+
70
+ super(LPIPS, self).__init__()
71
+ warnings.filterwarnings("ignore")
72
+ if verbose:
73
+ pass
74
+ # print(
75
+ # "Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]"
76
+ # % (
77
+ # "LPIPS" if lpips else "baseline",
78
+ # net,
79
+ # version,
80
+ # "on" if spatial else "off",
81
+ # )
82
+ # )
83
+
84
+ self.pnet_type = net
85
+ self.pnet_tune = pnet_tune
86
+ self.pnet_rand = pnet_rand
87
+ self.spatial = spatial
88
+ self.lpips = lpips # false means baseline of just averaging all layers
89
+ self.version = version
90
+ self.scaling_layer = ScalingLayer()
91
+
92
+ if self.pnet_type in ["vgg", "vgg16"]:
93
+ net_type = pn.vgg16
94
+ self.chns = [64, 128, 256, 512, 512]
95
+ elif self.pnet_type == "alex":
96
+ net_type = pn.alexnet
97
+ self.chns = [64, 192, 384, 256, 256]
98
+ elif self.pnet_type == "squeeze":
99
+ net_type = pn.squeezenet
100
+ self.chns = [64, 128, 256, 384, 384, 512, 512]
101
+ self.L = len(self.chns)
102
+
103
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
104
+
105
+ if lpips:
106
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
107
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
108
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
109
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
110
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
111
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
112
+ if self.pnet_type == "squeeze": # 7 layers for squeezenet
113
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
114
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
115
+ self.lins += [self.lin5, self.lin6]
116
+ self.lins = nn.ModuleList(self.lins)
117
+
118
+ if pretrained:
119
+ if model_path is None:
120
+ import inspect
121
+ import os
122
+
123
+ model_path = os.path.abspath(
124
+ os.path.join(
125
+ inspect.getfile(self.__init__),
126
+ "..",
127
+ "weights/v%s/%s.pth" % (version, net),
128
+ )
129
+ )
130
+
131
+ if verbose:
132
+ pass
133
+ # print("Loading model from: %s" % model_path)
134
+ self.load_state_dict(
135
+ torch.load(model_path, map_location="cpu"), strict=False
136
+ )
137
+
138
+ if eval_mode:
139
+ self.eval()
140
+
141
+ def forward(self, in0, in1, retPerLayer=False, normalize=False):
142
+ if (
143
+ normalize
144
+ ): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
145
+ in0 = 2 * in0 - 1
146
+ in1 = 2 * in1 - 1
147
+
148
+ # v0.0 - original release had a bug, where input was not scaled
149
+ in0_input, in1_input = (
150
+ (self.scaling_layer(in0), self.scaling_layer(in1))
151
+ if self.version == "0.1"
152
+ else (in0, in1)
153
+ )
154
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
155
+ feats0, feats1, diffs = {}, {}, {}
156
+
157
+ for kk in range(self.L):
158
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
159
+ outs1[kk]
160
+ )
161
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
162
+
163
+ if self.lpips:
164
+ if self.spatial:
165
+ res = [
166
+ upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:])
167
+ for kk in range(self.L)
168
+ ]
169
+ else:
170
+ res = [
171
+ spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
172
+ for kk in range(self.L)
173
+ ]
174
+ else:
175
+ if self.spatial:
176
+ res = [
177
+ upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:])
178
+ for kk in range(self.L)
179
+ ]
180
+ else:
181
+ res = [
182
+ spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)
183
+ for kk in range(self.L)
184
+ ]
185
+
186
+ val = 0
187
+ for l in range(self.L):
188
+ val += res[l]
189
+
190
+ if retPerLayer:
191
+ return (val, res)
192
+ else:
193
+ return val
194
+
195
+
196
+ class ScalingLayer(nn.Module):
197
+ def __init__(self):
198
+ super(ScalingLayer, self).__init__()
199
+ self.register_buffer(
200
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
201
+ )
202
+ self.register_buffer(
203
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
204
+ )
205
+
206
+ def forward(self, inp):
207
+ return (inp - self.shift) / self.scale
208
+
209
+
210
+ class NetLinLayer(nn.Module):
211
+ """A single linear layer which does a 1x1 conv"""
212
+
213
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
214
+ super(NetLinLayer, self).__init__()
215
+
216
+ layers = (
217
+ [
218
+ nn.Dropout(),
219
+ ]
220
+ if (use_dropout)
221
+ else []
222
+ )
223
+ layers += [
224
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
225
+ ]
226
+ self.model = nn.Sequential(*layers)
227
+
228
+ def forward(self, x):
229
+ return self.model(x)
230
+
231
+
232
+ class Dist2LogitLayer(nn.Module):
233
+ """takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)"""
234
+
235
+ def __init__(self, chn_mid=32, use_sigmoid=True):
236
+ super(Dist2LogitLayer, self).__init__()
237
+
238
+ layers = [
239
+ nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),
240
+ ]
241
+ layers += [
242
+ nn.LeakyReLU(0.2, True),
243
+ ]
244
+ layers += [
245
+ nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),
246
+ ]
247
+ layers += [
248
+ nn.LeakyReLU(0.2, True),
249
+ ]
250
+ layers += [
251
+ nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),
252
+ ]
253
+ if use_sigmoid:
254
+ layers += [
255
+ nn.Sigmoid(),
256
+ ]
257
+ self.model = nn.Sequential(*layers)
258
+
259
+ def forward(self, d0, d1, eps=0.1):
260
+ return self.model.forward(
261
+ torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)
262
+ )
263
+
264
+
265
+ class BCERankingLoss(nn.Module):
266
+ def __init__(self, chn_mid=32):
267
+ super(BCERankingLoss, self).__init__()
268
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
269
+ # self.parameters = list(self.net.parameters())
270
+ self.loss = torch.nn.BCELoss()
271
+
272
+ def forward(self, d0, d1, judge):
273
+ per = (judge + 1.0) / 2.0
274
+ self.logit = self.net.forward(d0, d1)
275
+ return self.loss(self.logit, per)
276
+
277
+
278
+ # L2, DSSIM metrics
279
+ class FakeNet(nn.Module):
280
+ def __init__(self, use_gpu=True, colorspace="Lab"):
281
+ super(FakeNet, self).__init__()
282
+ self.use_gpu = use_gpu
283
+ self.colorspace = colorspace
284
+
285
+
286
+ class L2(FakeNet):
287
+ def forward(self, in0, in1, retPerLayer=None):
288
+ assert in0.size()[0] == 1 # currently only supports batchSize 1
289
+
290
+ if self.colorspace == "RGB":
291
+ (N, C, X, Y) = in0.size()
292
+ value = torch.mean(
293
+ torch.mean(
294
+ torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2
295
+ ).view(N, 1, 1, Y),
296
+ dim=3,
297
+ ).view(N)
298
+ return value
299
+ elif self.colorspace == "Lab":
300
+ value = l2(
301
+ tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
302
+ tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
303
+ range=100.0,
304
+ ).astype("float")
305
+ ret_var = Variable(torch.Tensor((value,)))
306
+ if self.use_gpu:
307
+ ret_var = ret_var.cuda()
308
+ return ret_var
309
+
310
+
311
+ class DSSIM(FakeNet):
312
+ def forward(self, in0, in1, retPerLayer=None):
313
+ assert in0.size()[0] == 1 # currently only supports batchSize 1
314
+
315
+ if self.colorspace == "RGB":
316
+ value = dssim(
317
+ 1.0 * tensor2im(in0.data),
318
+ 1.0 * tensor2im(in1.data),
319
+ range=255.0,
320
+ ).astype("float")
321
+ elif self.colorspace == "Lab":
322
+ value = dssim(
323
+ tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
324
+ tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
325
+ range=100.0,
326
+ ).astype("float")
327
+ ret_var = Variable(torch.Tensor((value,)))
328
+ if self.use_gpu:
329
+ ret_var = ret_var.cuda()
330
+ return ret_var
331
+
332
+
333
+ def print_network(net):
334
+ num_params = 0
335
+ for param in net.parameters():
336
+ num_params += param.numel()
337
+ print("Network", net)
338
+ print("Total number of parameters: %d" % num_params)
kit/metrics/lpips/pretrained_networks.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2, 5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple(
52
+ "SqueezeOutputs",
53
+ ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
54
+ )
55
+ out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
56
+
57
+ return out
58
+
59
+
60
+ class alexnet(torch.nn.Module):
61
+ def __init__(self, requires_grad=False, pretrained=True):
62
+ super(alexnet, self).__init__()
63
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
64
+ self.slice1 = torch.nn.Sequential()
65
+ self.slice2 = torch.nn.Sequential()
66
+ self.slice3 = torch.nn.Sequential()
67
+ self.slice4 = torch.nn.Sequential()
68
+ self.slice5 = torch.nn.Sequential()
69
+ self.N_slices = 5
70
+ for x in range(2):
71
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
72
+ for x in range(2, 5):
73
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
74
+ for x in range(5, 8):
75
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
76
+ for x in range(8, 10):
77
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
78
+ for x in range(10, 12):
79
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
80
+ if not requires_grad:
81
+ for param in self.parameters():
82
+ param.requires_grad = False
83
+
84
+ def forward(self, X):
85
+ h = self.slice1(X)
86
+ h_relu1 = h
87
+ h = self.slice2(h)
88
+ h_relu2 = h
89
+ h = self.slice3(h)
90
+ h_relu3 = h
91
+ h = self.slice4(h)
92
+ h_relu4 = h
93
+ h = self.slice5(h)
94
+ h_relu5 = h
95
+ alexnet_outputs = namedtuple(
96
+ "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
97
+ )
98
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
99
+
100
+ return out
101
+
102
+
103
+ class vgg16(torch.nn.Module):
104
+ def __init__(self, requires_grad=False, pretrained=True):
105
+ super(vgg16, self).__init__()
106
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
107
+ self.slice1 = torch.nn.Sequential()
108
+ self.slice2 = torch.nn.Sequential()
109
+ self.slice3 = torch.nn.Sequential()
110
+ self.slice4 = torch.nn.Sequential()
111
+ self.slice5 = torch.nn.Sequential()
112
+ self.N_slices = 5
113
+ for x in range(4):
114
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(4, 9):
116
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(9, 16):
118
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
119
+ for x in range(16, 23):
120
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
121
+ for x in range(23, 30):
122
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
123
+ if not requires_grad:
124
+ for param in self.parameters():
125
+ param.requires_grad = False
126
+
127
+ def forward(self, X):
128
+ h = self.slice1(X)
129
+ h_relu1_2 = h
130
+ h = self.slice2(h)
131
+ h_relu2_2 = h
132
+ h = self.slice3(h)
133
+ h_relu3_3 = h
134
+ h = self.slice4(h)
135
+ h_relu4_3 = h
136
+ h = self.slice5(h)
137
+ h_relu5_3 = h
138
+ vgg_outputs = namedtuple(
139
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
140
+ )
141
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
142
+
143
+ return out
144
+
145
+
146
+ class resnet(torch.nn.Module):
147
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
148
+ super(resnet, self).__init__()
149
+ if num == 18:
150
+ self.net = tv.resnet18(pretrained=pretrained)
151
+ elif num == 34:
152
+ self.net = tv.resnet34(pretrained=pretrained)
153
+ elif num == 50:
154
+ self.net = tv.resnet50(pretrained=pretrained)
155
+ elif num == 101:
156
+ self.net = tv.resnet101(pretrained=pretrained)
157
+ elif num == 152:
158
+ self.net = tv.resnet152(pretrained=pretrained)
159
+ self.N_slices = 5
160
+
161
+ self.conv1 = self.net.conv1
162
+ self.bn1 = self.net.bn1
163
+ self.relu = self.net.relu
164
+ self.maxpool = self.net.maxpool
165
+ self.layer1 = self.net.layer1
166
+ self.layer2 = self.net.layer2
167
+ self.layer3 = self.net.layer3
168
+ self.layer4 = self.net.layer4
169
+
170
+ def forward(self, X):
171
+ h = self.conv1(X)
172
+ h = self.bn1(h)
173
+ h = self.relu(h)
174
+ h_relu1 = h
175
+ h = self.maxpool(h)
176
+ h = self.layer1(h)
177
+ h_conv2 = h
178
+ h = self.layer2(h)
179
+ h_conv3 = h
180
+ h = self.layer3(h)
181
+ h_conv4 = h
182
+ h = self.layer4(h)
183
+ h_conv5 = h
184
+
185
+ outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"])
186
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
187
+
188
+ return out
kit/metrics/lpips/trainer.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ import numpy as np
3
+ import torch
4
+ from collections import OrderedDict
5
+ from torch.autograd import Variable
6
+ from scipy.ndimage import zoom
7
+ from tqdm import tqdm
8
+ import os
9
+ from .lpips import LPIPS, L2, DSSIM, BCERankingLoss
10
+ from .utils import tensor2im, voc_ap
11
+
12
+
13
+ class Trainer:
14
+ def name(self):
15
+ return self.model_name
16
+
17
+ def initialize(
18
+ self,
19
+ model="lpips",
20
+ net="alex",
21
+ colorspace="Lab",
22
+ pnet_rand=False,
23
+ pnet_tune=False,
24
+ model_path=None,
25
+ use_gpu=True,
26
+ printNet=False,
27
+ spatial=False,
28
+ is_train=False,
29
+ lr=0.0001,
30
+ beta1=0.5,
31
+ version="0.1",
32
+ gpu_ids=[0],
33
+ ):
34
+ """
35
+ INPUTS
36
+ model - ['lpips'] for linearly calibrated network
37
+ ['baseline'] for off-the-shelf network
38
+ ['L2'] for L2 distance in Lab colorspace
39
+ ['SSIM'] for ssim in RGB colorspace
40
+ net - ['squeeze','alex','vgg']
41
+ model_path - if None, will look in weights/[NET_NAME].pth
42
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
43
+ use_gpu - bool - whether or not to use a GPU
44
+ printNet - bool - whether or not to print network architecture out
45
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
46
+ is_train - bool - [True] for training mode
47
+ lr - float - initial learning rate
48
+ beta1 - float - initial momentum term for adam
49
+ version - 0.1 for latest, 0.0 was original (with a bug)
50
+ gpu_ids - int array - [0] by default, gpus to use
51
+ """
52
+ self.use_gpu = use_gpu
53
+ self.gpu_ids = gpu_ids
54
+ self.model = model
55
+ self.net = net
56
+ self.is_train = is_train
57
+ self.spatial = spatial
58
+ self.model_name = "%s [%s]" % (model, net)
59
+
60
+ if self.model == "lpips": # pretrained net + linear layer
61
+ self.net = LPIPS(
62
+ pretrained=not is_train,
63
+ net=net,
64
+ version=version,
65
+ lpips=True,
66
+ spatial=spatial,
67
+ pnet_rand=pnet_rand,
68
+ pnet_tune=pnet_tune,
69
+ use_dropout=True,
70
+ model_path=model_path,
71
+ eval_mode=False,
72
+ )
73
+ elif self.model == "baseline": # pretrained network
74
+ self.net = LPIPS(pnet_rand=pnet_rand, net=net, lpips=False)
75
+ elif self.model in ["L2", "l2"]:
76
+ self.net = L2(
77
+ use_gpu=use_gpu, colorspace=colorspace
78
+ ) # not really a network, only for testing
79
+ self.model_name = "L2"
80
+ elif self.model in ["DSSIM", "dssim", "SSIM", "ssim"]:
81
+ self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace)
82
+ self.model_name = "SSIM"
83
+ else:
84
+ raise ValueError("Model [%s] not recognized." % self.model)
85
+
86
+ self.parameters = list(self.net.parameters())
87
+
88
+ if self.is_train: # training mode
89
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90
+ self.rankLoss = BCERankingLoss()
91
+ self.parameters += list(self.rankLoss.net.parameters())
92
+ self.lr = lr
93
+ self.old_lr = lr
94
+ self.optimizer_net = torch.optim.Adam(
95
+ self.parameters, lr=lr, betas=(beta1, 0.999)
96
+ )
97
+ else: # test mode
98
+ self.net.eval()
99
+
100
+ if use_gpu:
101
+ self.net.to(gpu_ids[0])
102
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
103
+ if self.is_train:
104
+ self.rankLoss = self.rankLoss.to(
105
+ device=gpu_ids[0]
106
+ ) # just put this on GPU0
107
+
108
+ if printNet:
109
+ pass
110
+
111
+ def forward(self, in0, in1, retPerLayer=False):
112
+ """Function computes the distance between image patches in0 and in1
113
+ INPUTS
114
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
115
+ OUTPUT
116
+ computed distances between in0 and in1
117
+ """
118
+
119
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
120
+
121
+ # ***** TRAINING FUNCTIONS *****
122
+ def optimize_parameters(self):
123
+ self.forward_train()
124
+ self.optimizer_net.zero_grad()
125
+ self.backward_train()
126
+ self.optimizer_net.step()
127
+ self.clamp_weights()
128
+
129
+ def clamp_weights(self):
130
+ for module in self.net.modules():
131
+ if hasattr(module, "weight") and module.kernel_size == (1, 1):
132
+ module.weight.data = torch.clamp(module.weight.data, min=0)
133
+
134
+ def set_input(self, data):
135
+ self.input_ref = data["ref"]
136
+ self.input_p0 = data["p0"]
137
+ self.input_p1 = data["p1"]
138
+ self.input_judge = data["judge"]
139
+
140
+ if self.use_gpu:
141
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
142
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
143
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
144
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
145
+
146
+ self.var_ref = Variable(self.input_ref, requires_grad=True)
147
+ self.var_p0 = Variable(self.input_p0, requires_grad=True)
148
+ self.var_p1 = Variable(self.input_p1, requires_grad=True)
149
+
150
+ def forward_train(self): # run forward pass
151
+ self.d0 = self.forward(self.var_ref, self.var_p0)
152
+ self.d1 = self.forward(self.var_ref, self.var_p1)
153
+ self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
154
+
155
+ self.var_judge = Variable(1.0 * self.input_judge).view(self.d0.size())
156
+
157
+ self.loss_total = self.rankLoss.forward(
158
+ self.d0, self.d1, self.var_judge * 2.0 - 1.0
159
+ )
160
+
161
+ return self.loss_total
162
+
163
+ def backward_train(self):
164
+ torch.mean(self.loss_total).backward()
165
+
166
+ def compute_accuracy(self, d0, d1, judge):
167
+ """d0, d1 are Variables, judge is a Tensor"""
168
+ d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
169
+ judge_per = judge.cpu().numpy().flatten()
170
+ return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
171
+
172
+ def get_current_errors(self):
173
+ retDict = OrderedDict(
174
+ [("loss_total", self.loss_total.data.cpu().numpy()), ("acc_r", self.acc_r)]
175
+ )
176
+
177
+ for key in retDict.keys():
178
+ retDict[key] = np.mean(retDict[key])
179
+
180
+ return retDict
181
+
182
+ def get_current_visuals(self):
183
+ zoom_factor = 256 / self.var_ref.data.size()[2]
184
+
185
+ ref_img = tensor2im(self.var_ref.data)
186
+ p0_img = tensor2im(self.var_p0.data)
187
+ p1_img = tensor2im(self.var_p1.data)
188
+
189
+ ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
190
+ p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
191
+ p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
192
+
193
+ return OrderedDict(
194
+ [("ref", ref_img_vis), ("p0", p0_img_vis), ("p1", p1_img_vis)]
195
+ )
196
+
197
+ def save(self, path, label):
198
+ if self.use_gpu:
199
+ self.save_network(self.net.module, path, "", label)
200
+ else:
201
+ self.save_network(self.net, path, "", label)
202
+ self.save_network(self.rankLoss.net, path, "rank", label)
203
+
204
+ # helper saving function that can be used by subclasses
205
+ def save_network(self, network, path, network_label, epoch_label):
206
+ save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
207
+ save_path = os.path.join(path, save_filename)
208
+ torch.save(network.state_dict(), save_path)
209
+
210
+ # helper loading function that can be used by subclasses
211
+ def load_network(self, network, network_label, epoch_label):
212
+ save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
213
+ save_path = os.path.join(self.save_dir, save_filename)
214
+ print("Loading network from %s" % save_path)
215
+ network.load_state_dict(torch.load(save_path))
216
+
217
+ def update_learning_rate(self, nepoch_decay):
218
+ lrd = self.lr / nepoch_decay
219
+ lr = self.old_lr - lrd
220
+
221
+ for param_group in self.optimizer_net.param_groups:
222
+ param_group["lr"] = lr
223
+
224
+ print("update lr [%s] decay: %f -> %f" % (type, self.old_lr, lr))
225
+ self.old_lr = lr
226
+
227
+ def get_image_paths(self):
228
+ return self.image_paths
229
+
230
+ def save_done(self, flag=False):
231
+ np.save(os.path.join(self.save_dir, "done_flag"), flag)
232
+ np.savetxt(
233
+ os.path.join(self.save_dir, "done_flag"),
234
+ [
235
+ flag,
236
+ ],
237
+ fmt="%i",
238
+ )
239
+
240
+
241
+ def score_2afc_dataset(data_loader, func, name=""):
242
+ """Function computes Two Alternative Forced Choice (2AFC) score using
243
+ distance function 'func' in dataset 'data_loader'
244
+ INPUTS
245
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
246
+ func - callable distance function - calling d=func(in0,in1) should take 2
247
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
248
+ OUTPUTS
249
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
250
+ [1] - dictionary with following elements
251
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
252
+ gts - N array in [0,1], preferred patch selected by human evaluators
253
+ (closer to "0" for left patch p0, "1" for right patch p1,
254
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
255
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
256
+ CONSTS
257
+ N - number of test triplets in data_loader
258
+ """
259
+
260
+ d0s = []
261
+ d1s = []
262
+ gts = []
263
+
264
+ for data in tqdm(data_loader.load_data(), desc=name):
265
+ d0s += func(data["ref"], data["p0"]).data.cpu().numpy().flatten().tolist()
266
+ d1s += func(data["ref"], data["p1"]).data.cpu().numpy().flatten().tolist()
267
+ gts += data["judge"].cpu().numpy().flatten().tolist()
268
+
269
+ d0s = np.array(d0s)
270
+ d1s = np.array(d1s)
271
+ gts = np.array(gts)
272
+ scores = (d0s < d1s) * (1.0 - gts) + (d1s < d0s) * gts + (d1s == d0s) * 0.5
273
+
274
+ return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
275
+
276
+
277
+ def score_jnd_dataset(data_loader, func, name=""):
278
+ """Function computes JND score using distance function 'func' in dataset 'data_loader'
279
+ INPUTS
280
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
281
+ func - callable distance function - calling d=func(in0,in1) should take 2
282
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
283
+ OUTPUTS
284
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
285
+ [1] - dictionary with following elements
286
+ ds - N array containing distances between two patches shown to human evaluator
287
+ sames - N array containing fraction of people who thought the two patches were identical
288
+ CONSTS
289
+ N - number of test triplets in data_loader
290
+ """
291
+
292
+ ds = []
293
+ gts = []
294
+
295
+ for data in tqdm(data_loader.load_data(), desc=name):
296
+ ds += func(data["p0"], data["p1"]).data.cpu().numpy().tolist()
297
+ gts += data["same"].cpu().numpy().flatten().tolist()
298
+
299
+ sames = np.array(gts)
300
+ ds = np.array(ds)
301
+
302
+ sorted_inds = np.argsort(ds)
303
+ ds_sorted = ds[sorted_inds]
304
+ sames_sorted = sames[sorted_inds]
305
+
306
+ TPs = np.cumsum(sames_sorted)
307
+ FPs = np.cumsum(1 - sames_sorted)
308
+ FNs = np.sum(sames_sorted) - TPs
309
+
310
+ precs = TPs / (TPs + FPs)
311
+ recs = TPs / (TPs + FNs)
312
+ score = voc_ap(recs, precs)
313
+
314
+ return (score, dict(ds=ds, sames=sames))
kit/metrics/lpips/utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def normalize_tensor(in_feat, eps=1e-10):
9
+ norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
10
+ return in_feat / (norm_factor + eps)
11
+
12
+
13
+ def l2(p0, p1, range=255.0):
14
+ return 0.5 * np.mean((p0 / range - p1 / range) ** 2)
15
+
16
+
17
+ def psnr(p0, p1, peak=255.0):
18
+ return 10 * np.log10(peak**2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2))
19
+
20
+
21
+ def dssim(p0, p1, range=255.0):
22
+ from skimage.measure import compare_ssim
23
+
24
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0
25
+
26
+
27
+ def tensor2np(tensor_obj):
28
+ # change dimension of a tensor object into a numpy array
29
+ return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
30
+
31
+
32
+ def np2tensor(np_obj):
33
+ # change dimenion of np array into tensor array
34
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
35
+
36
+
37
+ def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
38
+ # image tensor to lab tensor
39
+ from skimage import color
40
+
41
+ img = tensor2im(image_tensor)
42
+ img_lab = color.rgb2lab(img)
43
+ if mc_only:
44
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
45
+ if to_norm and not mc_only:
46
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
47
+ img_lab = img_lab / 100.0
48
+
49
+ return np2tensor(img_lab)
50
+
51
+
52
+ def tensorlab2tensor(lab_tensor, return_inbnd=False):
53
+ from skimage import color
54
+ import warnings
55
+
56
+ warnings.filterwarnings("ignore")
57
+
58
+ lab = tensor2np(lab_tensor) * 100.0
59
+ lab[:, :, 0] = lab[:, :, 0] + 50
60
+
61
+ rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1)
62
+ if return_inbnd:
63
+ # convert back to lab, see if we match
64
+ lab_back = color.rgb2lab(rgb_back.astype("uint8"))
65
+ mask = 1.0 * np.isclose(lab_back, lab, atol=2.0)
66
+ mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
67
+ return (im2tensor(rgb_back), mask)
68
+ else:
69
+ return im2tensor(rgb_back)
70
+
71
+
72
+ def load_image(path):
73
+ if (
74
+ path[-3:] == "bmp"
75
+ or path[-3:] == "jpg"
76
+ or path[-3:] == "png"
77
+ or path[-4:] == "jpeg"
78
+ ):
79
+ import cv2
80
+
81
+ return cv2.imread(path)[:, :, ::-1]
82
+ else:
83
+ import matplotlib.pyplot as plt
84
+
85
+ img = (255 * plt.imread(path)[:, :, :3]).astype("uint8")
86
+
87
+ return img
88
+
89
+
90
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
91
+ image_numpy = image_tensor[0].cpu().float().numpy()
92
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
93
+ return image_numpy.astype(imtype)
94
+
95
+
96
+ def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
97
+ return torch.Tensor(
98
+ (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
99
+ )
100
+
101
+
102
+ def tensor2vec(vector_tensor):
103
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
104
+
105
+
106
+ def voc_ap(rec, prec, use_07_metric=False):
107
+ """ap = voc_ap(rec, prec, [use_07_metric])
108
+ Compute VOC AP given precision and recall.
109
+ If use_07_metric is true, uses the
110
+ VOC 07 11 point method (default:False).
111
+ """
112
+ if use_07_metric:
113
+ # 11 point metric
114
+ ap = 0.0
115
+ for t in np.arange(0.0, 1.1, 0.1):
116
+ if np.sum(rec >= t) == 0:
117
+ p = 0
118
+ else:
119
+ p = np.max(prec[rec >= t])
120
+ ap = ap + p / 11.0
121
+ else:
122
+ # correct AP calculation
123
+ # first append sentinel values at the end
124
+ mrec = np.concatenate(([0.0], rec, [1.0]))
125
+ mpre = np.concatenate(([0.0], prec, [0.0]))
126
+
127
+ # compute the precision envelope
128
+ for i in range(mpre.size - 1, 0, -1):
129
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
130
+
131
+ # to calculate area under PR curve, look for points
132
+ # where X axis (recall) changes value
133
+ i = np.where(mrec[1:] != mrec[:-1])[0]
134
+
135
+ # and sum (\Delta recall) * prec
136
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
137
+ return ap
kit/metrics/lpips/weights/v0.0/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
3
+ size 5455
kit/metrics/lpips/weights/v0.0/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
3
+ size 10057
kit/metrics/lpips/weights/v0.0/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
3
+ size 6735
kit/metrics/lpips/weights/v0.1/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
3
+ size 6009
kit/metrics/lpips/weights/v0.1/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
3
+ size 10811
kit/metrics/lpips/weights/v0.1/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
kit/metrics/perceptual.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ from .lpips import LPIPS
5
+
6
+
7
+ # Normalize image tensors
8
+ def normalize_tensor(images, norm_type):
9
+ assert norm_type in ["imagenet", "naive"]
10
+ # Two possible normalization conventions
11
+ if norm_type == "imagenet":
12
+ mean = [0.485, 0.456, 0.406]
13
+ std = [0.229, 0.224, 0.225]
14
+ normalize = transforms.Normalize(mean, std)
15
+ elif norm_type == "naive":
16
+ mean = [0.5, 0.5, 0.5]
17
+ std = [0.5, 0.5, 0.5]
18
+ normalize = transforms.Normalize(mean, std)
19
+ else:
20
+ assert False
21
+ return torch.stack([normalize(image) for image in images])
22
+
23
+
24
+ def to_tensor(images, norm_type="naive"):
25
+ assert isinstance(images, list) and all(
26
+ [isinstance(image, Image.Image) for image in images]
27
+ )
28
+ images = torch.stack([transforms.ToTensor()(image) for image in images])
29
+ if norm_type is not None:
30
+ images = normalize_tensor(images, norm_type)
31
+ return images
32
+
33
+
34
+ def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
35
+ assert metric_name in ["lpips"]
36
+ if metric_name == "lpips":
37
+ assert mode in ["vgg", "alex"]
38
+ perceptual_model = LPIPS(net=mode).to(device)
39
+ else:
40
+ assert False
41
+ return perceptual_model
42
+
43
+
44
+ # Compute metric between two images
45
+ def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")):
46
+ assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image)
47
+ image1_tensor = to_tensor([image1]).to(device)
48
+ image2_tensor = to_tensor([image2]).to(device)
49
+ return perceptual_model(image1_tensor, image2_tensor).cpu().item()
50
+
51
+
52
+ # Compute LPIPS distance between two images
53
+ def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")):
54
+ perceptual_model = load_perceptual_models("lpips", mode, device)
55
+ return compute_metric(image1, image2, perceptual_model, device)
56
+
57
+
58
+ # Compute metrics between pairs of images
59
+ def compute_perceptual_metric_repeated(
60
+ images1,
61
+ images2,
62
+ metric_name,
63
+ mode,
64
+ model,
65
+ device,
66
+ ):
67
+ # Accept list of PIL images
68
+ assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
69
+ assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
70
+ assert len(images1) == len(images2)
71
+ if model is None:
72
+ model = load_perceptual_models(metric_name, mode).to(device)
73
+ return (
74
+ model(to_tensor(images1).to(device), to_tensor(images2).to(device))
75
+ .detach()
76
+ .cpu()
77
+ .numpy()
78
+ .flatten()
79
+ .tolist()
80
+ )
81
+
82
+
83
+ # Compute LPIPS distance between pairs of images
84
+ def compute_lpips_repeated(
85
+ images1,
86
+ images2,
87
+ mode="alex",
88
+ model=None,
89
+ device=torch.device("cuda"),
90
+ ):
91
+ return compute_perceptual_metric_repeated(
92
+ images1, images2, "lpips", mode, model, device
93
+ )
kit/metrics/prompt.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
3
+
4
+
5
+ # Load GPT-2 large model and tokenizer
6
+ def load_perplexity_model_and_tokenizer():
7
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
+ ppl_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device)
9
+ ppl_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-large")
10
+ return ppl_model, ppl_tokenizer
11
+
12
+
13
+ # Compute perplexity for a single prompt
14
+ def compute_prompt_perplexity(prompt, models, stride=512):
15
+ assert isinstance(prompt, str)
16
+ assert isinstance(models, tuple) and len(models) == 2
17
+ ppl_model, ppl_tokenizer = models
18
+ encodings = ppl_tokenizer(prompt, return_tensors="pt")
19
+ max_length = ppl_model.config.n_positions
20
+ seq_len = encodings.input_ids.size(1)
21
+ nlls = []
22
+ prev_end_loc = 0
23
+ for begin_loc in range(0, seq_len, stride):
24
+ end_loc = min(begin_loc + max_length, seq_len)
25
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
26
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(
27
+ next(ppl_model.parameters()).device
28
+ )
29
+ target_ids = input_ids.clone()
30
+ target_ids[:, :-trg_len] = -100
31
+ with torch.no_grad():
32
+ outputs = ppl_model(input_ids, labels=target_ids)
33
+ neg_log_likelihood = outputs.loss
34
+ nlls.append(neg_log_likelihood)
35
+ prev_end_loc = end_loc
36
+ if end_loc == seq_len:
37
+ break
38
+ ppl = torch.exp(torch.stack(nlls).mean()).item()
39
+ return ppl