Spaces:
Runtime error
Runtime error
mcding
commited on
Commit
·
ad552d8
0
Parent(s):
published version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +164 -0
- README.md +10 -0
- app.py +426 -0
- icon.jpg +0 -0
- image.png +0 -0
- kit/__init__.py +121 -0
- kit/metrics/__init__.py +27 -0
- kit/metrics/aesthetics.py +38 -0
- kit/metrics/aesthetics_scorer/__init__.py +4 -0
- kit/metrics/aesthetics_scorer/model.py +104 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.config +8 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.config +8 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.config +8 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.config +8 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.config +8 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.config +8 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth +3 -0
- kit/metrics/clean_fid/__init__.py +3 -0
- kit/metrics/clean_fid/clip_features.py +40 -0
- kit/metrics/clean_fid/downloads_helper.py +75 -0
- kit/metrics/clean_fid/features.py +117 -0
- kit/metrics/clean_fid/fid.py +836 -0
- kit/metrics/clean_fid/inception_pytorch.py +329 -0
- kit/metrics/clean_fid/inception_torchscript.py +59 -0
- kit/metrics/clean_fid/leaderboard.py +58 -0
- kit/metrics/clean_fid/resize.py +108 -0
- kit/metrics/clean_fid/utils.py +75 -0
- kit/metrics/clean_fid/wrappers.py +111 -0
- kit/metrics/clip.py +32 -0
- kit/metrics/distributional.py +104 -0
- kit/metrics/image.py +112 -0
- kit/metrics/lpips/__init__.py +4 -0
- kit/metrics/lpips/lpips.py +338 -0
- kit/metrics/lpips/pretrained_networks.py +188 -0
- kit/metrics/lpips/trainer.py +314 -0
- kit/metrics/lpips/utils.py +137 -0
- kit/metrics/lpips/weights/v0.0/alex.pth +3 -0
- kit/metrics/lpips/weights/v0.0/squeeze.pth +3 -0
- kit/metrics/lpips/weights/v0.0/vgg.pth +3 -0
- kit/metrics/lpips/weights/v0.1/alex.pth +3 -0
- kit/metrics/lpips/weights/v0.1/squeeze.pth +3 -0
- kit/metrics/lpips/weights/v0.1/vgg.pth +3 -0
- kit/metrics/perceptual.py +93 -0
- kit/metrics/prompt.py +39 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
.env
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Erasing the Invisible
|
3 |
+
emoji: 🌊
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
---
|
app.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import redis
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
from datetime import datetime
|
8 |
+
from PIL import Image
|
9 |
+
from kit import compute_performance, compute_quality
|
10 |
+
import dotenv
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
dotenv.load_dotenv()
|
14 |
+
|
15 |
+
CSS = """
|
16 |
+
.tabs button{
|
17 |
+
font-size: 20px;
|
18 |
+
}
|
19 |
+
#download_btn {
|
20 |
+
height: 91.6px;
|
21 |
+
}
|
22 |
+
#submit_btn {
|
23 |
+
height: 91.6px;
|
24 |
+
}
|
25 |
+
#original_image {
|
26 |
+
display: block;
|
27 |
+
margin-left: auto;
|
28 |
+
margin-right: auto;
|
29 |
+
}
|
30 |
+
#uploaded_image {
|
31 |
+
display: block;
|
32 |
+
margin-left: auto;
|
33 |
+
margin-right: auto;
|
34 |
+
}
|
35 |
+
#leaderboard_plot {
|
36 |
+
display: block;
|
37 |
+
margin-left: auto;
|
38 |
+
margin-right: auto;
|
39 |
+
width: 640px; /* Adjust width as needed */
|
40 |
+
height: 640px; /* Adjust height as needed */
|
41 |
+
#leaderboard_table {
|
42 |
+
display: block;
|
43 |
+
margin-left: auto;
|
44 |
+
margin-right: auto;
|
45 |
+
}
|
46 |
+
"""
|
47 |
+
|
48 |
+
JS = """
|
49 |
+
function refresh() {
|
50 |
+
const url = new URL(window.location);
|
51 |
+
|
52 |
+
if (url.searchParams.get('__theme') !== 'dark') {
|
53 |
+
url.searchParams.set('__theme', 'dark');
|
54 |
+
window.location.href = url.href;
|
55 |
+
}
|
56 |
+
}
|
57 |
+
"""
|
58 |
+
|
59 |
+
QUALITY_POST_FUNC = lambda x: x / 4 * 8
|
60 |
+
PERFORMANCE_POST_FUNC = lambda x: abs(x - 0.5) * 2
|
61 |
+
|
62 |
+
|
63 |
+
# Connect to Redis
|
64 |
+
redis_client = redis.Redis(
|
65 |
+
host=os.getenv("REDIS_HOST"),
|
66 |
+
port=os.getenv("REDIS_PORT"),
|
67 |
+
username=os.getenv("REDIS_USERNAME"),
|
68 |
+
password=os.getenv("REDIS_PASSWORD"),
|
69 |
+
decode_responses=True,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def save_to_redis(name, performance, quality):
|
74 |
+
submission = {
|
75 |
+
"name": name,
|
76 |
+
"performance": performance,
|
77 |
+
"quality": quality,
|
78 |
+
"timestamp": datetime.now().isoformat(),
|
79 |
+
}
|
80 |
+
redis_client.lpush("submissions", json.dumps(submission))
|
81 |
+
|
82 |
+
|
83 |
+
def get_submissions_from_redis():
|
84 |
+
submissions = redis_client.lrange("submissions", 0, -1)
|
85 |
+
submissions = [json.loads(submission) for submission in submissions]
|
86 |
+
for s in submissions:
|
87 |
+
s["quality"] = QUALITY_POST_FUNC(s["quality"])
|
88 |
+
s["performance"] = PERFORMANCE_POST_FUNC(s["performance"])
|
89 |
+
s["score"] = np.sqrt(float(s["quality"]) ** 2 + float(s["performance"]) ** 2)
|
90 |
+
return submissions
|
91 |
+
|
92 |
+
|
93 |
+
def update_plot(
|
94 |
+
submissions,
|
95 |
+
current_name=None,
|
96 |
+
):
|
97 |
+
names = [sub["name"] for sub in submissions]
|
98 |
+
performances = [float(sub["performance"]) for sub in submissions]
|
99 |
+
qualities = [float(sub["quality"]) for sub in submissions]
|
100 |
+
|
101 |
+
# Create scatter plot
|
102 |
+
fig = go.Figure()
|
103 |
+
|
104 |
+
for name, quality, performance in zip(names, qualities, performances):
|
105 |
+
if name == current_name:
|
106 |
+
marker = dict(symbol="star", size=15, color="orange")
|
107 |
+
elif name.startswith("Baseline: "):
|
108 |
+
marker = dict(symbol="square", size=8, color="blue")
|
109 |
+
else:
|
110 |
+
marker = dict(symbol="circle", size=10, color="green")
|
111 |
+
|
112 |
+
fig.add_trace(
|
113 |
+
go.Scatter(
|
114 |
+
x=[quality],
|
115 |
+
y=[performance],
|
116 |
+
mode="markers+text",
|
117 |
+
text=[name if not name.startswith("Baseline: ") else ""],
|
118 |
+
textposition="top center",
|
119 |
+
name=name,
|
120 |
+
marker=marker,
|
121 |
+
customdata=[
|
122 |
+
name if name.startswith("Baseline: ") else f"User: {name}",
|
123 |
+
],
|
124 |
+
hovertemplate="<b>%{customdata}</b><br>"
|
125 |
+
+ "Performance: %{y:.3f}<br>"
|
126 |
+
+ "Quality: %{x:.3f}<br>"
|
127 |
+
+ "<extra></extra>",
|
128 |
+
)
|
129 |
+
)
|
130 |
+
|
131 |
+
# Add circles
|
132 |
+
circle_radii = np.linspace(0, 1, 5)
|
133 |
+
for radius in circle_radii:
|
134 |
+
theta = np.linspace(0, 2 * np.pi, 100)
|
135 |
+
x = radius * np.cos(theta)
|
136 |
+
y = radius * np.sin(theta)
|
137 |
+
fig.add_trace(
|
138 |
+
go.Scatter(
|
139 |
+
x=x,
|
140 |
+
y=y,
|
141 |
+
mode="lines",
|
142 |
+
line=dict(color="gray", dash="dash"),
|
143 |
+
showlegend=False,
|
144 |
+
hovertemplate="Performance: %{x:.3f}<br>"
|
145 |
+
+ "Quality: %{y:.3f}<br>"
|
146 |
+
+ "<extra></extra>",
|
147 |
+
)
|
148 |
+
)
|
149 |
+
|
150 |
+
# Update layout
|
151 |
+
fig.update_layout(
|
152 |
+
xaxis_title="Image Quality Degredation",
|
153 |
+
yaxis_title="Watermark Detection Performance",
|
154 |
+
xaxis=dict(
|
155 |
+
range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
|
156 |
+
),
|
157 |
+
yaxis=dict(
|
158 |
+
range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
|
159 |
+
),
|
160 |
+
width=640,
|
161 |
+
height=640,
|
162 |
+
showlegend=False, # Remove legend
|
163 |
+
modebar=dict(remove=["all"]),
|
164 |
+
)
|
165 |
+
fig.update_xaxes(title_font_size=20)
|
166 |
+
fig.update_yaxes(title_font_size=20)
|
167 |
+
|
168 |
+
return fig
|
169 |
+
|
170 |
+
|
171 |
+
def update_table(
|
172 |
+
submissions,
|
173 |
+
current_name=None,
|
174 |
+
):
|
175 |
+
def tp(timestamp):
|
176 |
+
return timestamp.replace("T", " ").split(".")[0]
|
177 |
+
|
178 |
+
names = [
|
179 |
+
(
|
180 |
+
sub["name"][len("Baseline: ") :]
|
181 |
+
if sub["name"].startswith("Baseline: ")
|
182 |
+
else sub["name"]
|
183 |
+
)
|
184 |
+
for sub in submissions
|
185 |
+
]
|
186 |
+
times = [
|
187 |
+
(
|
188 |
+
""
|
189 |
+
if sub["name"].startswith("Baseline: ")
|
190 |
+
else (
|
191 |
+
tp(sub["timestamp"]) + " (Current)"
|
192 |
+
if sub["name"] == current_name
|
193 |
+
else tp(sub["timestamp"])
|
194 |
+
)
|
195 |
+
)
|
196 |
+
for sub in submissions
|
197 |
+
]
|
198 |
+
performances = ["%.4f" % (float(sub["performance"])) for sub in submissions]
|
199 |
+
qualities = ["%.4f" % (float(sub["quality"])) for sub in submissions]
|
200 |
+
scores = ["%.4f" % (float(sub["score"])) for sub in submissions]
|
201 |
+
df = pd.DataFrame(
|
202 |
+
{
|
203 |
+
"Name": names,
|
204 |
+
"Submission Time": times,
|
205 |
+
"Performance": performances,
|
206 |
+
"Quality": qualities,
|
207 |
+
"Score": scores,
|
208 |
+
}
|
209 |
+
).sort_values(by=["Score"])
|
210 |
+
df.insert(0, "Rank #", list(np.arange(len(names)) + 1), True)
|
211 |
+
|
212 |
+
def highlight_null(s):
|
213 |
+
con = s.copy()
|
214 |
+
con[:] = None
|
215 |
+
if s["Submission Time"] == "":
|
216 |
+
con[:] = "background-color: darkgrey"
|
217 |
+
return con
|
218 |
+
|
219 |
+
return df.style.apply(highlight_null, axis=1)
|
220 |
+
|
221 |
+
|
222 |
+
def process_submission(name, image):
|
223 |
+
original_image = Image.open("./image.png")
|
224 |
+
progress = gr.Progress()
|
225 |
+
progress(0, desc="Detecting Watermark")
|
226 |
+
performance = compute_performance(image)
|
227 |
+
progress(0.4, desc="Evaluating Image Quality")
|
228 |
+
quality = compute_quality(image, original_image)
|
229 |
+
progress(1.0, desc="Uploading Results")
|
230 |
+
|
231 |
+
# Save unprocessed values but display processed values
|
232 |
+
save_to_redis(name, performance, quality)
|
233 |
+
quality = QUALITY_POST_FUNC(quality)
|
234 |
+
performance = PERFORMANCE_POST_FUNC(performance)
|
235 |
+
|
236 |
+
submissions = get_submissions_from_redis()
|
237 |
+
leaderboard_table = update_table(submissions, current_name=name)
|
238 |
+
leaderboard_plot = update_plot(submissions, current_name=name)
|
239 |
+
|
240 |
+
# Calculate rank
|
241 |
+
distances = [
|
242 |
+
np.sqrt(float(s["quality"]) ** 2 + float(s["performance"]) ** 2)
|
243 |
+
for s in submissions
|
244 |
+
]
|
245 |
+
rank = sorted(distances).index(np.sqrt(quality**2 + performance**2)) + 1
|
246 |
+
gr.Info(f"You ranked {rank} out of {len(submissions)}!")
|
247 |
+
return (
|
248 |
+
leaderboard_plot,
|
249 |
+
leaderboard_table,
|
250 |
+
f"{rank} out of {len(submissions)}",
|
251 |
+
name,
|
252 |
+
f"{performance:.3f}",
|
253 |
+
f"{quality:.3f}",
|
254 |
+
f"{np.sqrt(quality**2 + performance**2):.3f}",
|
255 |
+
)
|
256 |
+
|
257 |
+
|
258 |
+
def upload_and_evaluate(name, image):
|
259 |
+
if name == "":
|
260 |
+
raise gr.Error("Please enter your name before submitting.")
|
261 |
+
if image is None:
|
262 |
+
raise gr.Error("Please upload an image before submitting.")
|
263 |
+
return process_submission(name, image)
|
264 |
+
|
265 |
+
|
266 |
+
def create_interface():
|
267 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=CSS, js=JS) as demo:
|
268 |
+
gr.Markdown(
|
269 |
+
"""
|
270 |
+
# Erasing the Invisible (Demo of NeurIPS'24 competition)
|
271 |
+
|
272 |
+
Welcome to the demo of the NeurIPS'24 competition [Erasing the Invisible: A Stress-Test Challenge for Image Watermarks](https://erasinginvisible.github.io/).
|
273 |
+
|
274 |
+
You could use this demo to better understand the competition pipeline or just for fun! 🎮
|
275 |
+
|
276 |
+
Here, we provide an image embedded with an invisible watermark. You only need to:
|
277 |
+
|
278 |
+
1. **Download** the original watermarked image. 🌊
|
279 |
+
2. **Remove** the invisible watermark using your preferred attack. 🧼
|
280 |
+
3. **Upload** your image. We will evaluate and rank your attack. 📊
|
281 |
+
|
282 |
+
That's it! 🚀
|
283 |
+
|
284 |
+
*Note: This is just a demo. The watermark used here is not necessarily representative of those used for the competition. To officially participate in the competition, please follow the guidelines [here](https://erasinginvisible.github.io/).*
|
285 |
+
"""
|
286 |
+
)
|
287 |
+
|
288 |
+
with gr.Tabs(elem_classes=["tabs"]) as tabs:
|
289 |
+
with gr.Tab("Original Watermarked Image", id="download"):
|
290 |
+
with gr.Column():
|
291 |
+
original_image = gr.Image(
|
292 |
+
value="./image.png",
|
293 |
+
format="png",
|
294 |
+
label="Original Watermarked Image",
|
295 |
+
show_label=True,
|
296 |
+
height=512,
|
297 |
+
width=512,
|
298 |
+
type="filepath",
|
299 |
+
show_download_button=False,
|
300 |
+
show_share_button=False,
|
301 |
+
show_fullscreen_button=False,
|
302 |
+
container=True,
|
303 |
+
elem_id="original_image",
|
304 |
+
)
|
305 |
+
with gr.Row():
|
306 |
+
download_btn = gr.DownloadButton(
|
307 |
+
"Download Watermarked Image",
|
308 |
+
value="./image.png",
|
309 |
+
elem_id="download_btn",
|
310 |
+
)
|
311 |
+
submit_btn = gr.Button(
|
312 |
+
"Submit Your Removal", elem_id="submit_btn"
|
313 |
+
)
|
314 |
+
|
315 |
+
with gr.Tab(
|
316 |
+
"Submit Watermark Removed Image",
|
317 |
+
id="submit",
|
318 |
+
elem_classes="gr-tab-header",
|
319 |
+
):
|
320 |
+
|
321 |
+
with gr.Column():
|
322 |
+
uploaded_image = gr.Image(
|
323 |
+
label="Your Watermark Removed Image",
|
324 |
+
format="png",
|
325 |
+
show_label=True,
|
326 |
+
height=512,
|
327 |
+
width=512,
|
328 |
+
sources=["upload"],
|
329 |
+
type="pil",
|
330 |
+
show_download_button=False,
|
331 |
+
show_share_button=False,
|
332 |
+
show_fullscreen_button=False,
|
333 |
+
container=True,
|
334 |
+
placeholder="Upload your watermark removed image",
|
335 |
+
elem_id="uploaded_image",
|
336 |
+
)
|
337 |
+
with gr.Row():
|
338 |
+
name_input = gr.Textbox(
|
339 |
+
label="Your Name", placeholder="Anonymous"
|
340 |
+
)
|
341 |
+
upload_btn = gr.Button("Upload and Evaluate")
|
342 |
+
|
343 |
+
with gr.Tab(
|
344 |
+
"Evaluation Results",
|
345 |
+
id="plot",
|
346 |
+
elem_classes="gr-tab-header",
|
347 |
+
):
|
348 |
+
gr.Markdown(
|
349 |
+
"The evaluation is based on two metrics, watermark performance ($$A$$) and image quality degradation ($$Q$$).",
|
350 |
+
latex_delimiters=[{"left": "$$", "right": "$$", "display": False}],
|
351 |
+
)
|
352 |
+
gr.Markdown(
|
353 |
+
"The lower the watermark performance and less quality degradation, the more effective the attack is. The overall score is $$\sqrt{Q^2+A^2}$$, the smaller the better.",
|
354 |
+
latex_delimiters=[{"left": "$$", "right": "$$", "display": False}],
|
355 |
+
)
|
356 |
+
gr.Markdown(
|
357 |
+
"""
|
358 |
+
<p>
|
359 |
+
<span style="display: inline-block; width: 20px;"></span>🟦: Baseline attacks
|
360 |
+
<span style="display: inline-block; width: 20px;"></span>🟢: Users' submissions
|
361 |
+
<span style="display: inline-block; width: 20px;"></span>⭐: Your current submission
|
362 |
+
</p>
|
363 |
+
<p><em>Note: The performance and quality metrics differ from those in the competition (as only one image is used here), but they still give you an idea of how effective your attack is.</em></p>
|
364 |
+
"""
|
365 |
+
)
|
366 |
+
with gr.Column():
|
367 |
+
leaderboard_plot = gr.Plot(
|
368 |
+
value=update_plot(get_submissions_from_redis()),
|
369 |
+
show_label=False,
|
370 |
+
elem_id="leaderboard_plot",
|
371 |
+
)
|
372 |
+
with gr.Row():
|
373 |
+
rank_output = gr.Textbox(label="Your Ranking")
|
374 |
+
name_output = gr.Textbox(label="Your Name")
|
375 |
+
performance_output = gr.Textbox(label="Watermark Performance")
|
376 |
+
quality_output = gr.Textbox(label="Quality Degredation")
|
377 |
+
overall_output = gr.Textbox(label="Overall Score")
|
378 |
+
with gr.Tab(
|
379 |
+
"Leaderboard",
|
380 |
+
id="leaderboard",
|
381 |
+
elem_classes="gr-tab-header",
|
382 |
+
):
|
383 |
+
gr.Markdown("Find your ranking on the leaderboard!")
|
384 |
+
gr.Markdown(
|
385 |
+
"Gray-shaded rows are baseline results provided by the organziers."
|
386 |
+
)
|
387 |
+
with gr.Column():
|
388 |
+
leaderboard_table = gr.Dataframe(
|
389 |
+
value=update_table(get_submissions_from_redis()),
|
390 |
+
show_label=False,
|
391 |
+
elem_id="leaderboard_table",
|
392 |
+
)
|
393 |
+
submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs)
|
394 |
+
|
395 |
+
upload_btn.click(lambda: gr.Tabs(selected="plot"), None, tabs).then(
|
396 |
+
upload_and_evaluate,
|
397 |
+
inputs=[name_input, uploaded_image],
|
398 |
+
outputs=[
|
399 |
+
leaderboard_plot,
|
400 |
+
leaderboard_table,
|
401 |
+
rank_output,
|
402 |
+
name_output,
|
403 |
+
performance_output,
|
404 |
+
quality_output,
|
405 |
+
overall_output,
|
406 |
+
],
|
407 |
+
)
|
408 |
+
|
409 |
+
demo.load(
|
410 |
+
lambda: [
|
411 |
+
gr.Image(value="./image.png", height=512, width=512),
|
412 |
+
gr.Plot(update_plot(get_submissions_from_redis())),
|
413 |
+
gr.Dataframe(update_table(get_submissions_from_redis())),
|
414 |
+
],
|
415 |
+
outputs=[original_image, leaderboard_plot, leaderboard_table],
|
416 |
+
)
|
417 |
+
|
418 |
+
return demo
|
419 |
+
|
420 |
+
|
421 |
+
# Create the demo object
|
422 |
+
demo = create_interface()
|
423 |
+
|
424 |
+
# Launch the app
|
425 |
+
if __name__ == "__main__":
|
426 |
+
demo.launch(share=False)
|
icon.jpg
ADDED
![]() |
image.png
ADDED
![]() |
kit/__init__.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime as ort
|
5 |
+
from PIL import Image
|
6 |
+
import dotenv
|
7 |
+
|
8 |
+
dotenv.load_dotenv()
|
9 |
+
|
10 |
+
GT_MESSAGE = os.environ["GT_MESSAGE"]
|
11 |
+
|
12 |
+
|
13 |
+
QUALITY_COEFFICIENTS = {
|
14 |
+
"psnr": -0.0022186489180419534,
|
15 |
+
"ssim": -0.11337077856710862,
|
16 |
+
"nmi": -0.09878221979274945,
|
17 |
+
"lpips": 0.3412626374646173,
|
18 |
+
}
|
19 |
+
|
20 |
+
QUALITY_OFFSETS = {
|
21 |
+
"psnr": 43.54757854447622,
|
22 |
+
"ssim": 0.984229018845295,
|
23 |
+
"nmi": 1.7536553655336136,
|
24 |
+
"lpips": 0.014247652621287854,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def compute_performance(image):
|
29 |
+
session_options = ort.SessionOptions()
|
30 |
+
session_options.intra_op_num_threads = 1
|
31 |
+
session_options.inter_op_num_threads = 1
|
32 |
+
session_options.log_severity_level = 3
|
33 |
+
model = ort.InferenceSession(
|
34 |
+
"./kit/models/stable_signature.onnx",
|
35 |
+
sess_options=session_options,
|
36 |
+
)
|
37 |
+
inputs = np.stack(
|
38 |
+
[
|
39 |
+
(
|
40 |
+
(
|
41 |
+
np.array(
|
42 |
+
image,
|
43 |
+
dtype=np.float32,
|
44 |
+
)
|
45 |
+
/ 255.0
|
46 |
+
- [0.485, 0.456, 0.406]
|
47 |
+
)
|
48 |
+
/ [0.229, 0.224, 0.225]
|
49 |
+
)
|
50 |
+
.transpose((2, 0, 1))
|
51 |
+
.astype(np.float32)
|
52 |
+
],
|
53 |
+
axis=0,
|
54 |
+
)
|
55 |
+
|
56 |
+
outputs = model.run(
|
57 |
+
None,
|
58 |
+
{
|
59 |
+
"image": inputs,
|
60 |
+
},
|
61 |
+
)
|
62 |
+
decoded = (outputs[0] > 0).astype(int)[0]
|
63 |
+
gt_message = np.array([int(bit) for bit in GT_MESSAGE])
|
64 |
+
return 1 - np.mean(gt_message != decoded)
|
65 |
+
|
66 |
+
|
67 |
+
from .metrics import (
|
68 |
+
compute_image_distance_repeated,
|
69 |
+
load_perceptual_models,
|
70 |
+
compute_perceptual_metric_repeated,
|
71 |
+
load_aesthetics_and_artifacts_models,
|
72 |
+
compute_aesthetics_and_artifacts_scores,
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def compute_quality(attacked_image, clean_image, quiet=True):
|
77 |
+
|
78 |
+
# Compress the image
|
79 |
+
buffer = io.BytesIO()
|
80 |
+
attacked_image.save(buffer, format="JPEG", quality=95)
|
81 |
+
buffer.seek(0)
|
82 |
+
|
83 |
+
# Update attacked_image with the compressed version
|
84 |
+
attacked_image = Image.open(buffer)
|
85 |
+
|
86 |
+
modes = ["psnr", "ssim", "nmi", "lpips"]
|
87 |
+
|
88 |
+
results = {}
|
89 |
+
for mode in modes:
|
90 |
+
if mode in ["psnr", "ssim", "nmi"]:
|
91 |
+
metrics = compute_image_distance_repeated(
|
92 |
+
[clean_image],
|
93 |
+
[attacked_image],
|
94 |
+
metric_name=mode,
|
95 |
+
num_workers=1,
|
96 |
+
verbose=not quiet,
|
97 |
+
)
|
98 |
+
results[mode] = metrics
|
99 |
+
|
100 |
+
elif mode == "lpips":
|
101 |
+
model = load_perceptual_models(
|
102 |
+
mode,
|
103 |
+
mode="alex",
|
104 |
+
device="cpu",
|
105 |
+
)
|
106 |
+
metrics = compute_perceptual_metric_repeated(
|
107 |
+
[clean_image],
|
108 |
+
[attacked_image],
|
109 |
+
metric_name=mode,
|
110 |
+
mode="alex",
|
111 |
+
model=model,
|
112 |
+
device="cpu",
|
113 |
+
)
|
114 |
+
results[mode] = metrics
|
115 |
+
|
116 |
+
normalized_quality = 0
|
117 |
+
for key, value in results.items():
|
118 |
+
normalized_quality += (value[0] - QUALITY_OFFSETS[key]) * QUALITY_COEFFICIENTS[
|
119 |
+
key
|
120 |
+
]
|
121 |
+
return normalized_quality
|
kit/metrics/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .distributional import compute_fid
|
2 |
+
from .image import (
|
3 |
+
compute_mse,
|
4 |
+
compute_psnr,
|
5 |
+
compute_ssim,
|
6 |
+
compute_nmi,
|
7 |
+
compute_mse_repeated,
|
8 |
+
compute_psnr_repeated,
|
9 |
+
compute_ssim_repeated,
|
10 |
+
compute_nmi_repeated,
|
11 |
+
compute_image_distance_repeated,
|
12 |
+
)
|
13 |
+
from .perceptual import (
|
14 |
+
load_perceptual_models,
|
15 |
+
compute_lpips,
|
16 |
+
compute_lpips_repeated,
|
17 |
+
compute_perceptual_metric_repeated,
|
18 |
+
)
|
19 |
+
from .aesthetics import (
|
20 |
+
load_aesthetics_and_artifacts_models,
|
21 |
+
compute_aesthetics_and_artifacts_scores,
|
22 |
+
)
|
23 |
+
from .clip import load_open_clip_model_preprocess_and_tokenizer, compute_clip_score
|
24 |
+
from .prompt import (
|
25 |
+
load_perplexity_model_and_tokenizer,
|
26 |
+
compute_prompt_perplexity,
|
27 |
+
)
|
kit/metrics/aesthetics.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from transformers import CLIPModel, CLIPProcessor
|
4 |
+
from .aesthetics_scorer import preprocess, load_model
|
5 |
+
|
6 |
+
|
7 |
+
def load_aesthetics_and_artifacts_models(device=torch.device("cuda")):
|
8 |
+
model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
9 |
+
vision_model = model.vision_model
|
10 |
+
vision_model.to(device)
|
11 |
+
del model
|
12 |
+
clip_processor = CLIPProcessor.from_pretrained(
|
13 |
+
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
14 |
+
)
|
15 |
+
rating_model = load_model("aesthetics_scorer_rating_openclip_vit_h_14").to(device)
|
16 |
+
artifacts_model = load_model("aesthetics_scorer_artifacts_openclip_vit_h_14").to(
|
17 |
+
device
|
18 |
+
)
|
19 |
+
return vision_model, clip_processor, rating_model, artifacts_model
|
20 |
+
|
21 |
+
|
22 |
+
def compute_aesthetics_and_artifacts_scores(
|
23 |
+
images, models, device=torch.device("cuda")
|
24 |
+
):
|
25 |
+
vision_model, clip_processor, rating_model, artifacts_model = models
|
26 |
+
|
27 |
+
inputs = clip_processor(images=images, return_tensors="pt").to(device)
|
28 |
+
with torch.no_grad():
|
29 |
+
vision_output = vision_model(**inputs)
|
30 |
+
pooled_output = vision_output.pooler_output
|
31 |
+
embedding = preprocess(pooled_output)
|
32 |
+
with torch.no_grad():
|
33 |
+
rating = rating_model(embedding)
|
34 |
+
artifact = artifacts_model(embedding)
|
35 |
+
return (
|
36 |
+
rating.detach().cpu().numpy().flatten().tolist(),
|
37 |
+
artifact.detach().cpu().numpy().flatten().tolist(),
|
38 |
+
)
|
kit/metrics/aesthetics_scorer/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
From https://github.com/kenjiqq/aesthetics-scorer#validation-split-of-diffusiondb-dataset
|
3 |
+
"""
|
4 |
+
from .model import preprocess, load_model
|
kit/metrics/aesthetics_scorer/model.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import inspect
|
6 |
+
|
7 |
+
|
8 |
+
class AestheticScorer(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
input_size=0,
|
12 |
+
use_activation=False,
|
13 |
+
dropout=0.2,
|
14 |
+
config=None,
|
15 |
+
hidden_dim=1024,
|
16 |
+
reduce_dims=False,
|
17 |
+
output_activation=None,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.config = {
|
21 |
+
"input_size": input_size,
|
22 |
+
"use_activation": use_activation,
|
23 |
+
"dropout": dropout,
|
24 |
+
"hidden_dim": hidden_dim,
|
25 |
+
"reduce_dims": reduce_dims,
|
26 |
+
"output_activation": output_activation,
|
27 |
+
}
|
28 |
+
if config != None:
|
29 |
+
self.config.update(config)
|
30 |
+
|
31 |
+
layers = [
|
32 |
+
nn.Linear(self.config["input_size"], self.config["hidden_dim"]),
|
33 |
+
nn.ReLU() if self.config["use_activation"] else None,
|
34 |
+
nn.Dropout(self.config["dropout"]),
|
35 |
+
nn.Linear(
|
36 |
+
self.config["hidden_dim"],
|
37 |
+
round(self.config["hidden_dim"] / (2 if reduce_dims else 1)),
|
38 |
+
),
|
39 |
+
nn.ReLU() if self.config["use_activation"] else None,
|
40 |
+
nn.Dropout(self.config["dropout"]),
|
41 |
+
nn.Linear(
|
42 |
+
round(self.config["hidden_dim"] / (2 if reduce_dims else 1)),
|
43 |
+
round(self.config["hidden_dim"] / (4 if reduce_dims else 1)),
|
44 |
+
),
|
45 |
+
nn.ReLU() if self.config["use_activation"] else None,
|
46 |
+
nn.Dropout(self.config["dropout"]),
|
47 |
+
nn.Linear(
|
48 |
+
round(self.config["hidden_dim"] / (4 if reduce_dims else 1)),
|
49 |
+
round(self.config["hidden_dim"] / (8 if reduce_dims else 1)),
|
50 |
+
),
|
51 |
+
nn.ReLU() if self.config["use_activation"] else None,
|
52 |
+
nn.Linear(round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 1),
|
53 |
+
]
|
54 |
+
if self.config["output_activation"] == "sigmoid":
|
55 |
+
layers.append(nn.Sigmoid())
|
56 |
+
layers = [x for x in layers if x is not None]
|
57 |
+
self.layers = nn.Sequential(*layers)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
if self.config["output_activation"] == "sigmoid":
|
61 |
+
upper, lower = 10, 1
|
62 |
+
scale = upper - lower
|
63 |
+
return (self.layers(x) * scale) + lower
|
64 |
+
else:
|
65 |
+
return self.layers(x)
|
66 |
+
|
67 |
+
def save(self, save_name):
|
68 |
+
split_name = os.path.splitext(save_name)
|
69 |
+
with open(f"{split_name[0]}.config", "w") as outfile:
|
70 |
+
outfile.write(json.dumps(self.config, indent=4))
|
71 |
+
|
72 |
+
for i in range(
|
73 |
+
6
|
74 |
+
): # saving sometiles fails, so retry 5 times, might be windows issue
|
75 |
+
try:
|
76 |
+
torch.save(self.state_dict(), save_name)
|
77 |
+
break
|
78 |
+
except RuntimeError as e:
|
79 |
+
# check if error contains string "File"
|
80 |
+
if "cannot be opened" in str(e) and i < 5:
|
81 |
+
print("Model save failed, retrying...")
|
82 |
+
else:
|
83 |
+
raise e
|
84 |
+
|
85 |
+
|
86 |
+
def preprocess(embeddings):
|
87 |
+
return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
|
88 |
+
|
89 |
+
|
90 |
+
def load_model(weight_name, device="cuda" if torch.cuda.is_available() else "cpu"):
|
91 |
+
weight_folder = os.path.abspath(
|
92 |
+
os.path.join(
|
93 |
+
inspect.getfile(load_model),
|
94 |
+
"../weights",
|
95 |
+
)
|
96 |
+
)
|
97 |
+
weight_path = os.path.join(weight_folder, f"{weight_name}.pth")
|
98 |
+
config_path = os.path.join(weight_folder, f"{weight_name}.config")
|
99 |
+
with open(config_path, "r") as config_file:
|
100 |
+
config = json.load(config_file)
|
101 |
+
model = AestheticScorer(config=config)
|
102 |
+
model.load_state_dict(torch.load(weight_path, map_location=device))
|
103 |
+
model.eval()
|
104 |
+
return model
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.config
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input_size": 1664,
|
3 |
+
"use_activation": false,
|
4 |
+
"dropout": 0.0,
|
5 |
+
"hidden_dim": 1024,
|
6 |
+
"reduce_dims": false,
|
7 |
+
"output_activation": null
|
8 |
+
}
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39a5d014670226d52c408e0dfec840b7626d80a73d003a6a144caafd5e02d031
|
3 |
+
size 19423219
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.config
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input_size": 1280,
|
3 |
+
"use_activation": false,
|
4 |
+
"dropout": 0.0,
|
5 |
+
"hidden_dim": 1024,
|
6 |
+
"reduce_dims": false,
|
7 |
+
"output_activation": null
|
8 |
+
}
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc48a8a2315cfdbc7bb8278be55f645e8a995e1a2fa234baec5eb41c4d33e070
|
3 |
+
size 17850319
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.config
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input_size": 1024,
|
3 |
+
"use_activation": false,
|
4 |
+
"dropout": 0.0,
|
5 |
+
"hidden_dim": 1024,
|
6 |
+
"reduce_dims": false,
|
7 |
+
"output_activation": null
|
8 |
+
}
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4a9481fdbce5ff02b252bcb25109b9f3b29841289fadf7e79e884d59f9357d5
|
3 |
+
size 16801743
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.config
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input_size": 1664,
|
3 |
+
"use_activation": false,
|
4 |
+
"dropout": 0.0,
|
5 |
+
"hidden_dim": 1024,
|
6 |
+
"reduce_dims": false,
|
7 |
+
"output_activation": null
|
8 |
+
}
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19b016304f54ae866e27f1eb498c0861f704958e7c37693adc5ce094e63904a8
|
3 |
+
size 19423099
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.config
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input_size": 1280,
|
3 |
+
"use_activation": false,
|
4 |
+
"dropout": 0.0,
|
5 |
+
"hidden_dim": 1024,
|
6 |
+
"reduce_dims": false,
|
7 |
+
"output_activation": null
|
8 |
+
}
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03603eee1864c2e5e97ef7079229609653db5b10594ca8b1de9e541d838cae9c
|
3 |
+
size 17850199
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.config
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input_size": 1024,
|
3 |
+
"use_activation": false,
|
4 |
+
"dropout": 0.0,
|
5 |
+
"hidden_dim": 1024,
|
6 |
+
"reduce_dims": false,
|
7 |
+
"output_activation": null
|
8 |
+
}
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eb7fe561369ab6c7dad34b9316a56d2c6070582f0323656148e1107a242cd666
|
3 |
+
size 16801623
|
kit/metrics/clean_fid/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
From https://github.com/GaParmar/clean-fid/tree/main
|
3 |
+
"""
|
kit/metrics/clean_fid/clip_features.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install git+https://github.com/openai/CLIP.git
|
2 |
+
import pdb
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import clip
|
8 |
+
from .fid import compute_fid
|
9 |
+
|
10 |
+
|
11 |
+
def img_preprocess_clip(img_np):
|
12 |
+
x = Image.fromarray(img_np.astype(np.uint8)).convert("RGB")
|
13 |
+
T = transforms.Compose(
|
14 |
+
[
|
15 |
+
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
|
16 |
+
transforms.CenterCrop(224),
|
17 |
+
]
|
18 |
+
)
|
19 |
+
return np.asarray(T(x)).clip(0, 255).astype(np.uint8)
|
20 |
+
|
21 |
+
|
22 |
+
class CLIP_fx:
|
23 |
+
def __init__(self, name="ViT-B/32", device="cuda"):
|
24 |
+
self.model, _ = clip.load(name, device=device)
|
25 |
+
self.model.eval()
|
26 |
+
self.name = "clip_" + name.lower().replace("-", "_").replace("/", "_")
|
27 |
+
|
28 |
+
def __call__(self, img_t):
|
29 |
+
img_x = img_t / 255.0
|
30 |
+
T_norm = transforms.Normalize(
|
31 |
+
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
|
32 |
+
)
|
33 |
+
img_x = T_norm(img_x)
|
34 |
+
assert torch.is_tensor(img_x)
|
35 |
+
if len(img_x.shape) == 3:
|
36 |
+
img_x = img_x.unsqueeze(0)
|
37 |
+
B, C, H, W = img_x.shape
|
38 |
+
with torch.no_grad():
|
39 |
+
z = self.model.encode_image(img_x)
|
40 |
+
return z
|
kit/metrics/clean_fid/downloads_helper.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib.request
|
3 |
+
import requests
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
|
7 |
+
inception_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt"
|
8 |
+
|
9 |
+
|
10 |
+
"""
|
11 |
+
Download the pretrined inception weights if it does not exists
|
12 |
+
ARGS:
|
13 |
+
fpath - output folder path
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
def check_download_inception(fpath="./"):
|
18 |
+
inception_path = os.path.join(fpath, "inception-2015-12-05.pt")
|
19 |
+
if not os.path.exists(inception_path):
|
20 |
+
# download the file
|
21 |
+
with urllib.request.urlopen(inception_url) as response, open(
|
22 |
+
inception_path, "wb"
|
23 |
+
) as f:
|
24 |
+
shutil.copyfileobj(response, f)
|
25 |
+
return inception_path
|
26 |
+
|
27 |
+
|
28 |
+
"""
|
29 |
+
Download any url if it does not exist
|
30 |
+
ARGS:
|
31 |
+
local_folder - output folder path
|
32 |
+
url - the weburl to download
|
33 |
+
"""
|
34 |
+
|
35 |
+
|
36 |
+
def check_download_url(local_folder, url):
|
37 |
+
name = os.path.basename(url)
|
38 |
+
local_path = os.path.join(local_folder, name)
|
39 |
+
if not os.path.exists(local_path):
|
40 |
+
os.makedirs(local_folder, exist_ok=True)
|
41 |
+
print(f"downloading statistics to {local_path}")
|
42 |
+
with urllib.request.urlopen(url) as response, open(local_path, "wb") as f:
|
43 |
+
shutil.copyfileobj(response, f)
|
44 |
+
return local_path
|
45 |
+
|
46 |
+
|
47 |
+
"""
|
48 |
+
Download a file from google drive
|
49 |
+
ARGS:
|
50 |
+
file_id - id of the google drive file
|
51 |
+
out_path - output folder path
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
def download_google_drive(file_id, out_path):
|
56 |
+
def get_confirm_token(response):
|
57 |
+
for key, value in response.cookies.items():
|
58 |
+
if key.startswith("download_warning"):
|
59 |
+
return value
|
60 |
+
return None
|
61 |
+
|
62 |
+
URL = "https://drive.google.com/uc?export=download"
|
63 |
+
session = requests.Session()
|
64 |
+
response = session.get(URL, params={"id": file_id}, stream=True)
|
65 |
+
token = get_confirm_token(response)
|
66 |
+
|
67 |
+
if token:
|
68 |
+
params = {"id": file_id, "confirm": token}
|
69 |
+
response = session.get(URL, params=params, stream=True)
|
70 |
+
|
71 |
+
CHUNK_SIZE = 32768
|
72 |
+
with open(out_path, "wb") as f:
|
73 |
+
for chunk in response.iter_content(CHUNK_SIZE):
|
74 |
+
if chunk:
|
75 |
+
f.write(chunk)
|
kit/metrics/clean_fid/features.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
helpers for extracting features from image
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import platform
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.hub import get_dir
|
9 |
+
from .downloads_helper import check_download_url
|
10 |
+
from .inception_pytorch import InceptionV3
|
11 |
+
from .inception_torchscript import InceptionV3W
|
12 |
+
|
13 |
+
|
14 |
+
"""
|
15 |
+
returns a functions that takes an image in range [0,255]
|
16 |
+
and outputs a feature embedding vector
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
def feature_extractor(
|
21 |
+
name="torchscript_inception",
|
22 |
+
device=torch.device("cuda"),
|
23 |
+
resize_inside=False,
|
24 |
+
use_dataparallel=True,
|
25 |
+
):
|
26 |
+
if name == "torchscript_inception":
|
27 |
+
path = "./" if platform.system() == "Windows" else "/tmp"
|
28 |
+
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(
|
29 |
+
device
|
30 |
+
)
|
31 |
+
model.eval()
|
32 |
+
if use_dataparallel:
|
33 |
+
model = torch.nn.DataParallel(model)
|
34 |
+
|
35 |
+
def model_fn(x):
|
36 |
+
return model(x)
|
37 |
+
|
38 |
+
elif name == "pytorch_inception":
|
39 |
+
model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
|
40 |
+
model.eval()
|
41 |
+
if use_dataparallel:
|
42 |
+
model = torch.nn.DataParallel(model)
|
43 |
+
|
44 |
+
def model_fn(x):
|
45 |
+
return model(x / 255)[0].squeeze(-1).squeeze(-1)
|
46 |
+
|
47 |
+
else:
|
48 |
+
raise ValueError(f"{name} feature extractor not implemented")
|
49 |
+
return model_fn
|
50 |
+
|
51 |
+
|
52 |
+
"""
|
53 |
+
Build a feature extractor for each of the modes
|
54 |
+
"""
|
55 |
+
|
56 |
+
|
57 |
+
def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
|
58 |
+
if mode == "legacy_pytorch":
|
59 |
+
feat_model = feature_extractor(
|
60 |
+
name="pytorch_inception",
|
61 |
+
resize_inside=False,
|
62 |
+
device=device,
|
63 |
+
use_dataparallel=use_dataparallel,
|
64 |
+
)
|
65 |
+
elif mode == "legacy_tensorflow":
|
66 |
+
feat_model = feature_extractor(
|
67 |
+
name="torchscript_inception",
|
68 |
+
resize_inside=True,
|
69 |
+
device=device,
|
70 |
+
use_dataparallel=use_dataparallel,
|
71 |
+
)
|
72 |
+
elif mode == "clean":
|
73 |
+
feat_model = feature_extractor(
|
74 |
+
name="torchscript_inception",
|
75 |
+
resize_inside=False,
|
76 |
+
device=device,
|
77 |
+
use_dataparallel=use_dataparallel,
|
78 |
+
)
|
79 |
+
return feat_model
|
80 |
+
|
81 |
+
|
82 |
+
"""
|
83 |
+
Load precomputed reference statistics for commonly used datasets
|
84 |
+
"""
|
85 |
+
|
86 |
+
|
87 |
+
def get_reference_statistics(
|
88 |
+
name,
|
89 |
+
res,
|
90 |
+
mode="clean",
|
91 |
+
model_name="inception_v3",
|
92 |
+
seed=0,
|
93 |
+
split="test",
|
94 |
+
metric="FID",
|
95 |
+
):
|
96 |
+
base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
|
97 |
+
if split == "custom":
|
98 |
+
res = "na"
|
99 |
+
if model_name == "inception_v3":
|
100 |
+
model_modifier = ""
|
101 |
+
else:
|
102 |
+
model_modifier = "_" + model_name
|
103 |
+
if metric == "FID":
|
104 |
+
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz")
|
105 |
+
url = f"{base_url}/{rel_path}"
|
106 |
+
stats_folder = os.path.join(get_dir(), "fid_stats")
|
107 |
+
fpath = check_download_url(local_folder=stats_folder, url=url)
|
108 |
+
stats = np.load(fpath)
|
109 |
+
mu, sigma = stats["mu"], stats["sigma"]
|
110 |
+
return mu, sigma
|
111 |
+
elif metric == "KID":
|
112 |
+
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz")
|
113 |
+
url = f"{base_url}/{rel_path}"
|
114 |
+
stats_folder = os.path.join(get_dir(), "fid_stats")
|
115 |
+
fpath = check_download_url(local_folder=stats_folder, url=url)
|
116 |
+
stats = np.load(fpath)
|
117 |
+
return stats["feats"]
|
kit/metrics/clean_fid/fid.py
ADDED
@@ -0,0 +1,836 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from tqdm.auto import tqdm
|
4 |
+
from glob import glob
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from scipy import linalg
|
9 |
+
import zipfile
|
10 |
+
from torch.hub import get_dir
|
11 |
+
from .utils import *
|
12 |
+
from .features import build_feature_extractor, get_reference_statistics
|
13 |
+
from .resize import *
|
14 |
+
|
15 |
+
|
16 |
+
"""
|
17 |
+
Numpy implementation of the Frechet Distance.
|
18 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
19 |
+
and X_2 ~ N(mu_2, C_2) is
|
20 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
21 |
+
Stable version by Danica J. Sutherland.
|
22 |
+
Params:
|
23 |
+
mu1 : Numpy array containing the activations of a layer of the
|
24 |
+
inception net (like returned by the function 'get_predictions')
|
25 |
+
for generated samples.
|
26 |
+
mu2 : The sample mean over activations, precalculated on an
|
27 |
+
representative data set.
|
28 |
+
sigma1: The covariance matrix over activations for generated samples.
|
29 |
+
sigma2: The covariance matrix over activations, precalculated on an
|
30 |
+
representative data set.
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
35 |
+
mu1 = np.atleast_1d(mu1)
|
36 |
+
mu2 = np.atleast_1d(mu2)
|
37 |
+
sigma1 = np.atleast_2d(sigma1)
|
38 |
+
sigma2 = np.atleast_2d(sigma2)
|
39 |
+
|
40 |
+
assert (
|
41 |
+
mu1.shape == mu2.shape
|
42 |
+
), "Training and test mean vectors have different lengths"
|
43 |
+
assert (
|
44 |
+
sigma1.shape == sigma2.shape
|
45 |
+
), "Training and test covariances have different dimensions"
|
46 |
+
|
47 |
+
diff = mu1 - mu2
|
48 |
+
|
49 |
+
# Product might be almost singular
|
50 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
51 |
+
if not np.isfinite(covmean).all():
|
52 |
+
msg = (
|
53 |
+
"fid calculation produces singular product; "
|
54 |
+
"adding %s to diagonal of cov estimates"
|
55 |
+
) % eps
|
56 |
+
print(msg)
|
57 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
58 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
59 |
+
|
60 |
+
# Numerical error might give slight imaginary component
|
61 |
+
if np.iscomplexobj(covmean):
|
62 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
63 |
+
m = np.max(np.abs(covmean.imag))
|
64 |
+
raise ValueError("Imaginary component {}".format(m))
|
65 |
+
covmean = covmean.real
|
66 |
+
|
67 |
+
tr_covmean = np.trace(covmean)
|
68 |
+
|
69 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
70 |
+
|
71 |
+
|
72 |
+
"""
|
73 |
+
Compute the KID score given the sets of features
|
74 |
+
"""
|
75 |
+
|
76 |
+
|
77 |
+
def kernel_distance(feats1, feats2, num_subsets=100, max_subset_size=1000):
|
78 |
+
n = feats1.shape[1]
|
79 |
+
m = min(min(feats1.shape[0], feats2.shape[0]), max_subset_size)
|
80 |
+
t = 0
|
81 |
+
for _subset_idx in range(num_subsets):
|
82 |
+
x = feats2[np.random.choice(feats2.shape[0], m, replace=False)]
|
83 |
+
y = feats1[np.random.choice(feats1.shape[0], m, replace=False)]
|
84 |
+
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
85 |
+
b = (x @ y.T / n + 1) ** 3
|
86 |
+
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
87 |
+
kid = t / num_subsets / m
|
88 |
+
return float(kid)
|
89 |
+
|
90 |
+
|
91 |
+
"""
|
92 |
+
Compute the inception features for a batch of images
|
93 |
+
"""
|
94 |
+
|
95 |
+
|
96 |
+
def get_batch_features(batch, model, device):
|
97 |
+
with torch.no_grad():
|
98 |
+
feat = model(batch.to(device))
|
99 |
+
return feat.detach().cpu().numpy()
|
100 |
+
|
101 |
+
|
102 |
+
"""
|
103 |
+
Compute the inception features for a list of files
|
104 |
+
"""
|
105 |
+
|
106 |
+
|
107 |
+
def get_files_features(
|
108 |
+
l_files,
|
109 |
+
model=None,
|
110 |
+
num_workers=12,
|
111 |
+
batch_size=128,
|
112 |
+
device=torch.device("cuda"),
|
113 |
+
mode="clean",
|
114 |
+
custom_fn_resize=None,
|
115 |
+
description="",
|
116 |
+
fdir=None,
|
117 |
+
verbose=True,
|
118 |
+
custom_image_tranform=None,
|
119 |
+
):
|
120 |
+
# wrap the images in a dataloader for parallelizing the resize operation
|
121 |
+
dataset = ResizeDataset(l_files, fdir=fdir, mode=mode)
|
122 |
+
if custom_image_tranform is not None:
|
123 |
+
dataset.custom_image_tranform = custom_image_tranform
|
124 |
+
if custom_fn_resize is not None:
|
125 |
+
dataset.fn_resize = custom_fn_resize
|
126 |
+
|
127 |
+
dataloader = torch.utils.data.DataLoader(
|
128 |
+
dataset,
|
129 |
+
batch_size=batch_size,
|
130 |
+
shuffle=False,
|
131 |
+
drop_last=False,
|
132 |
+
num_workers=num_workers,
|
133 |
+
)
|
134 |
+
|
135 |
+
# collect all inception features
|
136 |
+
l_feats = []
|
137 |
+
if verbose:
|
138 |
+
pbar = tqdm(dataloader, desc=description)
|
139 |
+
else:
|
140 |
+
pbar = dataloader
|
141 |
+
|
142 |
+
for batch in pbar:
|
143 |
+
l_feats.append(get_batch_features(batch, model, device))
|
144 |
+
np_feats = np.concatenate(l_feats)
|
145 |
+
return np_feats
|
146 |
+
|
147 |
+
|
148 |
+
"""
|
149 |
+
Compute the inception features for a folder of image files
|
150 |
+
"""
|
151 |
+
|
152 |
+
|
153 |
+
def get_folder_features(
|
154 |
+
fdir,
|
155 |
+
model=None,
|
156 |
+
num_workers=12,
|
157 |
+
num=None,
|
158 |
+
shuffle=False,
|
159 |
+
seed=0,
|
160 |
+
batch_size=128,
|
161 |
+
device=torch.device("cuda"),
|
162 |
+
mode="clean",
|
163 |
+
custom_fn_resize=None,
|
164 |
+
description="",
|
165 |
+
verbose=True,
|
166 |
+
custom_image_tranform=None,
|
167 |
+
):
|
168 |
+
# get all relevant files in the dataset
|
169 |
+
if ".zip" in fdir:
|
170 |
+
files = list(set(zipfile.ZipFile(fdir).namelist()))
|
171 |
+
# remove the non-image files inside the zip
|
172 |
+
files = [x for x in files if os.path.splitext(x)[1].lower()[1:] in EXTENSIONS]
|
173 |
+
else:
|
174 |
+
files = sorted(
|
175 |
+
[
|
176 |
+
file
|
177 |
+
for ext in EXTENSIONS
|
178 |
+
for file in glob(os.path.join(fdir, f"**/*.{ext}"), recursive=True)
|
179 |
+
]
|
180 |
+
)
|
181 |
+
# use a subset number of files if needed
|
182 |
+
if num is not None:
|
183 |
+
if shuffle:
|
184 |
+
random.seed(seed)
|
185 |
+
random.shuffle(files)
|
186 |
+
files = files[:num]
|
187 |
+
np_feats = get_files_features(
|
188 |
+
files,
|
189 |
+
model,
|
190 |
+
num_workers=num_workers,
|
191 |
+
batch_size=batch_size,
|
192 |
+
device=device,
|
193 |
+
mode=mode,
|
194 |
+
custom_fn_resize=custom_fn_resize,
|
195 |
+
custom_image_tranform=custom_image_tranform,
|
196 |
+
description=description,
|
197 |
+
fdir=fdir,
|
198 |
+
verbose=verbose,
|
199 |
+
)
|
200 |
+
return np_feats
|
201 |
+
|
202 |
+
|
203 |
+
"""
|
204 |
+
Compute the FID score given the inception features stack
|
205 |
+
"""
|
206 |
+
|
207 |
+
|
208 |
+
def fid_from_feats(feats1, feats2):
|
209 |
+
mu1, sig1 = np.mean(feats1, axis=0), np.cov(feats1, rowvar=False)
|
210 |
+
mu2, sig2 = np.mean(feats2, axis=0), np.cov(feats2, rowvar=False)
|
211 |
+
return frechet_distance(mu1, sig1, mu2, sig2)
|
212 |
+
|
213 |
+
|
214 |
+
"""
|
215 |
+
Computes the FID score for a folder of images for a specific dataset
|
216 |
+
and a specific resolution
|
217 |
+
"""
|
218 |
+
|
219 |
+
|
220 |
+
def fid_folder(
|
221 |
+
fdir,
|
222 |
+
dataset_name,
|
223 |
+
dataset_res,
|
224 |
+
dataset_split,
|
225 |
+
model=None,
|
226 |
+
mode="clean",
|
227 |
+
model_name="inception_v3",
|
228 |
+
num_workers=12,
|
229 |
+
batch_size=128,
|
230 |
+
device=torch.device("cuda"),
|
231 |
+
verbose=True,
|
232 |
+
custom_image_tranform=None,
|
233 |
+
custom_fn_resize=None,
|
234 |
+
):
|
235 |
+
# Load reference FID statistics (download if needed)
|
236 |
+
ref_mu, ref_sigma = get_reference_statistics(
|
237 |
+
dataset_name,
|
238 |
+
dataset_res,
|
239 |
+
mode=mode,
|
240 |
+
model_name=model_name,
|
241 |
+
seed=0,
|
242 |
+
split=dataset_split,
|
243 |
+
)
|
244 |
+
fbname = os.path.basename(fdir)
|
245 |
+
# get all inception features for folder images
|
246 |
+
np_feats = get_folder_features(
|
247 |
+
fdir,
|
248 |
+
model,
|
249 |
+
num_workers=num_workers,
|
250 |
+
batch_size=batch_size,
|
251 |
+
device=device,
|
252 |
+
mode=mode,
|
253 |
+
description=f"FID {fbname} : ",
|
254 |
+
verbose=verbose,
|
255 |
+
custom_image_tranform=custom_image_tranform,
|
256 |
+
custom_fn_resize=custom_fn_resize,
|
257 |
+
)
|
258 |
+
mu = np.mean(np_feats, axis=0)
|
259 |
+
sigma = np.cov(np_feats, rowvar=False)
|
260 |
+
fid = frechet_distance(mu, sigma, ref_mu, ref_sigma)
|
261 |
+
return fid
|
262 |
+
|
263 |
+
|
264 |
+
"""
|
265 |
+
Compute the FID stats from a generator model
|
266 |
+
"""
|
267 |
+
|
268 |
+
|
269 |
+
def get_model_features(
|
270 |
+
G,
|
271 |
+
model,
|
272 |
+
mode="clean",
|
273 |
+
z_dim=512,
|
274 |
+
num_gen=50_000,
|
275 |
+
batch_size=128,
|
276 |
+
device=torch.device("cuda"),
|
277 |
+
desc="FID model: ",
|
278 |
+
verbose=True,
|
279 |
+
return_z=False,
|
280 |
+
custom_image_tranform=None,
|
281 |
+
custom_fn_resize=None,
|
282 |
+
):
|
283 |
+
if custom_fn_resize is None:
|
284 |
+
fn_resize = build_resizer(mode)
|
285 |
+
else:
|
286 |
+
fn_resize = custom_fn_resize
|
287 |
+
|
288 |
+
# Generate test features
|
289 |
+
num_iters = int(np.ceil(num_gen / batch_size))
|
290 |
+
l_feats = []
|
291 |
+
latents = []
|
292 |
+
if verbose:
|
293 |
+
pbar = tqdm(range(num_iters), desc=desc)
|
294 |
+
else:
|
295 |
+
pbar = range(num_iters)
|
296 |
+
for idx in pbar:
|
297 |
+
with torch.no_grad():
|
298 |
+
z_batch = torch.randn((batch_size, z_dim)).to(device)
|
299 |
+
if return_z:
|
300 |
+
latents.append(z_batch)
|
301 |
+
# generated image is in range [0,255]
|
302 |
+
img_batch = G(z_batch)
|
303 |
+
# split into individual batches for resizing if needed
|
304 |
+
if mode != "legacy_tensorflow":
|
305 |
+
l_resized_batch = []
|
306 |
+
for idx in range(batch_size):
|
307 |
+
curr_img = img_batch[idx]
|
308 |
+
img_np = curr_img.cpu().numpy().transpose((1, 2, 0))
|
309 |
+
if custom_image_tranform is not None:
|
310 |
+
img_np = custom_image_tranform(img_np)
|
311 |
+
img_resize = fn_resize(img_np)
|
312 |
+
l_resized_batch.append(
|
313 |
+
torch.tensor(img_resize.transpose((2, 0, 1))).unsqueeze(0)
|
314 |
+
)
|
315 |
+
resized_batch = torch.cat(l_resized_batch, dim=0)
|
316 |
+
else:
|
317 |
+
resized_batch = img_batch
|
318 |
+
feat = get_batch_features(resized_batch, model, device)
|
319 |
+
l_feats.append(feat)
|
320 |
+
np_feats = np.concatenate(l_feats)[:num_gen]
|
321 |
+
if return_z:
|
322 |
+
latents = torch.cat(latents, 0)
|
323 |
+
return np_feats, latents
|
324 |
+
return np_feats
|
325 |
+
|
326 |
+
|
327 |
+
"""
|
328 |
+
Computes the FID score for a generator model for a specific dataset
|
329 |
+
and a specific resolution
|
330 |
+
"""
|
331 |
+
|
332 |
+
|
333 |
+
def fid_model(
|
334 |
+
G,
|
335 |
+
dataset_name,
|
336 |
+
dataset_res,
|
337 |
+
dataset_split,
|
338 |
+
model=None,
|
339 |
+
model_name="inception_v3",
|
340 |
+
z_dim=512,
|
341 |
+
num_gen=50_000,
|
342 |
+
mode="clean",
|
343 |
+
num_workers=0,
|
344 |
+
batch_size=128,
|
345 |
+
device=torch.device("cuda"),
|
346 |
+
verbose=True,
|
347 |
+
custom_image_tranform=None,
|
348 |
+
custom_fn_resize=None,
|
349 |
+
):
|
350 |
+
# Load reference FID statistics (download if needed)
|
351 |
+
ref_mu, ref_sigma = get_reference_statistics(
|
352 |
+
dataset_name,
|
353 |
+
dataset_res,
|
354 |
+
mode=mode,
|
355 |
+
model_name=model_name,
|
356 |
+
seed=0,
|
357 |
+
split=dataset_split,
|
358 |
+
)
|
359 |
+
# Generate features of images generated by the model
|
360 |
+
np_feats = get_model_features(
|
361 |
+
G,
|
362 |
+
model,
|
363 |
+
mode=mode,
|
364 |
+
z_dim=z_dim,
|
365 |
+
num_gen=num_gen,
|
366 |
+
batch_size=batch_size,
|
367 |
+
device=device,
|
368 |
+
verbose=verbose,
|
369 |
+
custom_image_tranform=custom_image_tranform,
|
370 |
+
custom_fn_resize=custom_fn_resize,
|
371 |
+
)
|
372 |
+
mu = np.mean(np_feats, axis=0)
|
373 |
+
sigma = np.cov(np_feats, rowvar=False)
|
374 |
+
fid = frechet_distance(mu, sigma, ref_mu, ref_sigma)
|
375 |
+
return fid
|
376 |
+
|
377 |
+
|
378 |
+
"""
|
379 |
+
Computes the FID score between the two given folders
|
380 |
+
"""
|
381 |
+
|
382 |
+
|
383 |
+
def compare_folders(
|
384 |
+
fdir1,
|
385 |
+
fdir2,
|
386 |
+
feat_model,
|
387 |
+
mode,
|
388 |
+
num_workers=0,
|
389 |
+
batch_size=8,
|
390 |
+
device=torch.device("cuda"),
|
391 |
+
verbose=True,
|
392 |
+
custom_image_tranform=None,
|
393 |
+
custom_fn_resize=None,
|
394 |
+
):
|
395 |
+
# get all inception features for the first folder
|
396 |
+
fbname1 = os.path.basename(fdir1)
|
397 |
+
np_feats1 = get_folder_features(
|
398 |
+
fdir1,
|
399 |
+
feat_model,
|
400 |
+
num_workers=num_workers,
|
401 |
+
batch_size=batch_size,
|
402 |
+
device=device,
|
403 |
+
mode=mode,
|
404 |
+
description=f"FID {fbname1} : ",
|
405 |
+
verbose=verbose,
|
406 |
+
custom_image_tranform=custom_image_tranform,
|
407 |
+
custom_fn_resize=custom_fn_resize,
|
408 |
+
)
|
409 |
+
mu1 = np.mean(np_feats1, axis=0)
|
410 |
+
sigma1 = np.cov(np_feats1, rowvar=False)
|
411 |
+
# get all inception features for the second folder
|
412 |
+
fbname2 = os.path.basename(fdir2)
|
413 |
+
np_feats2 = get_folder_features(
|
414 |
+
fdir2,
|
415 |
+
feat_model,
|
416 |
+
num_workers=num_workers,
|
417 |
+
batch_size=batch_size,
|
418 |
+
device=device,
|
419 |
+
mode=mode,
|
420 |
+
description=f"FID {fbname2} : ",
|
421 |
+
verbose=verbose,
|
422 |
+
custom_image_tranform=custom_image_tranform,
|
423 |
+
custom_fn_resize=custom_fn_resize,
|
424 |
+
)
|
425 |
+
mu2 = np.mean(np_feats2, axis=0)
|
426 |
+
sigma2 = np.cov(np_feats2, rowvar=False)
|
427 |
+
fid = frechet_distance(mu1, sigma1, mu2, sigma2)
|
428 |
+
return fid
|
429 |
+
|
430 |
+
|
431 |
+
"""
|
432 |
+
Test if a custom statistic exists
|
433 |
+
"""
|
434 |
+
|
435 |
+
|
436 |
+
def test_stats_exists(name, mode, model_name="inception_v3", metric="FID"):
|
437 |
+
stats_folder = os.path.join(get_dir(), "fid_stats")
|
438 |
+
split, res = "custom", "na"
|
439 |
+
if model_name == "inception_v3":
|
440 |
+
model_modifier = ""
|
441 |
+
else:
|
442 |
+
model_modifier = "_" + model_name
|
443 |
+
if metric == "FID":
|
444 |
+
fname = f"{name}_{mode}{model_modifier}_{split}_{res}.npz"
|
445 |
+
elif metric == "KID":
|
446 |
+
fname = f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz"
|
447 |
+
fpath = os.path.join(stats_folder, fname)
|
448 |
+
return os.path.exists(fpath)
|
449 |
+
|
450 |
+
|
451 |
+
"""
|
452 |
+
Remove the custom FID features from the stats folder
|
453 |
+
"""
|
454 |
+
|
455 |
+
|
456 |
+
def remove_custom_stats(name, mode="clean", model_name="inception_v3"):
|
457 |
+
stats_folder = os.path.join(get_dir(), "fid_stats")
|
458 |
+
# remove the FID stats
|
459 |
+
split, res = "custom", "na"
|
460 |
+
if model_name == "inception_v3":
|
461 |
+
model_modifier = ""
|
462 |
+
else:
|
463 |
+
model_modifier = "_" + model_name
|
464 |
+
outf = os.path.join(
|
465 |
+
stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}.npz"
|
466 |
+
)
|
467 |
+
if not os.path.exists(outf):
|
468 |
+
msg = f"The stats file {name} does not exist."
|
469 |
+
raise Exception(msg)
|
470 |
+
os.remove(outf)
|
471 |
+
# remove the KID stats
|
472 |
+
outf = os.path.join(
|
473 |
+
stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz"
|
474 |
+
)
|
475 |
+
if not os.path.exists(outf):
|
476 |
+
msg = f"The stats file {name} does not exist."
|
477 |
+
raise Exception(msg)
|
478 |
+
os.remove(outf)
|
479 |
+
|
480 |
+
|
481 |
+
"""
|
482 |
+
Cache a custom dataset statistics file
|
483 |
+
"""
|
484 |
+
|
485 |
+
|
486 |
+
def make_custom_stats(
|
487 |
+
name,
|
488 |
+
fdir,
|
489 |
+
num=None,
|
490 |
+
mode="clean",
|
491 |
+
model_name="inception_v3",
|
492 |
+
num_workers=0,
|
493 |
+
batch_size=64,
|
494 |
+
device=torch.device("cuda"),
|
495 |
+
verbose=True,
|
496 |
+
):
|
497 |
+
stats_folder = os.path.join(get_dir(), "fid_stats")
|
498 |
+
os.makedirs(stats_folder, exist_ok=True)
|
499 |
+
split, res = "custom", "na"
|
500 |
+
if model_name == "inception_v3":
|
501 |
+
model_modifier = ""
|
502 |
+
else:
|
503 |
+
model_modifier = "_" + model_name
|
504 |
+
outf = os.path.join(
|
505 |
+
stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}.npz"
|
506 |
+
)
|
507 |
+
# if the custom stat file already exists
|
508 |
+
if os.path.exists(outf):
|
509 |
+
msg = f"The statistics file {name} already exists. "
|
510 |
+
msg += "Use remove_custom_stats function to delete it first."
|
511 |
+
raise Exception(msg)
|
512 |
+
if model_name == "inception_v3":
|
513 |
+
feat_model = build_feature_extractor(mode, device)
|
514 |
+
custom_fn_resize = None
|
515 |
+
custom_image_tranform = None
|
516 |
+
elif model_name == "clip_vit_b_32":
|
517 |
+
from .clip_features import CLIP_fx, img_preprocess_clip
|
518 |
+
|
519 |
+
clip_fx = CLIP_fx("ViT-B/32")
|
520 |
+
feat_model = clip_fx
|
521 |
+
custom_fn_resize = img_preprocess_clip
|
522 |
+
custom_image_tranform = None
|
523 |
+
else:
|
524 |
+
raise ValueError(f"The entered model name - {model_name} was not recognized.")
|
525 |
+
|
526 |
+
# get all inception features for folder images
|
527 |
+
np_feats = get_folder_features(
|
528 |
+
fdir,
|
529 |
+
feat_model,
|
530 |
+
num_workers=num_workers,
|
531 |
+
num=num,
|
532 |
+
batch_size=batch_size,
|
533 |
+
device=device,
|
534 |
+
verbose=verbose,
|
535 |
+
mode=mode,
|
536 |
+
description=f"custom stats: {os.path.basename(fdir)} : ",
|
537 |
+
custom_image_tranform=custom_image_tranform,
|
538 |
+
custom_fn_resize=custom_fn_resize,
|
539 |
+
)
|
540 |
+
|
541 |
+
mu = np.mean(np_feats, axis=0)
|
542 |
+
sigma = np.cov(np_feats, rowvar=False)
|
543 |
+
# print(f"saving custom FID stats to {outf}")
|
544 |
+
np.savez_compressed(outf, mu=mu, sigma=sigma)
|
545 |
+
|
546 |
+
# KID stats
|
547 |
+
outf = os.path.join(
|
548 |
+
stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz"
|
549 |
+
)
|
550 |
+
# print(f"saving custom KID stats to {outf}")
|
551 |
+
np.savez_compressed(outf, feats=np_feats)
|
552 |
+
|
553 |
+
|
554 |
+
def compute_kid(
|
555 |
+
fdir1=None,
|
556 |
+
fdir2=None,
|
557 |
+
gen=None,
|
558 |
+
mode="clean",
|
559 |
+
num_workers=12,
|
560 |
+
batch_size=32,
|
561 |
+
device=torch.device("cuda"),
|
562 |
+
dataset_name="FFHQ",
|
563 |
+
dataset_res=1024,
|
564 |
+
dataset_split="train",
|
565 |
+
num_gen=50_000,
|
566 |
+
z_dim=512,
|
567 |
+
verbose=True,
|
568 |
+
use_dataparallel=True,
|
569 |
+
):
|
570 |
+
# build the feature extractor based on the mode
|
571 |
+
feat_model = build_feature_extractor(
|
572 |
+
mode, device, use_dataparallel=use_dataparallel
|
573 |
+
)
|
574 |
+
|
575 |
+
# if both dirs are specified, compute KID between folders
|
576 |
+
if fdir1 is not None and fdir2 is not None:
|
577 |
+
# get all inception features for the first folder
|
578 |
+
fbname1 = os.path.basename(fdir1)
|
579 |
+
np_feats1 = get_folder_features(
|
580 |
+
fdir1,
|
581 |
+
feat_model,
|
582 |
+
num_workers=num_workers,
|
583 |
+
batch_size=batch_size,
|
584 |
+
device=device,
|
585 |
+
mode=mode,
|
586 |
+
description=f"KID {fbname1} : ",
|
587 |
+
verbose=verbose,
|
588 |
+
)
|
589 |
+
# get all inception features for the second folder
|
590 |
+
fbname2 = os.path.basename(fdir2)
|
591 |
+
np_feats2 = get_folder_features(
|
592 |
+
fdir2,
|
593 |
+
feat_model,
|
594 |
+
num_workers=num_workers,
|
595 |
+
batch_size=batch_size,
|
596 |
+
device=device,
|
597 |
+
mode=mode,
|
598 |
+
description=f"KID {fbname2} : ",
|
599 |
+
verbose=verbose,
|
600 |
+
)
|
601 |
+
score = kernel_distance(np_feats1, np_feats2)
|
602 |
+
return score
|
603 |
+
|
604 |
+
# compute kid of a folder
|
605 |
+
elif fdir1 is not None and fdir2 is None:
|
606 |
+
if verbose:
|
607 |
+
print(f"compute KID of a folder with {dataset_name} statistics")
|
608 |
+
ref_feats = get_reference_statistics(
|
609 |
+
dataset_name,
|
610 |
+
dataset_res,
|
611 |
+
mode=mode,
|
612 |
+
seed=0,
|
613 |
+
split=dataset_split,
|
614 |
+
metric="KID",
|
615 |
+
)
|
616 |
+
fbname = os.path.basename(fdir1)
|
617 |
+
# get all inception features for folder images
|
618 |
+
np_feats = get_folder_features(
|
619 |
+
fdir1,
|
620 |
+
feat_model,
|
621 |
+
num_workers=num_workers,
|
622 |
+
batch_size=batch_size,
|
623 |
+
device=device,
|
624 |
+
mode=mode,
|
625 |
+
description=f"KID {fbname} : ",
|
626 |
+
verbose=verbose,
|
627 |
+
)
|
628 |
+
score = kernel_distance(ref_feats, np_feats)
|
629 |
+
return score
|
630 |
+
|
631 |
+
# compute kid for a generator, using images in fdir2
|
632 |
+
elif gen is not None and fdir2 is not None:
|
633 |
+
if verbose:
|
634 |
+
print(f"compute KID of a model, using references in fdir2")
|
635 |
+
# get all inception features for the second folder
|
636 |
+
fbname2 = os.path.basename(fdir2)
|
637 |
+
ref_feats = get_folder_features(
|
638 |
+
fdir2,
|
639 |
+
feat_model,
|
640 |
+
num_workers=num_workers,
|
641 |
+
batch_size=batch_size,
|
642 |
+
device=device,
|
643 |
+
mode=mode,
|
644 |
+
description=f"KID {fbname2} : ",
|
645 |
+
)
|
646 |
+
# Generate test features
|
647 |
+
np_feats = get_model_features(
|
648 |
+
gen,
|
649 |
+
feat_model,
|
650 |
+
mode=mode,
|
651 |
+
z_dim=z_dim,
|
652 |
+
num_gen=num_gen,
|
653 |
+
desc="KID model: ",
|
654 |
+
batch_size=batch_size,
|
655 |
+
device=device,
|
656 |
+
)
|
657 |
+
score = kernel_distance(ref_feats, np_feats)
|
658 |
+
return score
|
659 |
+
|
660 |
+
# compute fid for a generator, using reference statistics
|
661 |
+
elif gen is not None:
|
662 |
+
if verbose:
|
663 |
+
print(
|
664 |
+
f"compute KID of a model with {dataset_name}-{dataset_res} statistics"
|
665 |
+
)
|
666 |
+
ref_feats = get_reference_statistics(
|
667 |
+
dataset_name,
|
668 |
+
dataset_res,
|
669 |
+
mode=mode,
|
670 |
+
seed=0,
|
671 |
+
split=dataset_split,
|
672 |
+
metric="KID",
|
673 |
+
)
|
674 |
+
# Generate test features
|
675 |
+
np_feats = get_model_features(
|
676 |
+
gen,
|
677 |
+
feat_model,
|
678 |
+
mode=mode,
|
679 |
+
z_dim=z_dim,
|
680 |
+
num_gen=num_gen,
|
681 |
+
desc="KID model: ",
|
682 |
+
batch_size=batch_size,
|
683 |
+
device=device,
|
684 |
+
verbose=verbose,
|
685 |
+
)
|
686 |
+
score = kernel_distance(ref_feats, np_feats)
|
687 |
+
return score
|
688 |
+
|
689 |
+
else:
|
690 |
+
raise ValueError("invalid combination of directories and models entered")
|
691 |
+
|
692 |
+
|
693 |
+
"""
|
694 |
+
custom_image_tranform:
|
695 |
+
function that takes an np_array image as input [0,255] and
|
696 |
+
applies a custom transform such as cropping
|
697 |
+
"""
|
698 |
+
|
699 |
+
|
700 |
+
def compute_fid(
|
701 |
+
fdir1=None,
|
702 |
+
fdir2=None,
|
703 |
+
gen=None,
|
704 |
+
mode="clean",
|
705 |
+
model_name="inception_v3",
|
706 |
+
num_workers=12,
|
707 |
+
batch_size=32,
|
708 |
+
device=torch.device("cuda"),
|
709 |
+
dataset_name="FFHQ",
|
710 |
+
dataset_res=1024,
|
711 |
+
dataset_split="train",
|
712 |
+
num_gen=50_000,
|
713 |
+
z_dim=512,
|
714 |
+
custom_feat_extractor=None,
|
715 |
+
verbose=True,
|
716 |
+
custom_image_tranform=None,
|
717 |
+
custom_fn_resize=None,
|
718 |
+
use_dataparallel=True,
|
719 |
+
):
|
720 |
+
# build the feature extractor based on the mode and the model to be used
|
721 |
+
if custom_feat_extractor is None and model_name == "inception_v3":
|
722 |
+
feat_model = build_feature_extractor(
|
723 |
+
mode, device, use_dataparallel=use_dataparallel
|
724 |
+
)
|
725 |
+
elif custom_feat_extractor is None and model_name == "clip_vit_b_32":
|
726 |
+
from .clip_features import CLIP_fx, img_preprocess_clip
|
727 |
+
|
728 |
+
clip_fx = CLIP_fx("ViT-B/32", device=device)
|
729 |
+
feat_model = clip_fx
|
730 |
+
custom_fn_resize = img_preprocess_clip
|
731 |
+
else:
|
732 |
+
feat_model = custom_feat_extractor
|
733 |
+
|
734 |
+
# if both dirs are specified, compute FID between folders
|
735 |
+
if fdir1 is not None and fdir2 is not None:
|
736 |
+
score = compare_folders(
|
737 |
+
fdir1,
|
738 |
+
fdir2,
|
739 |
+
feat_model,
|
740 |
+
mode=mode,
|
741 |
+
batch_size=batch_size,
|
742 |
+
num_workers=num_workers,
|
743 |
+
device=device,
|
744 |
+
custom_image_tranform=custom_image_tranform,
|
745 |
+
custom_fn_resize=custom_fn_resize,
|
746 |
+
verbose=verbose,
|
747 |
+
)
|
748 |
+
return score
|
749 |
+
|
750 |
+
# compute fid of a folder
|
751 |
+
elif fdir1 is not None and fdir2 is None:
|
752 |
+
if verbose:
|
753 |
+
print(f"compute FID of a folder with {dataset_name} statistics")
|
754 |
+
score = fid_folder(
|
755 |
+
fdir1,
|
756 |
+
dataset_name,
|
757 |
+
dataset_res,
|
758 |
+
dataset_split,
|
759 |
+
model=feat_model,
|
760 |
+
mode=mode,
|
761 |
+
model_name=model_name,
|
762 |
+
custom_fn_resize=custom_fn_resize,
|
763 |
+
custom_image_tranform=custom_image_tranform,
|
764 |
+
num_workers=num_workers,
|
765 |
+
batch_size=batch_size,
|
766 |
+
device=device,
|
767 |
+
verbose=verbose,
|
768 |
+
)
|
769 |
+
return score
|
770 |
+
|
771 |
+
# compute fid for a generator, using images in fdir2
|
772 |
+
elif gen is not None and fdir2 is not None:
|
773 |
+
if verbose:
|
774 |
+
print(f"compute FID of a model, using references in fdir2")
|
775 |
+
# get all inception features for the second folder
|
776 |
+
fbname2 = os.path.basename(fdir2)
|
777 |
+
np_feats2 = get_folder_features(
|
778 |
+
fdir2,
|
779 |
+
feat_model,
|
780 |
+
num_workers=num_workers,
|
781 |
+
batch_size=batch_size,
|
782 |
+
device=device,
|
783 |
+
mode=mode,
|
784 |
+
description=f"FID {fbname2} : ",
|
785 |
+
verbose=verbose,
|
786 |
+
custom_fn_resize=custom_fn_resize,
|
787 |
+
custom_image_tranform=custom_image_tranform,
|
788 |
+
)
|
789 |
+
mu2 = np.mean(np_feats2, axis=0)
|
790 |
+
sigma2 = np.cov(np_feats2, rowvar=False)
|
791 |
+
# Generate test features
|
792 |
+
np_feats = get_model_features(
|
793 |
+
gen,
|
794 |
+
feat_model,
|
795 |
+
mode=mode,
|
796 |
+
z_dim=z_dim,
|
797 |
+
num_gen=num_gen,
|
798 |
+
custom_fn_resize=custom_fn_resize,
|
799 |
+
custom_image_tranform=custom_image_tranform,
|
800 |
+
batch_size=batch_size,
|
801 |
+
device=device,
|
802 |
+
verbose=verbose,
|
803 |
+
)
|
804 |
+
|
805 |
+
mu = np.mean(np_feats, axis=0)
|
806 |
+
sigma = np.cov(np_feats, rowvar=False)
|
807 |
+
fid = frechet_distance(mu, sigma, mu2, sigma2)
|
808 |
+
return fid
|
809 |
+
|
810 |
+
# compute fid for a generator, using reference statistics
|
811 |
+
elif gen is not None:
|
812 |
+
if verbose:
|
813 |
+
print(
|
814 |
+
f"compute FID of a model with {dataset_name}-{dataset_res} statistics"
|
815 |
+
)
|
816 |
+
score = fid_model(
|
817 |
+
gen,
|
818 |
+
dataset_name,
|
819 |
+
dataset_res,
|
820 |
+
dataset_split,
|
821 |
+
model=feat_model,
|
822 |
+
model_name=model_name,
|
823 |
+
z_dim=z_dim,
|
824 |
+
num_gen=num_gen,
|
825 |
+
mode=mode,
|
826 |
+
num_workers=num_workers,
|
827 |
+
batch_size=batch_size,
|
828 |
+
custom_image_tranform=custom_image_tranform,
|
829 |
+
custom_fn_resize=custom_fn_resize,
|
830 |
+
device=device,
|
831 |
+
verbose=verbose,
|
832 |
+
)
|
833 |
+
return score
|
834 |
+
|
835 |
+
else:
|
836 |
+
raise ValueError("invalid combination of directories and models entered")
|
kit/metrics/clean_fid/inception_pytorch.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File from: https://github.com/mseitzer/pytorch-fid
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
import warnings
|
11 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
12 |
+
|
13 |
+
# Inception weights ported to Pytorch from
|
14 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
15 |
+
FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
|
16 |
+
|
17 |
+
|
18 |
+
class InceptionV3(nn.Module):
|
19 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
20 |
+
|
21 |
+
# Index of default block of inception to return,
|
22 |
+
# corresponds to output of final average pooling
|
23 |
+
DEFAULT_BLOCK_INDEX = 3
|
24 |
+
|
25 |
+
# Maps feature dimensionality to their output blocks indices
|
26 |
+
BLOCK_INDEX_BY_DIM = {
|
27 |
+
64: 0, # First max pooling features
|
28 |
+
192: 1, # Second max pooling featurs
|
29 |
+
768: 2, # Pre-aux classifier features
|
30 |
+
2048: 3, # Final average pooling features
|
31 |
+
}
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
36 |
+
resize_input=True,
|
37 |
+
normalize_input=True,
|
38 |
+
requires_grad=False,
|
39 |
+
use_fid_inception=True,
|
40 |
+
):
|
41 |
+
"""Build pretrained InceptionV3
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
output_blocks : list of int
|
45 |
+
Indices of blocks to return features of. Possible values are:
|
46 |
+
- 0: corresponds to output of first max pooling
|
47 |
+
- 1: corresponds to output of second max pooling
|
48 |
+
- 2: corresponds to output which is fed to aux classifier
|
49 |
+
- 3: corresponds to output of final average pooling
|
50 |
+
resize_input : bool
|
51 |
+
If true, bilinearly resizes input to width and height 299 before
|
52 |
+
feeding input to model. As the network without fully connected
|
53 |
+
layers is fully convolutional, it should be able to handle inputs
|
54 |
+
of arbitrary size, so resizing might not be strictly needed
|
55 |
+
normalize_input : bool
|
56 |
+
If true, scales the input from range (0, 1) to the range the
|
57 |
+
pretrained Inception network expects, namely (-1, 1)
|
58 |
+
requires_grad : bool
|
59 |
+
If true, parameters of the model require gradients. Possibly useful
|
60 |
+
for finetuning the network
|
61 |
+
use_fid_inception : bool
|
62 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
63 |
+
FID implementation. If false, uses the pretrained Inception model
|
64 |
+
available in torchvision. The FID Inception model has different
|
65 |
+
weights and a slightly different structure from torchvision's
|
66 |
+
Inception model. If you want to compute FID scores, you are
|
67 |
+
strongly advised to set this parameter to true to get comparable
|
68 |
+
results.
|
69 |
+
"""
|
70 |
+
super(InceptionV3, self).__init__()
|
71 |
+
|
72 |
+
self.resize_input = resize_input
|
73 |
+
self.normalize_input = normalize_input
|
74 |
+
self.output_blocks = sorted(output_blocks)
|
75 |
+
self.last_needed_block = max(output_blocks)
|
76 |
+
|
77 |
+
assert self.last_needed_block <= 3, "Last possible output block index is 3"
|
78 |
+
|
79 |
+
self.blocks = nn.ModuleList()
|
80 |
+
|
81 |
+
if use_fid_inception:
|
82 |
+
inception = fid_inception_v3()
|
83 |
+
else:
|
84 |
+
inception = _inception_v3(pretrained=True)
|
85 |
+
|
86 |
+
# Block 0: input to maxpool1
|
87 |
+
block0 = [
|
88 |
+
inception.Conv2d_1a_3x3,
|
89 |
+
inception.Conv2d_2a_3x3,
|
90 |
+
inception.Conv2d_2b_3x3,
|
91 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
92 |
+
]
|
93 |
+
self.blocks.append(nn.Sequential(*block0))
|
94 |
+
|
95 |
+
# Block 1: maxpool1 to maxpool2
|
96 |
+
if self.last_needed_block >= 1:
|
97 |
+
block1 = [
|
98 |
+
inception.Conv2d_3b_1x1,
|
99 |
+
inception.Conv2d_4a_3x3,
|
100 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
101 |
+
]
|
102 |
+
self.blocks.append(nn.Sequential(*block1))
|
103 |
+
|
104 |
+
# Block 2: maxpool2 to aux classifier
|
105 |
+
if self.last_needed_block >= 2:
|
106 |
+
block2 = [
|
107 |
+
inception.Mixed_5b,
|
108 |
+
inception.Mixed_5c,
|
109 |
+
inception.Mixed_5d,
|
110 |
+
inception.Mixed_6a,
|
111 |
+
inception.Mixed_6b,
|
112 |
+
inception.Mixed_6c,
|
113 |
+
inception.Mixed_6d,
|
114 |
+
inception.Mixed_6e,
|
115 |
+
]
|
116 |
+
self.blocks.append(nn.Sequential(*block2))
|
117 |
+
|
118 |
+
# Block 3: aux classifier to final avgpool
|
119 |
+
if self.last_needed_block >= 3:
|
120 |
+
block3 = [
|
121 |
+
inception.Mixed_7a,
|
122 |
+
inception.Mixed_7b,
|
123 |
+
inception.Mixed_7c,
|
124 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
125 |
+
]
|
126 |
+
self.blocks.append(nn.Sequential(*block3))
|
127 |
+
|
128 |
+
for param in self.parameters():
|
129 |
+
param.requires_grad = requires_grad
|
130 |
+
|
131 |
+
def forward(self, inp):
|
132 |
+
"""Get Inception feature maps
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
inp : torch.autograd.Variable
|
136 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
137 |
+
range (0, 1)
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
raise ValueError("should not resize here")
|
148 |
+
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
|
149 |
+
|
150 |
+
if self.normalize_input:
|
151 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
152 |
+
|
153 |
+
for idx, block in enumerate(self.blocks):
|
154 |
+
x = block(x)
|
155 |
+
if idx in self.output_blocks:
|
156 |
+
outp.append(x)
|
157 |
+
|
158 |
+
if idx == self.last_needed_block:
|
159 |
+
break
|
160 |
+
|
161 |
+
return outp
|
162 |
+
|
163 |
+
|
164 |
+
def _inception_v3(*args, **kwargs):
|
165 |
+
"""Wraps `torchvision.models.inception_v3`
|
166 |
+
Skips default weight inititialization if supported by torchvision version.
|
167 |
+
See https://github.com/mseitzer/pytorch-fid/issues/28.
|
168 |
+
"""
|
169 |
+
warnings.filterwarnings("ignore")
|
170 |
+
try:
|
171 |
+
version = tuple(map(int, torchvision.__version__.split(".")[:2]))
|
172 |
+
except ValueError:
|
173 |
+
# Just a caution against weird version strings
|
174 |
+
version = (0,)
|
175 |
+
|
176 |
+
if version >= (0, 6):
|
177 |
+
kwargs["init_weights"] = False
|
178 |
+
|
179 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
180 |
+
|
181 |
+
|
182 |
+
def fid_inception_v3():
|
183 |
+
"""Build pretrained Inception model for FID computation
|
184 |
+
The Inception model for FID computation uses a different set of weights
|
185 |
+
and has a slightly different structure than torchvision's Inception.
|
186 |
+
This method first constructs torchvision's Inception and then patches the
|
187 |
+
necessary parts that are different in the FID Inception model.
|
188 |
+
"""
|
189 |
+
inception = _inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
|
190 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
191 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
192 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
193 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
194 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
195 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
196 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
197 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
198 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
199 |
+
|
200 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
|
201 |
+
inception.load_state_dict(state_dict)
|
202 |
+
return inception
|
203 |
+
|
204 |
+
|
205 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
206 |
+
"""InceptionA block patched for FID computation"""
|
207 |
+
|
208 |
+
def __init__(self, in_channels, pool_features):
|
209 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
branch1x1 = self.branch1x1(x)
|
213 |
+
|
214 |
+
branch5x5 = self.branch5x5_1(x)
|
215 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
216 |
+
|
217 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
218 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
219 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
220 |
+
|
221 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
222 |
+
# its average calculation
|
223 |
+
branch_pool = F.avg_pool2d(
|
224 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
225 |
+
)
|
226 |
+
branch_pool = self.branch_pool(branch_pool)
|
227 |
+
|
228 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
229 |
+
return torch.cat(outputs, 1)
|
230 |
+
|
231 |
+
|
232 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
233 |
+
"""InceptionC block patched for FID computation"""
|
234 |
+
|
235 |
+
def __init__(self, in_channels, channels_7x7):
|
236 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
branch1x1 = self.branch1x1(x)
|
240 |
+
|
241 |
+
branch7x7 = self.branch7x7_1(x)
|
242 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
243 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
244 |
+
|
245 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
246 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
247 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
248 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
249 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
250 |
+
|
251 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
252 |
+
# its average calculation
|
253 |
+
branch_pool = F.avg_pool2d(
|
254 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
255 |
+
)
|
256 |
+
branch_pool = self.branch_pool(branch_pool)
|
257 |
+
|
258 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
259 |
+
return torch.cat(outputs, 1)
|
260 |
+
|
261 |
+
|
262 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
263 |
+
"""First InceptionE block patched for FID computation"""
|
264 |
+
|
265 |
+
def __init__(self, in_channels):
|
266 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
branch1x1 = self.branch1x1(x)
|
270 |
+
|
271 |
+
branch3x3 = self.branch3x3_1(x)
|
272 |
+
branch3x3 = [
|
273 |
+
self.branch3x3_2a(branch3x3),
|
274 |
+
self.branch3x3_2b(branch3x3),
|
275 |
+
]
|
276 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
277 |
+
|
278 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
279 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
280 |
+
branch3x3dbl = [
|
281 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
282 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
283 |
+
]
|
284 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
285 |
+
|
286 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
287 |
+
# its average calculation
|
288 |
+
branch_pool = F.avg_pool2d(
|
289 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
290 |
+
)
|
291 |
+
branch_pool = self.branch_pool(branch_pool)
|
292 |
+
|
293 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
294 |
+
return torch.cat(outputs, 1)
|
295 |
+
|
296 |
+
|
297 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
298 |
+
"""Second InceptionE block patched for FID computation"""
|
299 |
+
|
300 |
+
def __init__(self, in_channels):
|
301 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
branch1x1 = self.branch1x1(x)
|
305 |
+
|
306 |
+
branch3x3 = self.branch3x3_1(x)
|
307 |
+
branch3x3 = [
|
308 |
+
self.branch3x3_2a(branch3x3),
|
309 |
+
self.branch3x3_2b(branch3x3),
|
310 |
+
]
|
311 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
312 |
+
|
313 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
314 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
315 |
+
branch3x3dbl = [
|
316 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
317 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
318 |
+
]
|
319 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
320 |
+
|
321 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
322 |
+
# pooling. This is likely an error in this specific Inception
|
323 |
+
# implementation, as other Inception models use average pooling here
|
324 |
+
# (which matches the description in the paper).
|
325 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
326 |
+
branch_pool = self.branch_pool(branch_pool)
|
327 |
+
|
328 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
329 |
+
return torch.cat(outputs, 1)
|
kit/metrics/clean_fid/inception_torchscript.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import contextlib
|
5 |
+
from .downloads_helper import *
|
6 |
+
|
7 |
+
|
8 |
+
@contextlib.contextmanager
|
9 |
+
def disable_gpu_fuser_on_pt19():
|
10 |
+
# On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See
|
11 |
+
# https://github.com/GaParmar/clean-fid/issues/5
|
12 |
+
# https://github.com/pytorch/pytorch/issues/64062
|
13 |
+
if torch.__version__.startswith("1.9."):
|
14 |
+
old_val = torch._C._jit_can_fuse_on_gpu()
|
15 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
16 |
+
yield
|
17 |
+
if torch.__version__.startswith("1.9."):
|
18 |
+
torch._C._jit_override_can_fuse_on_gpu(old_val)
|
19 |
+
|
20 |
+
|
21 |
+
class InceptionV3W(nn.Module):
|
22 |
+
"""
|
23 |
+
Wrapper around Inception V3 torchscript model provided here
|
24 |
+
https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt
|
25 |
+
|
26 |
+
path: locally saved inception weights
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, path, download=True, resize_inside=False):
|
30 |
+
super(InceptionV3W, self).__init__()
|
31 |
+
# download the network if it is not present at the given directory
|
32 |
+
# use the current directory by default
|
33 |
+
if download:
|
34 |
+
check_download_inception(fpath=path)
|
35 |
+
path = os.path.join(path, "inception-2015-12-05.pt")
|
36 |
+
self.base = torch.jit.load(path).eval()
|
37 |
+
self.layers = self.base.layers
|
38 |
+
self.resize_inside = resize_inside
|
39 |
+
|
40 |
+
"""
|
41 |
+
Get the inception features without resizing
|
42 |
+
x: Image with values in range [0,255]
|
43 |
+
"""
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
with disable_gpu_fuser_on_pt19():
|
47 |
+
bs = x.shape[0]
|
48 |
+
if self.resize_inside:
|
49 |
+
features = self.base(x, return_features=True).view((bs, 2048))
|
50 |
+
else:
|
51 |
+
# make sure it is resized already
|
52 |
+
assert (x.shape[2] == 299) and (x.shape[3] == 299)
|
53 |
+
# apply normalization
|
54 |
+
x1 = x - 128
|
55 |
+
x2 = x1 / 128
|
56 |
+
features = self.layers.forward(
|
57 |
+
x2,
|
58 |
+
).view((bs, 2048))
|
59 |
+
return features
|
kit/metrics/clean_fid/leaderboard.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import csv
|
3 |
+
import shutil
|
4 |
+
import urllib.request
|
5 |
+
|
6 |
+
|
7 |
+
def get_score(
|
8 |
+
model_name=None,
|
9 |
+
dataset_name=None,
|
10 |
+
dataset_res=None,
|
11 |
+
dataset_split=None,
|
12 |
+
task_name=None,
|
13 |
+
):
|
14 |
+
# download the csv file from server
|
15 |
+
url = "https://www.cs.cmu.edu/~clean-fid/files/leaderboard.csv"
|
16 |
+
local_path = "/tmp/leaderboard.csv"
|
17 |
+
with urllib.request.urlopen(url) as response, open(local_path, "wb") as f:
|
18 |
+
shutil.copyfileobj(response, f)
|
19 |
+
|
20 |
+
d_field2idx = {}
|
21 |
+
l_matches = []
|
22 |
+
with open(local_path, "r") as f:
|
23 |
+
csvreader = csv.reader(f)
|
24 |
+
l_fields = next(csvreader)
|
25 |
+
for idx, val in enumerate(l_fields):
|
26 |
+
d_field2idx[val.strip()] = idx
|
27 |
+
# iterate through all rows
|
28 |
+
for row in csvreader:
|
29 |
+
# skip empty rows
|
30 |
+
if len(row) == 0:
|
31 |
+
continue
|
32 |
+
# skip if the filter doesn't match
|
33 |
+
if model_name is not None and (
|
34 |
+
row[d_field2idx["model_name"]].strip() != model_name
|
35 |
+
):
|
36 |
+
continue
|
37 |
+
if dataset_name is not None and (
|
38 |
+
row[d_field2idx["dataset_name"]].strip() != dataset_name
|
39 |
+
):
|
40 |
+
continue
|
41 |
+
if dataset_res is not None and (
|
42 |
+
row[d_field2idx["dataset_res"]].strip() != dataset_res
|
43 |
+
):
|
44 |
+
continue
|
45 |
+
if dataset_split is not None and (
|
46 |
+
row[d_field2idx["dataset_split"]].strip() != dataset_split
|
47 |
+
):
|
48 |
+
continue
|
49 |
+
if task_name is not None and (
|
50 |
+
row[d_field2idx["task_name"]].strip() != task_name
|
51 |
+
):
|
52 |
+
continue
|
53 |
+
curr = {}
|
54 |
+
for f in l_fields:
|
55 |
+
curr[f.strip()] = row[d_field2idx[f.strip()]].strip()
|
56 |
+
l_matches.append(curr)
|
57 |
+
os.remove(local_path)
|
58 |
+
return l_matches
|
kit/metrics/clean_fid/resize.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers for resizing with multiple CPU cores
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def build_resizer(mode):
|
12 |
+
if mode == "clean":
|
13 |
+
return make_resizer("PIL", False, "bicubic", (299, 299))
|
14 |
+
# if using legacy tensorflow, do not manually resize outside the network
|
15 |
+
elif mode == "legacy_tensorflow":
|
16 |
+
return lambda x: x
|
17 |
+
elif mode == "legacy_pytorch":
|
18 |
+
return make_resizer("PyTorch", False, "bilinear", (299, 299))
|
19 |
+
else:
|
20 |
+
raise ValueError(f"Invalid mode {mode} specified")
|
21 |
+
|
22 |
+
|
23 |
+
"""
|
24 |
+
Construct a function that resizes a numpy image based on the
|
25 |
+
flags passed in.
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
def make_resizer(library, quantize_after, filter, output_size):
|
30 |
+
if library == "PIL" and quantize_after:
|
31 |
+
name_to_filter = {
|
32 |
+
"bicubic": Image.BICUBIC,
|
33 |
+
"bilinear": Image.BILINEAR,
|
34 |
+
"nearest": Image.NEAREST,
|
35 |
+
"lanczos": Image.LANCZOS,
|
36 |
+
"box": Image.BOX,
|
37 |
+
}
|
38 |
+
|
39 |
+
def func(x):
|
40 |
+
x = Image.fromarray(x)
|
41 |
+
x = x.resize(output_size, resample=name_to_filter[filter])
|
42 |
+
x = np.asarray(x).clip(0, 255).astype(np.uint8)
|
43 |
+
return x
|
44 |
+
|
45 |
+
elif library == "PIL" and not quantize_after:
|
46 |
+
name_to_filter = {
|
47 |
+
"bicubic": Image.BICUBIC,
|
48 |
+
"bilinear": Image.BILINEAR,
|
49 |
+
"nearest": Image.NEAREST,
|
50 |
+
"lanczos": Image.LANCZOS,
|
51 |
+
"box": Image.BOX,
|
52 |
+
}
|
53 |
+
s1, s2 = output_size
|
54 |
+
|
55 |
+
def resize_single_channel(x_np):
|
56 |
+
img = Image.fromarray(x_np.astype(np.float32), mode="F")
|
57 |
+
img = img.resize(output_size, resample=name_to_filter[filter])
|
58 |
+
return np.asarray(img).clip(0, 255).reshape(s2, s1, 1)
|
59 |
+
|
60 |
+
def func(x):
|
61 |
+
x = [resize_single_channel(x[:, :, idx]) for idx in range(3)]
|
62 |
+
x = np.concatenate(x, axis=2).astype(np.float32)
|
63 |
+
return x
|
64 |
+
|
65 |
+
elif library == "PyTorch":
|
66 |
+
import warnings
|
67 |
+
|
68 |
+
# ignore the numpy warnings
|
69 |
+
warnings.filterwarnings("ignore")
|
70 |
+
|
71 |
+
def func(x):
|
72 |
+
x = torch.Tensor(x.transpose((2, 0, 1)))[None, ...]
|
73 |
+
x = F.interpolate(x, size=output_size, mode=filter, align_corners=False)
|
74 |
+
x = x[0, ...].cpu().data.numpy().transpose((1, 2, 0)).clip(0, 255)
|
75 |
+
if quantize_after:
|
76 |
+
x = x.astype(np.uint8)
|
77 |
+
return x
|
78 |
+
|
79 |
+
else:
|
80 |
+
raise NotImplementedError("library [%s] is not include" % library)
|
81 |
+
return func
|
82 |
+
|
83 |
+
|
84 |
+
class FolderResizer(torch.utils.data.Dataset):
|
85 |
+
def __init__(self, files, outpath, fn_resize, output_ext=".png"):
|
86 |
+
self.files = files
|
87 |
+
self.outpath = outpath
|
88 |
+
self.output_ext = output_ext
|
89 |
+
self.fn_resize = fn_resize
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return len(self.files)
|
93 |
+
|
94 |
+
def __getitem__(self, i):
|
95 |
+
path = str(self.files[i])
|
96 |
+
img_np = np.asarray(Image.open(path))
|
97 |
+
img_resize_np = self.fn_resize(img_np)
|
98 |
+
# swap the output extension
|
99 |
+
basename = os.path.basename(path).split(".")[0] + self.output_ext
|
100 |
+
outname = os.path.join(self.outpath, basename)
|
101 |
+
if self.output_ext == ".npy":
|
102 |
+
np.save(outname, img_resize_np)
|
103 |
+
elif self.output_ext == ".png":
|
104 |
+
img_resized_pil = Image.fromarray(img_resize_np)
|
105 |
+
img_resized_pil.save(outname)
|
106 |
+
else:
|
107 |
+
raise ValueError("invalid output extension")
|
108 |
+
return 0
|
kit/metrics/clean_fid/utils.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
from PIL import Image
|
5 |
+
import zipfile
|
6 |
+
from .resize import build_resizer
|
7 |
+
|
8 |
+
|
9 |
+
class ResizeDataset(torch.utils.data.Dataset):
|
10 |
+
"""
|
11 |
+
A placeholder Dataset that enables parallelizing the resize operation
|
12 |
+
using multiple CPU cores
|
13 |
+
|
14 |
+
files: list of all files in the folder
|
15 |
+
fn_resize: function that takes an np_array as input [0,255]
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, files, mode, size=(299, 299), fdir=None):
|
19 |
+
self.files = files
|
20 |
+
self.fdir = fdir
|
21 |
+
self.transforms = torchvision.transforms.ToTensor()
|
22 |
+
self.size = size
|
23 |
+
self.fn_resize = build_resizer(mode)
|
24 |
+
self.custom_image_tranform = lambda x: x
|
25 |
+
self._zipfile = None
|
26 |
+
|
27 |
+
def _get_zipfile(self):
|
28 |
+
assert self.fdir is not None and ".zip" in self.fdir
|
29 |
+
if self._zipfile is None:
|
30 |
+
self._zipfile = zipfile.ZipFile(self.fdir)
|
31 |
+
return self._zipfile
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.files)
|
35 |
+
|
36 |
+
def __getitem__(self, i):
|
37 |
+
path = str(self.files[i])
|
38 |
+
if self.fdir is not None and ".zip" in self.fdir:
|
39 |
+
with self._get_zipfile().open(path, "r") as f:
|
40 |
+
img_np = np.array(Image.open(f).convert("RGB"))
|
41 |
+
elif ".npy" in path:
|
42 |
+
img_np = np.load(path)
|
43 |
+
else:
|
44 |
+
img_pil = Image.open(path).convert("RGB")
|
45 |
+
img_np = np.array(img_pil)
|
46 |
+
|
47 |
+
# apply a custom image transform before resizing the image to 299x299
|
48 |
+
img_np = self.custom_image_tranform(img_np)
|
49 |
+
# fn_resize expects a np array and returns a np array
|
50 |
+
img_resized = self.fn_resize(img_np)
|
51 |
+
|
52 |
+
# ToTensor() converts to [0,1] only if input in uint8
|
53 |
+
if img_resized.dtype == "uint8":
|
54 |
+
img_t = self.transforms(np.array(img_resized)) * 255
|
55 |
+
elif img_resized.dtype == "float32":
|
56 |
+
img_t = self.transforms(img_resized)
|
57 |
+
|
58 |
+
return img_t
|
59 |
+
|
60 |
+
|
61 |
+
EXTENSIONS = {
|
62 |
+
"bmp",
|
63 |
+
"jpg",
|
64 |
+
"jpeg",
|
65 |
+
"pgm",
|
66 |
+
"png",
|
67 |
+
"ppm",
|
68 |
+
"tif",
|
69 |
+
"tiff",
|
70 |
+
"webp",
|
71 |
+
"npy",
|
72 |
+
"JPEG",
|
73 |
+
"JPG",
|
74 |
+
"PNG",
|
75 |
+
}
|
kit/metrics/clean_fid/wrappers.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from .features import build_feature_extractor, get_reference_statistics
|
5 |
+
from .fid import get_batch_features, fid_from_feats
|
6 |
+
from .resize import build_resizer
|
7 |
+
|
8 |
+
|
9 |
+
"""
|
10 |
+
A helper class that allowing adding the images one batch at a time.
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
class CleanFID:
|
15 |
+
def __init__(self, mode="clean", model_name="inception_v3", device="cuda"):
|
16 |
+
self.real_features = []
|
17 |
+
self.gen_features = []
|
18 |
+
self.mode = mode
|
19 |
+
self.device = device
|
20 |
+
if model_name == "inception_v3":
|
21 |
+
self.feat_model = build_feature_extractor(mode, device)
|
22 |
+
self.fn_resize = build_resizer(mode)
|
23 |
+
elif model_name == "clip_vit_b_32":
|
24 |
+
from .clip_features import CLIP_fx, img_preprocess_clip
|
25 |
+
|
26 |
+
clip_fx = CLIP_fx("ViT-B/32")
|
27 |
+
self.feat_model = clip_fx
|
28 |
+
self.fn_resize = img_preprocess_clip
|
29 |
+
|
30 |
+
"""
|
31 |
+
Funtion that takes an image (PIL.Image or np.array or torch.tensor)
|
32 |
+
and returns the corresponding feature embedding vector.
|
33 |
+
The image x is expected to be in range [0, 255]
|
34 |
+
"""
|
35 |
+
|
36 |
+
def compute_features(self, x):
|
37 |
+
# if x is a PIL Image
|
38 |
+
if isinstance(x, Image.Image):
|
39 |
+
x_np = np.array(x)
|
40 |
+
x_np_resized = self.fn_resize(x_np)
|
41 |
+
x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0)
|
42 |
+
x_feat = get_batch_features(x_t, self.feat_model, self.device)
|
43 |
+
elif isinstance(x, np.ndarray):
|
44 |
+
x_np_resized = self.fn_resize(x)
|
45 |
+
x_t = (
|
46 |
+
torch.tensor(x_np_resized.transpose((2, 0, 1)))
|
47 |
+
.unsqueeze(0)
|
48 |
+
.to(self.device)
|
49 |
+
)
|
50 |
+
# normalization happens inside the self.feat_model, expected image range here is [0,255]
|
51 |
+
x_feat = get_batch_features(x_t, self.feat_model, self.device)
|
52 |
+
elif isinstance(x, torch.Tensor):
|
53 |
+
# pdb.set_trace()
|
54 |
+
# add the batch dimension if x is passed in as C,H,W
|
55 |
+
if len(x.shape) == 3:
|
56 |
+
x = x.unsqueeze(0)
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
# convert back to np array and resize
|
59 |
+
l_x_np_resized = []
|
60 |
+
for _ in range(b):
|
61 |
+
x_np = x[_].cpu().numpy().transpose((1, 2, 0))
|
62 |
+
l_x_np_resized.append(self.fn_resize(x_np)[None,])
|
63 |
+
x_np_resized = np.concatenate(l_x_np_resized)
|
64 |
+
x_t = torch.tensor(x_np_resized.transpose((0, 3, 1, 2))).to(self.device)
|
65 |
+
# normalization happens inside the self.feat_model, expected image range here is [0,255]
|
66 |
+
x_feat = get_batch_features(x_t, self.feat_model, self.device)
|
67 |
+
else:
|
68 |
+
raise ValueError("image type could not be inferred")
|
69 |
+
return x_feat
|
70 |
+
|
71 |
+
"""
|
72 |
+
Extract the faetures from x and add to the list of reference real images
|
73 |
+
"""
|
74 |
+
|
75 |
+
def add_real_images(self, x):
|
76 |
+
x_feat = self.compute_features(x)
|
77 |
+
self.real_features.append(x_feat)
|
78 |
+
|
79 |
+
"""
|
80 |
+
Extract the faetures from x and add to the list of generated images
|
81 |
+
"""
|
82 |
+
|
83 |
+
def add_gen_images(self, x):
|
84 |
+
x_feat = self.compute_features(x)
|
85 |
+
self.gen_features.append(x_feat)
|
86 |
+
|
87 |
+
"""
|
88 |
+
Compute FID between the real and generated images added so far
|
89 |
+
"""
|
90 |
+
|
91 |
+
def calculate_fid(self, verbose=True):
|
92 |
+
feats1 = np.concatenate(self.real_features)
|
93 |
+
feats2 = np.concatenate(self.gen_features)
|
94 |
+
if verbose:
|
95 |
+
print(f"# real images = {feats1.shape[0]}")
|
96 |
+
print(f"# generated images = {feats2.shape[0]}")
|
97 |
+
return fid_from_feats(feats1, feats2)
|
98 |
+
|
99 |
+
"""
|
100 |
+
Remove the real image features added so far
|
101 |
+
"""
|
102 |
+
|
103 |
+
def reset_real_features(self):
|
104 |
+
self.real_features = []
|
105 |
+
|
106 |
+
"""
|
107 |
+
Remove the generated image features added so far
|
108 |
+
"""
|
109 |
+
|
110 |
+
def reset_gen_features(self):
|
111 |
+
self.gen_features = []
|
kit/metrics/clip.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import open_clip
|
4 |
+
|
5 |
+
|
6 |
+
def load_open_clip_model_preprocess_and_tokenizer(device=torch.device("cuda")):
|
7 |
+
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
|
8 |
+
"ViT-g-14", pretrained="laion2b_s12b_b42k", device=device
|
9 |
+
)
|
10 |
+
clip_tokenizer = open_clip.get_tokenizer("ViT-g-14")
|
11 |
+
return clip_model, clip_preprocess, clip_tokenizer
|
12 |
+
|
13 |
+
|
14 |
+
def compute_clip_score(
|
15 |
+
images,
|
16 |
+
prompts,
|
17 |
+
models,
|
18 |
+
device=torch.device("cuda"),
|
19 |
+
):
|
20 |
+
clip_model, clip_preprocess, clip_tokenizer = models
|
21 |
+
with torch.no_grad():
|
22 |
+
tensors = [clip_preprocess(image) for image in images]
|
23 |
+
image_processed_tensor = torch.stack(tensors, 0).to(device)
|
24 |
+
image_features = clip_model.encode_image(image_processed_tensor)
|
25 |
+
|
26 |
+
encoding = clip_tokenizer(prompts).to(device)
|
27 |
+
text_features = clip_model.encode_text(encoding)
|
28 |
+
|
29 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
30 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
return (image_features @ text_features.T).mean(-1).cpu().numpy().tolist()
|
kit/metrics/distributional.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from concurrent.futures import ProcessPoolExecutor
|
7 |
+
from functools import partial
|
8 |
+
from PIL import Image
|
9 |
+
from .clean_fid import fid
|
10 |
+
|
11 |
+
|
12 |
+
def save_single_image_to_temp(i, image, temp_dir):
|
13 |
+
save_path = os.path.join(temp_dir, f"{i}.png")
|
14 |
+
image.save(save_path, "PNG")
|
15 |
+
|
16 |
+
|
17 |
+
def save_images_to_temp(images, num_workers, verbose=False):
|
18 |
+
assert isinstance(images, list) and isinstance(images[0], Image.Image)
|
19 |
+
temp_dir = tempfile.mkdtemp()
|
20 |
+
|
21 |
+
# Using ProcessPoolExecutor to save images in parallel
|
22 |
+
func = partial(save_single_image_to_temp, temp_dir=temp_dir)
|
23 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
24 |
+
tasks = executor.map(func, range(len(images)), images)
|
25 |
+
list(tasks) if not verbose else list(
|
26 |
+
tqdm(
|
27 |
+
tasks,
|
28 |
+
total=len(images),
|
29 |
+
desc="Saving images ",
|
30 |
+
)
|
31 |
+
)
|
32 |
+
return temp_dir
|
33 |
+
|
34 |
+
|
35 |
+
# Compute FID between two sets of images
|
36 |
+
def compute_fid(
|
37 |
+
images1,
|
38 |
+
images2,
|
39 |
+
mode="legacy",
|
40 |
+
device=None,
|
41 |
+
batch_size=64,
|
42 |
+
num_workers=None,
|
43 |
+
verbose=False,
|
44 |
+
):
|
45 |
+
# Support four types of FID scores
|
46 |
+
assert mode in ["legacy", "clean", "clip"]
|
47 |
+
if mode == "legacy":
|
48 |
+
mode = "legacy_pytorch"
|
49 |
+
model_name = "inception_v3"
|
50 |
+
elif mode == "clean":
|
51 |
+
mode = "clean"
|
52 |
+
model_name = "inception_v3"
|
53 |
+
elif mode == "clip":
|
54 |
+
mode = "clean"
|
55 |
+
model_name = "clip_vit_b_32"
|
56 |
+
else:
|
57 |
+
assert False
|
58 |
+
|
59 |
+
# Set up device and num_workers
|
60 |
+
if device is None:
|
61 |
+
device = (
|
62 |
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
63 |
+
)
|
64 |
+
if num_workers is not None:
|
65 |
+
assert 1 <= num_workers <= os.cpu_count()
|
66 |
+
else:
|
67 |
+
num_workers = max(torch.cuda.device_count() * 4, 8)
|
68 |
+
|
69 |
+
# Check images, can be paths or lists of PIL images
|
70 |
+
if not isinstance(images1, list):
|
71 |
+
assert isinstance(images1, str) and os.path.exists(images1)
|
72 |
+
assert isinstance(images2, str) and os.path.exists(images2)
|
73 |
+
path1 = images1
|
74 |
+
path2 = images2
|
75 |
+
else:
|
76 |
+
assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
|
77 |
+
assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
|
78 |
+
# Save images to temp dir if needed
|
79 |
+
path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose)
|
80 |
+
path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose)
|
81 |
+
|
82 |
+
# Attempt to cache statistics for path1
|
83 |
+
if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name):
|
84 |
+
fid.make_custom_stats(
|
85 |
+
name=str(os.path.abspath(path1)).replace("/", "_"),
|
86 |
+
fdir=path1,
|
87 |
+
mode=mode,
|
88 |
+
model_name=model_name,
|
89 |
+
device=device,
|
90 |
+
num_workers=num_workers,
|
91 |
+
verbose=verbose,
|
92 |
+
)
|
93 |
+
fid_score = fid.compute_fid(
|
94 |
+
path2,
|
95 |
+
dataset_name=str(os.path.abspath(path1)).replace("/", "_"),
|
96 |
+
dataset_split="custom",
|
97 |
+
mode=mode,
|
98 |
+
model_name=model_name,
|
99 |
+
device=device,
|
100 |
+
batch_size=batch_size,
|
101 |
+
num_workers=num_workers,
|
102 |
+
verbose=verbose,
|
103 |
+
)
|
104 |
+
return fid_score
|
kit/metrics/image.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from skimage.metrics import (
|
6 |
+
mean_squared_error,
|
7 |
+
peak_signal_noise_ratio,
|
8 |
+
structural_similarity as structural_similarity_index_measure,
|
9 |
+
normalized_mutual_information,
|
10 |
+
)
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from concurrent.futures import ThreadPoolExecutor
|
13 |
+
|
14 |
+
|
15 |
+
# Process images to numpy arrays
|
16 |
+
def convert_image_pair_to_numpy(image1, image2):
|
17 |
+
assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image)
|
18 |
+
|
19 |
+
image1_np = np.array(image1)
|
20 |
+
image2_np = np.array(image2)
|
21 |
+
assert image1_np.shape == image2_np.shape
|
22 |
+
|
23 |
+
return image1_np, image2_np
|
24 |
+
|
25 |
+
|
26 |
+
# Compute MSE between two images
|
27 |
+
def compute_mse(image1, image2):
|
28 |
+
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
|
29 |
+
return float(mean_squared_error(image1_np, image2_np))
|
30 |
+
|
31 |
+
|
32 |
+
# Compute PSNR between two images
|
33 |
+
def compute_psnr(image1, image2):
|
34 |
+
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
|
35 |
+
return float(peak_signal_noise_ratio(image1_np, image2_np))
|
36 |
+
|
37 |
+
|
38 |
+
# Compute SSIM between two images
|
39 |
+
def compute_ssim(image1, image2):
|
40 |
+
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
|
41 |
+
return float(
|
42 |
+
structural_similarity_index_measure(image1_np, image2_np, channel_axis=2)
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
# Compute NMI between two images
|
47 |
+
def compute_nmi(image1, image2):
|
48 |
+
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2)
|
49 |
+
return float(normalized_mutual_information(image1_np, image2_np))
|
50 |
+
|
51 |
+
|
52 |
+
# Compute metrics
|
53 |
+
def compute_metric_repeated(
|
54 |
+
images1, images2, metric_func, num_workers=None, verbose=False
|
55 |
+
):
|
56 |
+
# Accept list of PIL images
|
57 |
+
assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
|
58 |
+
assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
|
59 |
+
assert len(images1) == len(images2)
|
60 |
+
|
61 |
+
if num_workers is not None:
|
62 |
+
assert 1 <= num_workers <= os.cpu_count()
|
63 |
+
else:
|
64 |
+
num_workers = max(torch.cuda.device_count() * 4, 8)
|
65 |
+
|
66 |
+
metric_name = metric_func.__name__.split("_")[1].upper()
|
67 |
+
|
68 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
69 |
+
tasks = executor.map(metric_func, images1, images2)
|
70 |
+
values = (
|
71 |
+
list(tasks)
|
72 |
+
if not verbose
|
73 |
+
else list(
|
74 |
+
tqdm(
|
75 |
+
tasks,
|
76 |
+
total=len(images1),
|
77 |
+
desc=f"{metric_name} ",
|
78 |
+
)
|
79 |
+
)
|
80 |
+
)
|
81 |
+
return values
|
82 |
+
|
83 |
+
|
84 |
+
# Compute MSE between pairs of images
|
85 |
+
def compute_mse_repeated(images1, images2, num_workers=None, verbose=False):
|
86 |
+
return compute_metric_repeated(images1, images2, compute_mse, num_workers, verbose)
|
87 |
+
|
88 |
+
|
89 |
+
# Compute PSNR between pairs of images
|
90 |
+
def compute_psnr_repeated(images1, images2, num_workers=None, verbose=False):
|
91 |
+
return compute_metric_repeated(images1, images2, compute_psnr, num_workers, verbose)
|
92 |
+
|
93 |
+
|
94 |
+
# Compute SSIM between pairs of images
|
95 |
+
def compute_ssim_repeated(images1, images2, num_workers=None, verbose=False):
|
96 |
+
return compute_metric_repeated(images1, images2, compute_ssim, num_workers, verbose)
|
97 |
+
|
98 |
+
|
99 |
+
# Compute NMI between pairs of images
|
100 |
+
def compute_nmi_repeated(images1, images2, num_workers=None, verbose=False):
|
101 |
+
return compute_metric_repeated(images1, images2, compute_nmi, num_workers, verbose)
|
102 |
+
|
103 |
+
|
104 |
+
def compute_image_distance_repeated(
|
105 |
+
images1, images2, metric_name, num_workers=None, verbose=False
|
106 |
+
):
|
107 |
+
metric_func = {
|
108 |
+
"psnr": compute_psnr,
|
109 |
+
"ssim": compute_ssim,
|
110 |
+
"nmi": compute_nmi,
|
111 |
+
}[metric_name]
|
112 |
+
return compute_metric_repeated(images1, images2, metric_func, num_workers, verbose)
|
kit/metrics/lpips/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
From https://github.com/richzhang/PerceptualSimilarity
|
3 |
+
"""
|
4 |
+
from .lpips import LPIPS
|
kit/metrics/lpips/lpips.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import warnings
|
6 |
+
from . import pretrained_networks as pn
|
7 |
+
from .utils import normalize_tensor, l2, dssim, tensor2np, tensor2tensorlab, tensor2im
|
8 |
+
|
9 |
+
|
10 |
+
def spatial_average(in_tens, keepdim=True):
|
11 |
+
return in_tens.mean([2, 3], keepdim=keepdim)
|
12 |
+
|
13 |
+
|
14 |
+
def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
|
15 |
+
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
|
16 |
+
return nn.Upsample(size=out_HW, mode="bilinear", align_corners=False)(in_tens)
|
17 |
+
|
18 |
+
|
19 |
+
# Learned perceptual metric
|
20 |
+
class LPIPS(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
pretrained=True,
|
24 |
+
net="alex",
|
25 |
+
version="0.1",
|
26 |
+
lpips=True,
|
27 |
+
spatial=False,
|
28 |
+
pnet_rand=False,
|
29 |
+
pnet_tune=False,
|
30 |
+
use_dropout=True,
|
31 |
+
model_path=None,
|
32 |
+
eval_mode=True,
|
33 |
+
verbose=True,
|
34 |
+
):
|
35 |
+
"""Initializes a perceptual loss torch.nn.Module
|
36 |
+
|
37 |
+
Parameters (default listed first)
|
38 |
+
---------------------------------
|
39 |
+
lpips : bool
|
40 |
+
[True] use linear layers on top of base/trunk network
|
41 |
+
[False] means no linear layers; each layer is averaged together
|
42 |
+
pretrained : bool
|
43 |
+
This flag controls the linear layers, which are only in effect when lpips=True above
|
44 |
+
[True] means linear layers are calibrated with human perceptual judgments
|
45 |
+
[False] means linear layers are randomly initialized
|
46 |
+
pnet_rand : bool
|
47 |
+
[False] means trunk loaded with ImageNet classification weights
|
48 |
+
[True] means randomly initialized trunk
|
49 |
+
net : str
|
50 |
+
['alex','vgg','squeeze'] are the base/trunk networks available
|
51 |
+
version : str
|
52 |
+
['v0.1'] is the default and latest
|
53 |
+
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
|
54 |
+
model_path : 'str'
|
55 |
+
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
|
56 |
+
|
57 |
+
The following parameters should only be changed if training the network
|
58 |
+
|
59 |
+
eval_mode : bool
|
60 |
+
[True] is for test mode (default)
|
61 |
+
[False] is for training mode
|
62 |
+
pnet_tune
|
63 |
+
[False] keep base/trunk frozen
|
64 |
+
[True] tune the base/trunk network
|
65 |
+
use_dropout : bool
|
66 |
+
[True] to use dropout when training linear layers
|
67 |
+
[False] for no dropout when training linear layers
|
68 |
+
"""
|
69 |
+
|
70 |
+
super(LPIPS, self).__init__()
|
71 |
+
warnings.filterwarnings("ignore")
|
72 |
+
if verbose:
|
73 |
+
pass
|
74 |
+
# print(
|
75 |
+
# "Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]"
|
76 |
+
# % (
|
77 |
+
# "LPIPS" if lpips else "baseline",
|
78 |
+
# net,
|
79 |
+
# version,
|
80 |
+
# "on" if spatial else "off",
|
81 |
+
# )
|
82 |
+
# )
|
83 |
+
|
84 |
+
self.pnet_type = net
|
85 |
+
self.pnet_tune = pnet_tune
|
86 |
+
self.pnet_rand = pnet_rand
|
87 |
+
self.spatial = spatial
|
88 |
+
self.lpips = lpips # false means baseline of just averaging all layers
|
89 |
+
self.version = version
|
90 |
+
self.scaling_layer = ScalingLayer()
|
91 |
+
|
92 |
+
if self.pnet_type in ["vgg", "vgg16"]:
|
93 |
+
net_type = pn.vgg16
|
94 |
+
self.chns = [64, 128, 256, 512, 512]
|
95 |
+
elif self.pnet_type == "alex":
|
96 |
+
net_type = pn.alexnet
|
97 |
+
self.chns = [64, 192, 384, 256, 256]
|
98 |
+
elif self.pnet_type == "squeeze":
|
99 |
+
net_type = pn.squeezenet
|
100 |
+
self.chns = [64, 128, 256, 384, 384, 512, 512]
|
101 |
+
self.L = len(self.chns)
|
102 |
+
|
103 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
104 |
+
|
105 |
+
if lpips:
|
106 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
107 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
108 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
109 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
110 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
111 |
+
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
112 |
+
if self.pnet_type == "squeeze": # 7 layers for squeezenet
|
113 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
114 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
115 |
+
self.lins += [self.lin5, self.lin6]
|
116 |
+
self.lins = nn.ModuleList(self.lins)
|
117 |
+
|
118 |
+
if pretrained:
|
119 |
+
if model_path is None:
|
120 |
+
import inspect
|
121 |
+
import os
|
122 |
+
|
123 |
+
model_path = os.path.abspath(
|
124 |
+
os.path.join(
|
125 |
+
inspect.getfile(self.__init__),
|
126 |
+
"..",
|
127 |
+
"weights/v%s/%s.pth" % (version, net),
|
128 |
+
)
|
129 |
+
)
|
130 |
+
|
131 |
+
if verbose:
|
132 |
+
pass
|
133 |
+
# print("Loading model from: %s" % model_path)
|
134 |
+
self.load_state_dict(
|
135 |
+
torch.load(model_path, map_location="cpu"), strict=False
|
136 |
+
)
|
137 |
+
|
138 |
+
if eval_mode:
|
139 |
+
self.eval()
|
140 |
+
|
141 |
+
def forward(self, in0, in1, retPerLayer=False, normalize=False):
|
142 |
+
if (
|
143 |
+
normalize
|
144 |
+
): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
|
145 |
+
in0 = 2 * in0 - 1
|
146 |
+
in1 = 2 * in1 - 1
|
147 |
+
|
148 |
+
# v0.0 - original release had a bug, where input was not scaled
|
149 |
+
in0_input, in1_input = (
|
150 |
+
(self.scaling_layer(in0), self.scaling_layer(in1))
|
151 |
+
if self.version == "0.1"
|
152 |
+
else (in0, in1)
|
153 |
+
)
|
154 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
155 |
+
feats0, feats1, diffs = {}, {}, {}
|
156 |
+
|
157 |
+
for kk in range(self.L):
|
158 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
159 |
+
outs1[kk]
|
160 |
+
)
|
161 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
162 |
+
|
163 |
+
if self.lpips:
|
164 |
+
if self.spatial:
|
165 |
+
res = [
|
166 |
+
upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:])
|
167 |
+
for kk in range(self.L)
|
168 |
+
]
|
169 |
+
else:
|
170 |
+
res = [
|
171 |
+
spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
|
172 |
+
for kk in range(self.L)
|
173 |
+
]
|
174 |
+
else:
|
175 |
+
if self.spatial:
|
176 |
+
res = [
|
177 |
+
upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:])
|
178 |
+
for kk in range(self.L)
|
179 |
+
]
|
180 |
+
else:
|
181 |
+
res = [
|
182 |
+
spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)
|
183 |
+
for kk in range(self.L)
|
184 |
+
]
|
185 |
+
|
186 |
+
val = 0
|
187 |
+
for l in range(self.L):
|
188 |
+
val += res[l]
|
189 |
+
|
190 |
+
if retPerLayer:
|
191 |
+
return (val, res)
|
192 |
+
else:
|
193 |
+
return val
|
194 |
+
|
195 |
+
|
196 |
+
class ScalingLayer(nn.Module):
|
197 |
+
def __init__(self):
|
198 |
+
super(ScalingLayer, self).__init__()
|
199 |
+
self.register_buffer(
|
200 |
+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
201 |
+
)
|
202 |
+
self.register_buffer(
|
203 |
+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
204 |
+
)
|
205 |
+
|
206 |
+
def forward(self, inp):
|
207 |
+
return (inp - self.shift) / self.scale
|
208 |
+
|
209 |
+
|
210 |
+
class NetLinLayer(nn.Module):
|
211 |
+
"""A single linear layer which does a 1x1 conv"""
|
212 |
+
|
213 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
214 |
+
super(NetLinLayer, self).__init__()
|
215 |
+
|
216 |
+
layers = (
|
217 |
+
[
|
218 |
+
nn.Dropout(),
|
219 |
+
]
|
220 |
+
if (use_dropout)
|
221 |
+
else []
|
222 |
+
)
|
223 |
+
layers += [
|
224 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
225 |
+
]
|
226 |
+
self.model = nn.Sequential(*layers)
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
return self.model(x)
|
230 |
+
|
231 |
+
|
232 |
+
class Dist2LogitLayer(nn.Module):
|
233 |
+
"""takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)"""
|
234 |
+
|
235 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
236 |
+
super(Dist2LogitLayer, self).__init__()
|
237 |
+
|
238 |
+
layers = [
|
239 |
+
nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),
|
240 |
+
]
|
241 |
+
layers += [
|
242 |
+
nn.LeakyReLU(0.2, True),
|
243 |
+
]
|
244 |
+
layers += [
|
245 |
+
nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),
|
246 |
+
]
|
247 |
+
layers += [
|
248 |
+
nn.LeakyReLU(0.2, True),
|
249 |
+
]
|
250 |
+
layers += [
|
251 |
+
nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),
|
252 |
+
]
|
253 |
+
if use_sigmoid:
|
254 |
+
layers += [
|
255 |
+
nn.Sigmoid(),
|
256 |
+
]
|
257 |
+
self.model = nn.Sequential(*layers)
|
258 |
+
|
259 |
+
def forward(self, d0, d1, eps=0.1):
|
260 |
+
return self.model.forward(
|
261 |
+
torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
class BCERankingLoss(nn.Module):
|
266 |
+
def __init__(self, chn_mid=32):
|
267 |
+
super(BCERankingLoss, self).__init__()
|
268 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
269 |
+
# self.parameters = list(self.net.parameters())
|
270 |
+
self.loss = torch.nn.BCELoss()
|
271 |
+
|
272 |
+
def forward(self, d0, d1, judge):
|
273 |
+
per = (judge + 1.0) / 2.0
|
274 |
+
self.logit = self.net.forward(d0, d1)
|
275 |
+
return self.loss(self.logit, per)
|
276 |
+
|
277 |
+
|
278 |
+
# L2, DSSIM metrics
|
279 |
+
class FakeNet(nn.Module):
|
280 |
+
def __init__(self, use_gpu=True, colorspace="Lab"):
|
281 |
+
super(FakeNet, self).__init__()
|
282 |
+
self.use_gpu = use_gpu
|
283 |
+
self.colorspace = colorspace
|
284 |
+
|
285 |
+
|
286 |
+
class L2(FakeNet):
|
287 |
+
def forward(self, in0, in1, retPerLayer=None):
|
288 |
+
assert in0.size()[0] == 1 # currently only supports batchSize 1
|
289 |
+
|
290 |
+
if self.colorspace == "RGB":
|
291 |
+
(N, C, X, Y) = in0.size()
|
292 |
+
value = torch.mean(
|
293 |
+
torch.mean(
|
294 |
+
torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2
|
295 |
+
).view(N, 1, 1, Y),
|
296 |
+
dim=3,
|
297 |
+
).view(N)
|
298 |
+
return value
|
299 |
+
elif self.colorspace == "Lab":
|
300 |
+
value = l2(
|
301 |
+
tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
302 |
+
tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
|
303 |
+
range=100.0,
|
304 |
+
).astype("float")
|
305 |
+
ret_var = Variable(torch.Tensor((value,)))
|
306 |
+
if self.use_gpu:
|
307 |
+
ret_var = ret_var.cuda()
|
308 |
+
return ret_var
|
309 |
+
|
310 |
+
|
311 |
+
class DSSIM(FakeNet):
|
312 |
+
def forward(self, in0, in1, retPerLayer=None):
|
313 |
+
assert in0.size()[0] == 1 # currently only supports batchSize 1
|
314 |
+
|
315 |
+
if self.colorspace == "RGB":
|
316 |
+
value = dssim(
|
317 |
+
1.0 * tensor2im(in0.data),
|
318 |
+
1.0 * tensor2im(in1.data),
|
319 |
+
range=255.0,
|
320 |
+
).astype("float")
|
321 |
+
elif self.colorspace == "Lab":
|
322 |
+
value = dssim(
|
323 |
+
tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
324 |
+
tensor2np(tensor2tensorlab(in1.data, to_norm=False)),
|
325 |
+
range=100.0,
|
326 |
+
).astype("float")
|
327 |
+
ret_var = Variable(torch.Tensor((value,)))
|
328 |
+
if self.use_gpu:
|
329 |
+
ret_var = ret_var.cuda()
|
330 |
+
return ret_var
|
331 |
+
|
332 |
+
|
333 |
+
def print_network(net):
|
334 |
+
num_params = 0
|
335 |
+
for param in net.parameters():
|
336 |
+
num_params += param.numel()
|
337 |
+
print("Network", net)
|
338 |
+
print("Total number of parameters: %d" % num_params)
|
kit/metrics/lpips/pretrained_networks.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
|
5 |
+
|
6 |
+
class squeezenet(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
+
super(squeezenet, self).__init__()
|
9 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
+
self.slice1 = torch.nn.Sequential()
|
11 |
+
self.slice2 = torch.nn.Sequential()
|
12 |
+
self.slice3 = torch.nn.Sequential()
|
13 |
+
self.slice4 = torch.nn.Sequential()
|
14 |
+
self.slice5 = torch.nn.Sequential()
|
15 |
+
self.slice6 = torch.nn.Sequential()
|
16 |
+
self.slice7 = torch.nn.Sequential()
|
17 |
+
self.N_slices = 7
|
18 |
+
for x in range(2):
|
19 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
+
for x in range(2, 5):
|
21 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
+
for x in range(5, 8):
|
23 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
+
for x in range(8, 10):
|
25 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
+
for x in range(10, 11):
|
27 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
+
for x in range(11, 12):
|
29 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
+
for x in range(12, 13):
|
31 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
+
if not requires_grad:
|
33 |
+
for param in self.parameters():
|
34 |
+
param.requires_grad = False
|
35 |
+
|
36 |
+
def forward(self, X):
|
37 |
+
h = self.slice1(X)
|
38 |
+
h_relu1 = h
|
39 |
+
h = self.slice2(h)
|
40 |
+
h_relu2 = h
|
41 |
+
h = self.slice3(h)
|
42 |
+
h_relu3 = h
|
43 |
+
h = self.slice4(h)
|
44 |
+
h_relu4 = h
|
45 |
+
h = self.slice5(h)
|
46 |
+
h_relu5 = h
|
47 |
+
h = self.slice6(h)
|
48 |
+
h_relu6 = h
|
49 |
+
h = self.slice7(h)
|
50 |
+
h_relu7 = h
|
51 |
+
vgg_outputs = namedtuple(
|
52 |
+
"SqueezeOutputs",
|
53 |
+
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
|
54 |
+
)
|
55 |
+
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
56 |
+
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class alexnet(torch.nn.Module):
|
61 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
62 |
+
super(alexnet, self).__init__()
|
63 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
64 |
+
self.slice1 = torch.nn.Sequential()
|
65 |
+
self.slice2 = torch.nn.Sequential()
|
66 |
+
self.slice3 = torch.nn.Sequential()
|
67 |
+
self.slice4 = torch.nn.Sequential()
|
68 |
+
self.slice5 = torch.nn.Sequential()
|
69 |
+
self.N_slices = 5
|
70 |
+
for x in range(2):
|
71 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
72 |
+
for x in range(2, 5):
|
73 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
74 |
+
for x in range(5, 8):
|
75 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
76 |
+
for x in range(8, 10):
|
77 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
78 |
+
for x in range(10, 12):
|
79 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
80 |
+
if not requires_grad:
|
81 |
+
for param in self.parameters():
|
82 |
+
param.requires_grad = False
|
83 |
+
|
84 |
+
def forward(self, X):
|
85 |
+
h = self.slice1(X)
|
86 |
+
h_relu1 = h
|
87 |
+
h = self.slice2(h)
|
88 |
+
h_relu2 = h
|
89 |
+
h = self.slice3(h)
|
90 |
+
h_relu3 = h
|
91 |
+
h = self.slice4(h)
|
92 |
+
h_relu4 = h
|
93 |
+
h = self.slice5(h)
|
94 |
+
h_relu5 = h
|
95 |
+
alexnet_outputs = namedtuple(
|
96 |
+
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
|
97 |
+
)
|
98 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
99 |
+
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class vgg16(torch.nn.Module):
|
104 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
105 |
+
super(vgg16, self).__init__()
|
106 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
107 |
+
self.slice1 = torch.nn.Sequential()
|
108 |
+
self.slice2 = torch.nn.Sequential()
|
109 |
+
self.slice3 = torch.nn.Sequential()
|
110 |
+
self.slice4 = torch.nn.Sequential()
|
111 |
+
self.slice5 = torch.nn.Sequential()
|
112 |
+
self.N_slices = 5
|
113 |
+
for x in range(4):
|
114 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(4, 9):
|
116 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
for x in range(9, 16):
|
118 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
for x in range(16, 23):
|
120 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
121 |
+
for x in range(23, 30):
|
122 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
123 |
+
if not requires_grad:
|
124 |
+
for param in self.parameters():
|
125 |
+
param.requires_grad = False
|
126 |
+
|
127 |
+
def forward(self, X):
|
128 |
+
h = self.slice1(X)
|
129 |
+
h_relu1_2 = h
|
130 |
+
h = self.slice2(h)
|
131 |
+
h_relu2_2 = h
|
132 |
+
h = self.slice3(h)
|
133 |
+
h_relu3_3 = h
|
134 |
+
h = self.slice4(h)
|
135 |
+
h_relu4_3 = h
|
136 |
+
h = self.slice5(h)
|
137 |
+
h_relu5_3 = h
|
138 |
+
vgg_outputs = namedtuple(
|
139 |
+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
140 |
+
)
|
141 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
142 |
+
|
143 |
+
return out
|
144 |
+
|
145 |
+
|
146 |
+
class resnet(torch.nn.Module):
|
147 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
148 |
+
super(resnet, self).__init__()
|
149 |
+
if num == 18:
|
150 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
151 |
+
elif num == 34:
|
152 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
153 |
+
elif num == 50:
|
154 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
155 |
+
elif num == 101:
|
156 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
157 |
+
elif num == 152:
|
158 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
159 |
+
self.N_slices = 5
|
160 |
+
|
161 |
+
self.conv1 = self.net.conv1
|
162 |
+
self.bn1 = self.net.bn1
|
163 |
+
self.relu = self.net.relu
|
164 |
+
self.maxpool = self.net.maxpool
|
165 |
+
self.layer1 = self.net.layer1
|
166 |
+
self.layer2 = self.net.layer2
|
167 |
+
self.layer3 = self.net.layer3
|
168 |
+
self.layer4 = self.net.layer4
|
169 |
+
|
170 |
+
def forward(self, X):
|
171 |
+
h = self.conv1(X)
|
172 |
+
h = self.bn1(h)
|
173 |
+
h = self.relu(h)
|
174 |
+
h_relu1 = h
|
175 |
+
h = self.maxpool(h)
|
176 |
+
h = self.layer1(h)
|
177 |
+
h_conv2 = h
|
178 |
+
h = self.layer2(h)
|
179 |
+
h_conv3 = h
|
180 |
+
h = self.layer3(h)
|
181 |
+
h_conv4 = h
|
182 |
+
h = self.layer4(h)
|
183 |
+
h_conv5 = h
|
184 |
+
|
185 |
+
outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"])
|
186 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
187 |
+
|
188 |
+
return out
|
kit/metrics/lpips/trainer.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from collections import OrderedDict
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from scipy.ndimage import zoom
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
from .lpips import LPIPS, L2, DSSIM, BCERankingLoss
|
10 |
+
from .utils import tensor2im, voc_ap
|
11 |
+
|
12 |
+
|
13 |
+
class Trainer:
|
14 |
+
def name(self):
|
15 |
+
return self.model_name
|
16 |
+
|
17 |
+
def initialize(
|
18 |
+
self,
|
19 |
+
model="lpips",
|
20 |
+
net="alex",
|
21 |
+
colorspace="Lab",
|
22 |
+
pnet_rand=False,
|
23 |
+
pnet_tune=False,
|
24 |
+
model_path=None,
|
25 |
+
use_gpu=True,
|
26 |
+
printNet=False,
|
27 |
+
spatial=False,
|
28 |
+
is_train=False,
|
29 |
+
lr=0.0001,
|
30 |
+
beta1=0.5,
|
31 |
+
version="0.1",
|
32 |
+
gpu_ids=[0],
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
INPUTS
|
36 |
+
model - ['lpips'] for linearly calibrated network
|
37 |
+
['baseline'] for off-the-shelf network
|
38 |
+
['L2'] for L2 distance in Lab colorspace
|
39 |
+
['SSIM'] for ssim in RGB colorspace
|
40 |
+
net - ['squeeze','alex','vgg']
|
41 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
42 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
43 |
+
use_gpu - bool - whether or not to use a GPU
|
44 |
+
printNet - bool - whether or not to print network architecture out
|
45 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
46 |
+
is_train - bool - [True] for training mode
|
47 |
+
lr - float - initial learning rate
|
48 |
+
beta1 - float - initial momentum term for adam
|
49 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
50 |
+
gpu_ids - int array - [0] by default, gpus to use
|
51 |
+
"""
|
52 |
+
self.use_gpu = use_gpu
|
53 |
+
self.gpu_ids = gpu_ids
|
54 |
+
self.model = model
|
55 |
+
self.net = net
|
56 |
+
self.is_train = is_train
|
57 |
+
self.spatial = spatial
|
58 |
+
self.model_name = "%s [%s]" % (model, net)
|
59 |
+
|
60 |
+
if self.model == "lpips": # pretrained net + linear layer
|
61 |
+
self.net = LPIPS(
|
62 |
+
pretrained=not is_train,
|
63 |
+
net=net,
|
64 |
+
version=version,
|
65 |
+
lpips=True,
|
66 |
+
spatial=spatial,
|
67 |
+
pnet_rand=pnet_rand,
|
68 |
+
pnet_tune=pnet_tune,
|
69 |
+
use_dropout=True,
|
70 |
+
model_path=model_path,
|
71 |
+
eval_mode=False,
|
72 |
+
)
|
73 |
+
elif self.model == "baseline": # pretrained network
|
74 |
+
self.net = LPIPS(pnet_rand=pnet_rand, net=net, lpips=False)
|
75 |
+
elif self.model in ["L2", "l2"]:
|
76 |
+
self.net = L2(
|
77 |
+
use_gpu=use_gpu, colorspace=colorspace
|
78 |
+
) # not really a network, only for testing
|
79 |
+
self.model_name = "L2"
|
80 |
+
elif self.model in ["DSSIM", "dssim", "SSIM", "ssim"]:
|
81 |
+
self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace)
|
82 |
+
self.model_name = "SSIM"
|
83 |
+
else:
|
84 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
85 |
+
|
86 |
+
self.parameters = list(self.net.parameters())
|
87 |
+
|
88 |
+
if self.is_train: # training mode
|
89 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
90 |
+
self.rankLoss = BCERankingLoss()
|
91 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
92 |
+
self.lr = lr
|
93 |
+
self.old_lr = lr
|
94 |
+
self.optimizer_net = torch.optim.Adam(
|
95 |
+
self.parameters, lr=lr, betas=(beta1, 0.999)
|
96 |
+
)
|
97 |
+
else: # test mode
|
98 |
+
self.net.eval()
|
99 |
+
|
100 |
+
if use_gpu:
|
101 |
+
self.net.to(gpu_ids[0])
|
102 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
103 |
+
if self.is_train:
|
104 |
+
self.rankLoss = self.rankLoss.to(
|
105 |
+
device=gpu_ids[0]
|
106 |
+
) # just put this on GPU0
|
107 |
+
|
108 |
+
if printNet:
|
109 |
+
pass
|
110 |
+
|
111 |
+
def forward(self, in0, in1, retPerLayer=False):
|
112 |
+
"""Function computes the distance between image patches in0 and in1
|
113 |
+
INPUTS
|
114 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
115 |
+
OUTPUT
|
116 |
+
computed distances between in0 and in1
|
117 |
+
"""
|
118 |
+
|
119 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
120 |
+
|
121 |
+
# ***** TRAINING FUNCTIONS *****
|
122 |
+
def optimize_parameters(self):
|
123 |
+
self.forward_train()
|
124 |
+
self.optimizer_net.zero_grad()
|
125 |
+
self.backward_train()
|
126 |
+
self.optimizer_net.step()
|
127 |
+
self.clamp_weights()
|
128 |
+
|
129 |
+
def clamp_weights(self):
|
130 |
+
for module in self.net.modules():
|
131 |
+
if hasattr(module, "weight") and module.kernel_size == (1, 1):
|
132 |
+
module.weight.data = torch.clamp(module.weight.data, min=0)
|
133 |
+
|
134 |
+
def set_input(self, data):
|
135 |
+
self.input_ref = data["ref"]
|
136 |
+
self.input_p0 = data["p0"]
|
137 |
+
self.input_p1 = data["p1"]
|
138 |
+
self.input_judge = data["judge"]
|
139 |
+
|
140 |
+
if self.use_gpu:
|
141 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
142 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
143 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
144 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
145 |
+
|
146 |
+
self.var_ref = Variable(self.input_ref, requires_grad=True)
|
147 |
+
self.var_p0 = Variable(self.input_p0, requires_grad=True)
|
148 |
+
self.var_p1 = Variable(self.input_p1, requires_grad=True)
|
149 |
+
|
150 |
+
def forward_train(self): # run forward pass
|
151 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
152 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
153 |
+
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
|
154 |
+
|
155 |
+
self.var_judge = Variable(1.0 * self.input_judge).view(self.d0.size())
|
156 |
+
|
157 |
+
self.loss_total = self.rankLoss.forward(
|
158 |
+
self.d0, self.d1, self.var_judge * 2.0 - 1.0
|
159 |
+
)
|
160 |
+
|
161 |
+
return self.loss_total
|
162 |
+
|
163 |
+
def backward_train(self):
|
164 |
+
torch.mean(self.loss_total).backward()
|
165 |
+
|
166 |
+
def compute_accuracy(self, d0, d1, judge):
|
167 |
+
"""d0, d1 are Variables, judge is a Tensor"""
|
168 |
+
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
|
169 |
+
judge_per = judge.cpu().numpy().flatten()
|
170 |
+
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
|
171 |
+
|
172 |
+
def get_current_errors(self):
|
173 |
+
retDict = OrderedDict(
|
174 |
+
[("loss_total", self.loss_total.data.cpu().numpy()), ("acc_r", self.acc_r)]
|
175 |
+
)
|
176 |
+
|
177 |
+
for key in retDict.keys():
|
178 |
+
retDict[key] = np.mean(retDict[key])
|
179 |
+
|
180 |
+
return retDict
|
181 |
+
|
182 |
+
def get_current_visuals(self):
|
183 |
+
zoom_factor = 256 / self.var_ref.data.size()[2]
|
184 |
+
|
185 |
+
ref_img = tensor2im(self.var_ref.data)
|
186 |
+
p0_img = tensor2im(self.var_p0.data)
|
187 |
+
p1_img = tensor2im(self.var_p1.data)
|
188 |
+
|
189 |
+
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
|
190 |
+
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
|
191 |
+
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
|
192 |
+
|
193 |
+
return OrderedDict(
|
194 |
+
[("ref", ref_img_vis), ("p0", p0_img_vis), ("p1", p1_img_vis)]
|
195 |
+
)
|
196 |
+
|
197 |
+
def save(self, path, label):
|
198 |
+
if self.use_gpu:
|
199 |
+
self.save_network(self.net.module, path, "", label)
|
200 |
+
else:
|
201 |
+
self.save_network(self.net, path, "", label)
|
202 |
+
self.save_network(self.rankLoss.net, path, "rank", label)
|
203 |
+
|
204 |
+
# helper saving function that can be used by subclasses
|
205 |
+
def save_network(self, network, path, network_label, epoch_label):
|
206 |
+
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
|
207 |
+
save_path = os.path.join(path, save_filename)
|
208 |
+
torch.save(network.state_dict(), save_path)
|
209 |
+
|
210 |
+
# helper loading function that can be used by subclasses
|
211 |
+
def load_network(self, network, network_label, epoch_label):
|
212 |
+
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
|
213 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
214 |
+
print("Loading network from %s" % save_path)
|
215 |
+
network.load_state_dict(torch.load(save_path))
|
216 |
+
|
217 |
+
def update_learning_rate(self, nepoch_decay):
|
218 |
+
lrd = self.lr / nepoch_decay
|
219 |
+
lr = self.old_lr - lrd
|
220 |
+
|
221 |
+
for param_group in self.optimizer_net.param_groups:
|
222 |
+
param_group["lr"] = lr
|
223 |
+
|
224 |
+
print("update lr [%s] decay: %f -> %f" % (type, self.old_lr, lr))
|
225 |
+
self.old_lr = lr
|
226 |
+
|
227 |
+
def get_image_paths(self):
|
228 |
+
return self.image_paths
|
229 |
+
|
230 |
+
def save_done(self, flag=False):
|
231 |
+
np.save(os.path.join(self.save_dir, "done_flag"), flag)
|
232 |
+
np.savetxt(
|
233 |
+
os.path.join(self.save_dir, "done_flag"),
|
234 |
+
[
|
235 |
+
flag,
|
236 |
+
],
|
237 |
+
fmt="%i",
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
def score_2afc_dataset(data_loader, func, name=""):
|
242 |
+
"""Function computes Two Alternative Forced Choice (2AFC) score using
|
243 |
+
distance function 'func' in dataset 'data_loader'
|
244 |
+
INPUTS
|
245 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
246 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
247 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
248 |
+
OUTPUTS
|
249 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
250 |
+
[1] - dictionary with following elements
|
251 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
252 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
253 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
254 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
255 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
256 |
+
CONSTS
|
257 |
+
N - number of test triplets in data_loader
|
258 |
+
"""
|
259 |
+
|
260 |
+
d0s = []
|
261 |
+
d1s = []
|
262 |
+
gts = []
|
263 |
+
|
264 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
265 |
+
d0s += func(data["ref"], data["p0"]).data.cpu().numpy().flatten().tolist()
|
266 |
+
d1s += func(data["ref"], data["p1"]).data.cpu().numpy().flatten().tolist()
|
267 |
+
gts += data["judge"].cpu().numpy().flatten().tolist()
|
268 |
+
|
269 |
+
d0s = np.array(d0s)
|
270 |
+
d1s = np.array(d1s)
|
271 |
+
gts = np.array(gts)
|
272 |
+
scores = (d0s < d1s) * (1.0 - gts) + (d1s < d0s) * gts + (d1s == d0s) * 0.5
|
273 |
+
|
274 |
+
return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
|
275 |
+
|
276 |
+
|
277 |
+
def score_jnd_dataset(data_loader, func, name=""):
|
278 |
+
"""Function computes JND score using distance function 'func' in dataset 'data_loader'
|
279 |
+
INPUTS
|
280 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
281 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
282 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
283 |
+
OUTPUTS
|
284 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
285 |
+
[1] - dictionary with following elements
|
286 |
+
ds - N array containing distances between two patches shown to human evaluator
|
287 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
288 |
+
CONSTS
|
289 |
+
N - number of test triplets in data_loader
|
290 |
+
"""
|
291 |
+
|
292 |
+
ds = []
|
293 |
+
gts = []
|
294 |
+
|
295 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
296 |
+
ds += func(data["p0"], data["p1"]).data.cpu().numpy().tolist()
|
297 |
+
gts += data["same"].cpu().numpy().flatten().tolist()
|
298 |
+
|
299 |
+
sames = np.array(gts)
|
300 |
+
ds = np.array(ds)
|
301 |
+
|
302 |
+
sorted_inds = np.argsort(ds)
|
303 |
+
ds_sorted = ds[sorted_inds]
|
304 |
+
sames_sorted = sames[sorted_inds]
|
305 |
+
|
306 |
+
TPs = np.cumsum(sames_sorted)
|
307 |
+
FPs = np.cumsum(1 - sames_sorted)
|
308 |
+
FNs = np.sum(sames_sorted) - TPs
|
309 |
+
|
310 |
+
precs = TPs / (TPs + FPs)
|
311 |
+
recs = TPs / (TPs + FNs)
|
312 |
+
score = voc_ap(recs, precs)
|
313 |
+
|
314 |
+
return (score, dict(ds=ds, sames=sames))
|
kit/metrics/lpips/utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_tensor(in_feat, eps=1e-10):
|
9 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
|
10 |
+
return in_feat / (norm_factor + eps)
|
11 |
+
|
12 |
+
|
13 |
+
def l2(p0, p1, range=255.0):
|
14 |
+
return 0.5 * np.mean((p0 / range - p1 / range) ** 2)
|
15 |
+
|
16 |
+
|
17 |
+
def psnr(p0, p1, peak=255.0):
|
18 |
+
return 10 * np.log10(peak**2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2))
|
19 |
+
|
20 |
+
|
21 |
+
def dssim(p0, p1, range=255.0):
|
22 |
+
from skimage.measure import compare_ssim
|
23 |
+
|
24 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0
|
25 |
+
|
26 |
+
|
27 |
+
def tensor2np(tensor_obj):
|
28 |
+
# change dimension of a tensor object into a numpy array
|
29 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
|
30 |
+
|
31 |
+
|
32 |
+
def np2tensor(np_obj):
|
33 |
+
# change dimenion of np array into tensor array
|
34 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
35 |
+
|
36 |
+
|
37 |
+
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
|
38 |
+
# image tensor to lab tensor
|
39 |
+
from skimage import color
|
40 |
+
|
41 |
+
img = tensor2im(image_tensor)
|
42 |
+
img_lab = color.rgb2lab(img)
|
43 |
+
if mc_only:
|
44 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
45 |
+
if to_norm and not mc_only:
|
46 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
47 |
+
img_lab = img_lab / 100.0
|
48 |
+
|
49 |
+
return np2tensor(img_lab)
|
50 |
+
|
51 |
+
|
52 |
+
def tensorlab2tensor(lab_tensor, return_inbnd=False):
|
53 |
+
from skimage import color
|
54 |
+
import warnings
|
55 |
+
|
56 |
+
warnings.filterwarnings("ignore")
|
57 |
+
|
58 |
+
lab = tensor2np(lab_tensor) * 100.0
|
59 |
+
lab[:, :, 0] = lab[:, :, 0] + 50
|
60 |
+
|
61 |
+
rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1)
|
62 |
+
if return_inbnd:
|
63 |
+
# convert back to lab, see if we match
|
64 |
+
lab_back = color.rgb2lab(rgb_back.astype("uint8"))
|
65 |
+
mask = 1.0 * np.isclose(lab_back, lab, atol=2.0)
|
66 |
+
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
|
67 |
+
return (im2tensor(rgb_back), mask)
|
68 |
+
else:
|
69 |
+
return im2tensor(rgb_back)
|
70 |
+
|
71 |
+
|
72 |
+
def load_image(path):
|
73 |
+
if (
|
74 |
+
path[-3:] == "bmp"
|
75 |
+
or path[-3:] == "jpg"
|
76 |
+
or path[-3:] == "png"
|
77 |
+
or path[-4:] == "jpeg"
|
78 |
+
):
|
79 |
+
import cv2
|
80 |
+
|
81 |
+
return cv2.imread(path)[:, :, ::-1]
|
82 |
+
else:
|
83 |
+
import matplotlib.pyplot as plt
|
84 |
+
|
85 |
+
img = (255 * plt.imread(path)[:, :, :3]).astype("uint8")
|
86 |
+
|
87 |
+
return img
|
88 |
+
|
89 |
+
|
90 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
|
91 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
92 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
93 |
+
return image_numpy.astype(imtype)
|
94 |
+
|
95 |
+
|
96 |
+
def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
|
97 |
+
return torch.Tensor(
|
98 |
+
(image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def tensor2vec(vector_tensor):
|
103 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
104 |
+
|
105 |
+
|
106 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
107 |
+
"""ap = voc_ap(rec, prec, [use_07_metric])
|
108 |
+
Compute VOC AP given precision and recall.
|
109 |
+
If use_07_metric is true, uses the
|
110 |
+
VOC 07 11 point method (default:False).
|
111 |
+
"""
|
112 |
+
if use_07_metric:
|
113 |
+
# 11 point metric
|
114 |
+
ap = 0.0
|
115 |
+
for t in np.arange(0.0, 1.1, 0.1):
|
116 |
+
if np.sum(rec >= t) == 0:
|
117 |
+
p = 0
|
118 |
+
else:
|
119 |
+
p = np.max(prec[rec >= t])
|
120 |
+
ap = ap + p / 11.0
|
121 |
+
else:
|
122 |
+
# correct AP calculation
|
123 |
+
# first append sentinel values at the end
|
124 |
+
mrec = np.concatenate(([0.0], rec, [1.0]))
|
125 |
+
mpre = np.concatenate(([0.0], prec, [0.0]))
|
126 |
+
|
127 |
+
# compute the precision envelope
|
128 |
+
for i in range(mpre.size - 1, 0, -1):
|
129 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
130 |
+
|
131 |
+
# to calculate area under PR curve, look for points
|
132 |
+
# where X axis (recall) changes value
|
133 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
134 |
+
|
135 |
+
# and sum (\Delta recall) * prec
|
136 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
137 |
+
return ap
|
kit/metrics/lpips/weights/v0.0/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
|
3 |
+
size 5455
|
kit/metrics/lpips/weights/v0.0/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
|
3 |
+
size 10057
|
kit/metrics/lpips/weights/v0.0/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
|
3 |
+
size 6735
|
kit/metrics/lpips/weights/v0.1/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
3 |
+
size 6009
|
kit/metrics/lpips/weights/v0.1/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
|
3 |
+
size 10811
|
kit/metrics/lpips/weights/v0.1/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
kit/metrics/perceptual.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision import transforms
|
4 |
+
from .lpips import LPIPS
|
5 |
+
|
6 |
+
|
7 |
+
# Normalize image tensors
|
8 |
+
def normalize_tensor(images, norm_type):
|
9 |
+
assert norm_type in ["imagenet", "naive"]
|
10 |
+
# Two possible normalization conventions
|
11 |
+
if norm_type == "imagenet":
|
12 |
+
mean = [0.485, 0.456, 0.406]
|
13 |
+
std = [0.229, 0.224, 0.225]
|
14 |
+
normalize = transforms.Normalize(mean, std)
|
15 |
+
elif norm_type == "naive":
|
16 |
+
mean = [0.5, 0.5, 0.5]
|
17 |
+
std = [0.5, 0.5, 0.5]
|
18 |
+
normalize = transforms.Normalize(mean, std)
|
19 |
+
else:
|
20 |
+
assert False
|
21 |
+
return torch.stack([normalize(image) for image in images])
|
22 |
+
|
23 |
+
|
24 |
+
def to_tensor(images, norm_type="naive"):
|
25 |
+
assert isinstance(images, list) and all(
|
26 |
+
[isinstance(image, Image.Image) for image in images]
|
27 |
+
)
|
28 |
+
images = torch.stack([transforms.ToTensor()(image) for image in images])
|
29 |
+
if norm_type is not None:
|
30 |
+
images = normalize_tensor(images, norm_type)
|
31 |
+
return images
|
32 |
+
|
33 |
+
|
34 |
+
def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
|
35 |
+
assert metric_name in ["lpips"]
|
36 |
+
if metric_name == "lpips":
|
37 |
+
assert mode in ["vgg", "alex"]
|
38 |
+
perceptual_model = LPIPS(net=mode).to(device)
|
39 |
+
else:
|
40 |
+
assert False
|
41 |
+
return perceptual_model
|
42 |
+
|
43 |
+
|
44 |
+
# Compute metric between two images
|
45 |
+
def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")):
|
46 |
+
assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image)
|
47 |
+
image1_tensor = to_tensor([image1]).to(device)
|
48 |
+
image2_tensor = to_tensor([image2]).to(device)
|
49 |
+
return perceptual_model(image1_tensor, image2_tensor).cpu().item()
|
50 |
+
|
51 |
+
|
52 |
+
# Compute LPIPS distance between two images
|
53 |
+
def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")):
|
54 |
+
perceptual_model = load_perceptual_models("lpips", mode, device)
|
55 |
+
return compute_metric(image1, image2, perceptual_model, device)
|
56 |
+
|
57 |
+
|
58 |
+
# Compute metrics between pairs of images
|
59 |
+
def compute_perceptual_metric_repeated(
|
60 |
+
images1,
|
61 |
+
images2,
|
62 |
+
metric_name,
|
63 |
+
mode,
|
64 |
+
model,
|
65 |
+
device,
|
66 |
+
):
|
67 |
+
# Accept list of PIL images
|
68 |
+
assert isinstance(images1, list) and isinstance(images1[0], Image.Image)
|
69 |
+
assert isinstance(images2, list) and isinstance(images2[0], Image.Image)
|
70 |
+
assert len(images1) == len(images2)
|
71 |
+
if model is None:
|
72 |
+
model = load_perceptual_models(metric_name, mode).to(device)
|
73 |
+
return (
|
74 |
+
model(to_tensor(images1).to(device), to_tensor(images2).to(device))
|
75 |
+
.detach()
|
76 |
+
.cpu()
|
77 |
+
.numpy()
|
78 |
+
.flatten()
|
79 |
+
.tolist()
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
# Compute LPIPS distance between pairs of images
|
84 |
+
def compute_lpips_repeated(
|
85 |
+
images1,
|
86 |
+
images2,
|
87 |
+
mode="alex",
|
88 |
+
model=None,
|
89 |
+
device=torch.device("cuda"),
|
90 |
+
):
|
91 |
+
return compute_perceptual_metric_repeated(
|
92 |
+
images1, images2, "lpips", mode, model, device
|
93 |
+
)
|
kit/metrics/prompt.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
3 |
+
|
4 |
+
|
5 |
+
# Load GPT-2 large model and tokenizer
|
6 |
+
def load_perplexity_model_and_tokenizer():
|
7 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
8 |
+
ppl_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device)
|
9 |
+
ppl_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-large")
|
10 |
+
return ppl_model, ppl_tokenizer
|
11 |
+
|
12 |
+
|
13 |
+
# Compute perplexity for a single prompt
|
14 |
+
def compute_prompt_perplexity(prompt, models, stride=512):
|
15 |
+
assert isinstance(prompt, str)
|
16 |
+
assert isinstance(models, tuple) and len(models) == 2
|
17 |
+
ppl_model, ppl_tokenizer = models
|
18 |
+
encodings = ppl_tokenizer(prompt, return_tensors="pt")
|
19 |
+
max_length = ppl_model.config.n_positions
|
20 |
+
seq_len = encodings.input_ids.size(1)
|
21 |
+
nlls = []
|
22 |
+
prev_end_loc = 0
|
23 |
+
for begin_loc in range(0, seq_len, stride):
|
24 |
+
end_loc = min(begin_loc + max_length, seq_len)
|
25 |
+
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
26 |
+
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(
|
27 |
+
next(ppl_model.parameters()).device
|
28 |
+
)
|
29 |
+
target_ids = input_ids.clone()
|
30 |
+
target_ids[:, :-trg_len] = -100
|
31 |
+
with torch.no_grad():
|
32 |
+
outputs = ppl_model(input_ids, labels=target_ids)
|
33 |
+
neg_log_likelihood = outputs.loss
|
34 |
+
nlls.append(neg_log_likelihood)
|
35 |
+
prev_end_loc = end_loc
|
36 |
+
if end_loc == seq_len:
|
37 |
+
break
|
38 |
+
ppl = torch.exp(torch.stack(nlls).mean()).item()
|
39 |
+
return ppl
|