Commit
·
0bd8f65
1
Parent(s):
8444c60
Add image support
Browse files- .gitattributes +5 -0
- app.py +98 -22
- examples/46657164_p1.jpg +0 -0
- examples/60378883_p0.jpg +0 -0
- examples/DaRlExxUwAAcUOS-orig.jpg +0 -0
- requirements.txt +2 -0
.gitattributes
CHANGED
@@ -34,3 +34,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.index filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.index filter=lfs diff=lfs merge=lfs -text
|
37 |
+
|
38 |
+
# Byte-compiled / optimized / DLL files
|
39 |
+
__pycache__/
|
40 |
+
*.py[cod]
|
41 |
+
*$py.class
|
app.py
CHANGED
@@ -7,10 +7,20 @@ import jax
|
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
9 |
import requests
|
|
|
10 |
|
11 |
from Models.CLIP import CLIP
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
|
15 |
headers = {"User-Agent": "image_similarity_tool"}
|
16 |
ratings_to_letters = {
|
@@ -56,6 +66,8 @@ class Predictor:
|
|
56 |
|
57 |
def predict(
|
58 |
self,
|
|
|
|
|
59 |
positive_tags,
|
60 |
negative_tags,
|
61 |
selected_ratings,
|
@@ -68,38 +80,68 @@ class Predictor:
|
|
68 |
|
69 |
num_classes = len(tags_df)
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
positive_tags = positive_tags.split(",")
|
72 |
negative_tags = negative_tags.split(",")
|
73 |
|
74 |
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
|
75 |
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
if len(negative_tags_idxs) > 0:
|
88 |
tags = np.zeros((1, num_classes), dtype=np.float32)
|
89 |
tags[0][negative_tags_idxs] = 1
|
90 |
|
91 |
-
|
92 |
{"params": self.params},
|
93 |
tags,
|
94 |
method=model.encode_text,
|
95 |
)
|
96 |
-
|
97 |
-
faiss.normalize_L2(
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
dists, indexes = self.knn_index.search(
|
103 |
neighbours_ids = self.images_ids[indexes][0]
|
104 |
neighbours_ids = [int(x) for x in neighbours_ids]
|
105 |
|
@@ -122,10 +164,19 @@ def main():
|
|
122 |
predictor = Predictor()
|
123 |
|
124 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
125 |
with gr.Row():
|
126 |
with gr.Column():
|
127 |
positive_tags = gr.Textbox(label="Positive tags")
|
128 |
negative_tags = gr.Textbox(label="Negative tags")
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
n_neighbours = gr.Slider(
|
130 |
minimum=1,
|
131 |
maximum=20,
|
@@ -133,15 +184,10 @@ def main():
|
|
133 |
step=1,
|
134 |
label="# of images",
|
135 |
)
|
136 |
-
|
137 |
with gr.Column():
|
138 |
api_username = gr.Textbox(label="Danbooru API Username")
|
139 |
api_key = gr.Textbox(label="Danbooru API Key")
|
140 |
-
|
141 |
-
choices=["General", "Sensitive", "Questionable", "Explicit"],
|
142 |
-
value=["General", "Sensitive"],
|
143 |
-
label="Ratings",
|
144 |
-
)
|
145 |
find_btn = gr.Button("Find similar images")
|
146 |
|
147 |
similar_images = gr.Gallery(label="Similar images", columns=[5])
|
@@ -149,6 +195,8 @@ def main():
|
|
149 |
examples = gr.Examples(
|
150 |
[
|
151 |
[
|
|
|
|
|
152 |
"marcille_donato",
|
153 |
"",
|
154 |
["General", "Sensitive"],
|
@@ -157,6 +205,8 @@ def main():
|
|
157 |
"",
|
158 |
],
|
159 |
[
|
|
|
|
|
160 |
"yellow_eyes,red_horns",
|
161 |
"",
|
162 |
["General", "Sensitive"],
|
@@ -165,6 +215,8 @@ def main():
|
|
165 |
"",
|
166 |
],
|
167 |
[
|
|
|
|
|
168 |
"artoria_pendragon_(fate),solo",
|
169 |
"excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
|
170 |
["General", "Sensitive"],
|
@@ -172,8 +224,30 @@ def main():
|
|
172 |
"",
|
173 |
"",
|
174 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
],
|
176 |
inputs=[
|
|
|
|
|
177 |
positive_tags,
|
178 |
negative_tags,
|
179 |
selected_ratings,
|
@@ -190,6 +264,8 @@ def main():
|
|
190 |
find_btn.click(
|
191 |
fn=predictor.predict,
|
192 |
inputs=[
|
|
|
|
|
193 |
positive_tags,
|
194 |
negative_tags,
|
195 |
selected_ratings,
|
|
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
9 |
import requests
|
10 |
+
from imgutils.tagging import wd14
|
11 |
|
12 |
from Models.CLIP import CLIP
|
13 |
|
14 |
|
15 |
+
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
|
16 |
+
pos = pos_img_embs + pos_tags_embs
|
17 |
+
|
18 |
+
neg = neg_img_embs + neg_tags_embs
|
19 |
+
|
20 |
+
result = pos - neg
|
21 |
+
return result
|
22 |
+
|
23 |
+
|
24 |
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
|
25 |
headers = {"User-Agent": "image_similarity_tool"}
|
26 |
ratings_to_letters = {
|
|
|
66 |
|
67 |
def predict(
|
68 |
self,
|
69 |
+
pos_img_input,
|
70 |
+
neg_img_input,
|
71 |
positive_tags,
|
72 |
negative_tags,
|
73 |
selected_ratings,
|
|
|
80 |
|
81 |
num_classes = len(tags_df)
|
82 |
|
83 |
+
output_shape = model.out_units
|
84 |
+
pos_img_embs = np.zeros((1, output_shape), dtype=np.float32)
|
85 |
+
neg_img_embs = np.zeros((1, output_shape), dtype=np.float32)
|
86 |
+
pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
|
87 |
+
neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32)
|
88 |
+
|
89 |
positive_tags = positive_tags.split(",")
|
90 |
negative_tags = negative_tags.split(",")
|
91 |
|
92 |
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
|
93 |
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
|
94 |
|
95 |
+
if pos_img_input is not None:
|
96 |
+
pos_img_embs = wd14.get_wd14_tags(
|
97 |
+
pos_img_input,
|
98 |
+
model_name="ConvNext",
|
99 |
+
fmt=("embedding"),
|
100 |
+
)
|
101 |
+
pos_img_embs = np.expand_dims(pos_img_embs, 0)
|
102 |
+
faiss.normalize_L2(pos_img_embs)
|
103 |
+
|
104 |
+
if neg_img_input is not None:
|
105 |
+
neg_img_embs = wd14.get_wd14_tags(
|
106 |
+
neg_img_input,
|
107 |
+
model_name="ConvNext",
|
108 |
+
fmt=("embedding"),
|
109 |
+
)
|
110 |
+
neg_img_embs = np.expand_dims(neg_img_embs, 0)
|
111 |
+
faiss.normalize_L2(neg_img_embs)
|
112 |
+
|
113 |
+
if len(positive_tags_idxs) > 0:
|
114 |
+
tags = np.zeros((1, num_classes), dtype=np.float32)
|
115 |
+
tags[0][positive_tags_idxs] = 1
|
116 |
+
|
117 |
+
pos_tags_embs = model.apply(
|
118 |
+
{"params": self.params},
|
119 |
+
tags,
|
120 |
+
method=model.encode_text,
|
121 |
+
)
|
122 |
+
pos_tags_embs = jax.device_get(pos_tags_embs)
|
123 |
+
faiss.normalize_L2(pos_tags_embs)
|
124 |
|
125 |
if len(negative_tags_idxs) > 0:
|
126 |
tags = np.zeros((1, num_classes), dtype=np.float32)
|
127 |
tags[0][negative_tags_idxs] = 1
|
128 |
|
129 |
+
neg_tags_embs = model.apply(
|
130 |
{"params": self.params},
|
131 |
tags,
|
132 |
method=model.encode_text,
|
133 |
)
|
134 |
+
neg_tags_embs = jax.device_get(neg_tags_embs)
|
135 |
+
faiss.normalize_L2(neg_tags_embs)
|
136 |
+
|
137 |
+
embeddings = combine_embeddings(
|
138 |
+
pos_img_embs,
|
139 |
+
pos_tags_embs,
|
140 |
+
neg_img_embs,
|
141 |
+
neg_tags_embs,
|
142 |
+
)
|
143 |
|
144 |
+
dists, indexes = self.knn_index.search(embeddings, k=n_neighbours)
|
145 |
neighbours_ids = self.images_ids[indexes][0]
|
146 |
neighbours_ids = [int(x) for x in neighbours_ids]
|
147 |
|
|
|
164 |
predictor = Predictor()
|
165 |
|
166 |
with gr.Blocks() as demo:
|
167 |
+
with gr.Row():
|
168 |
+
pos_img_input = gr.Image(type="pil", label="Positive input")
|
169 |
+
neg_img_input = gr.Image(type="pil", label="Negative input")
|
170 |
with gr.Row():
|
171 |
with gr.Column():
|
172 |
positive_tags = gr.Textbox(label="Positive tags")
|
173 |
negative_tags = gr.Textbox(label="Negative tags")
|
174 |
+
with gr.Column():
|
175 |
+
selected_ratings = gr.CheckboxGroup(
|
176 |
+
choices=["General", "Sensitive", "Questionable", "Explicit"],
|
177 |
+
value=["General", "Sensitive"],
|
178 |
+
label="Ratings",
|
179 |
+
)
|
180 |
n_neighbours = gr.Slider(
|
181 |
minimum=1,
|
182 |
maximum=20,
|
|
|
184 |
step=1,
|
185 |
label="# of images",
|
186 |
)
|
|
|
187 |
with gr.Column():
|
188 |
api_username = gr.Textbox(label="Danbooru API Username")
|
189 |
api_key = gr.Textbox(label="Danbooru API Key")
|
190 |
+
|
|
|
|
|
|
|
|
|
191 |
find_btn = gr.Button("Find similar images")
|
192 |
|
193 |
similar_images = gr.Gallery(label="Similar images", columns=[5])
|
|
|
195 |
examples = gr.Examples(
|
196 |
[
|
197 |
[
|
198 |
+
None,
|
199 |
+
None,
|
200 |
"marcille_donato",
|
201 |
"",
|
202 |
["General", "Sensitive"],
|
|
|
205 |
"",
|
206 |
],
|
207 |
[
|
208 |
+
None,
|
209 |
+
None,
|
210 |
"yellow_eyes,red_horns",
|
211 |
"",
|
212 |
["General", "Sensitive"],
|
|
|
215 |
"",
|
216 |
],
|
217 |
[
|
218 |
+
None,
|
219 |
+
None,
|
220 |
"artoria_pendragon_(fate),solo",
|
221 |
"excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
|
222 |
["General", "Sensitive"],
|
|
|
224 |
"",
|
225 |
"",
|
226 |
],
|
227 |
+
[
|
228 |
+
"examples/60378883_p0.jpg",
|
229 |
+
None,
|
230 |
+
"fujimaru_ritsuka_(female)",
|
231 |
+
"solo",
|
232 |
+
["General", "Sensitive"],
|
233 |
+
5,
|
234 |
+
"",
|
235 |
+
"",
|
236 |
+
],
|
237 |
+
[
|
238 |
+
"examples/DaRlExxUwAAcUOS-orig.jpg",
|
239 |
+
"examples/46657164_p1.jpg",
|
240 |
+
"",
|
241 |
+
"",
|
242 |
+
["General", "Sensitive"],
|
243 |
+
5,
|
244 |
+
"",
|
245 |
+
"",
|
246 |
+
],
|
247 |
],
|
248 |
inputs=[
|
249 |
+
pos_img_input,
|
250 |
+
neg_img_input,
|
251 |
positive_tags,
|
252 |
negative_tags,
|
253 |
selected_ratings,
|
|
|
264 |
find_btn.click(
|
265 |
fn=predictor.predict,
|
266 |
inputs=[
|
267 |
+
pos_img_input,
|
268 |
+
neg_img_input,
|
269 |
positive_tags,
|
270 |
negative_tags,
|
271 |
selected_ratings,
|
examples/46657164_p1.jpg
ADDED
![]() |
examples/60378883_p0.jpg
ADDED
![]() |
examples/DaRlExxUwAAcUOS-orig.jpg
ADDED
![]() |
requirements.txt
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
faiss-cpu
|
2 |
jax[cpu]
|
3 |
flax
|
|
|
|
|
|
1 |
faiss-cpu
|
2 |
jax[cpu]
|
3 |
flax
|
4 |
+
imgutils
|
5 |
+
onnxruntime
|