Update app.py
Browse files
app.py
CHANGED
@@ -257,23 +257,22 @@ def process_images(images, threshold):
|
|
257 |
|
258 |
with torch.no_grad():
|
259 |
for batch, filenames in dataloader:
|
260 |
-
|
261 |
batch = batch.to(device)
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
|
278 |
return all_results
|
279 |
|
|
|
257 |
|
258 |
with torch.no_grad():
|
259 |
for batch, filenames in dataloader:
|
|
|
260 |
batch = batch.to(device)
|
261 |
+
probabilities = model(batch)
|
262 |
+
for i, prob in enumerate(probabilities):
|
263 |
+
indices = torch.where(prob > threshold)[0]
|
264 |
+
values = prob[indices]
|
265 |
+
|
266 |
+
temp = []
|
267 |
+
tag_score = dict()
|
268 |
+
for j in range(indices.size(0)):
|
269 |
+
tag = allowed_tags[indices[j]]
|
270 |
+
score = values[j].item()
|
271 |
+
temp.append([tag, score])
|
272 |
+
tag_score[tag] = score
|
273 |
+
|
274 |
+
tags = ", ".join([t[0] for t in temp])
|
275 |
+
all_results.append((filenames[i], tags, tag_score))
|
276 |
|
277 |
return all_results
|
278 |
|