Spaces:
Sleeping
Sleeping
environment setup version 1
Browse files- app.py +119 -0
- example_images/banana_73.jpg +0 -0
- example_images/turkish_tea.jpeg +0 -0
- gallery/apple_11.jpg +0 -0
- gallery/apple_14.jpg +0 -0
- gallery/apple_21.jpg +0 -0
- gallery/apple_4.jpg +0 -0
- gallery/apple_46.jpg +0 -0
- gallery/apple_5.jpg +0 -0
- gallery/apple_52.jpg +0 -0
- gallery/apple_55.jpg +0 -0
- gallery/apple_60.jpg +0 -0
- gallery/apple_71.jpg +0 -0
- gallery/banana_27.jpg +0 -0
- gallery/banana_40.jpg +0 -0
- gallery/banana_48.jpg +0 -0
- gallery/banana_50.jpg +0 -0
- gallery/banana_52.jpg +0 -0
- gallery/banana_59.jpg +0 -0
- gallery/banana_61.jpg +0 -0
- gallery/banana_70.jpg +0 -0
- gallery/mixed_1.jpg +0 -0
- gallery/mixed_17.jpg +0 -0
- gallery/mixed_18.jpg +0 -0
- gallery/mixed_19.jpg +0 -0
- gallery/mixed_20.jpg +0 -0
- gallery/mixed_3.jpg +0 -0
- gallery/mixed_4.jpg +0 -0
- gallery/orange_1.jpg +0 -0
- gallery/orange_10.jpg +0 -0
- gallery/orange_2.jpg +0 -0
- gallery/orange_3.jpg +0 -0
- gallery/orange_4.jpg +0 -0
- gallery/orange_5.jpg +0 -0
- gallery/orange_6.jpg +0 -0
- gallery/orange_7.jpg +0 -0
- gallery/orange_8.jpg +0 -0
- gallery/orange_9.jpg +0 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import clip
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from matplotlib import pyplot as plt
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
# Load CLIP model
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
14 |
+
|
15 |
+
|
16 |
+
f = open("debug.txt", "w+")
|
17 |
+
|
18 |
+
# Examples for Zero-Shot Classification
|
19 |
+
classification_examples = [
|
20 |
+
["example_images/turkish_tea.jpeg", "turkish tea, coffee, water"]
|
21 |
+
]
|
22 |
+
|
23 |
+
# Examples for Image Retrieval
|
24 |
+
gallery_image_paths_list = os.listdir("gallery")
|
25 |
+
gallery_image_paths_list = ["gallery/"+path for path in gallery_image_paths_list]
|
26 |
+
f.write(str(gallery_image_paths_list))
|
27 |
+
retrieval_examples = [
|
28 |
+
[gallery_image_paths_list, "example_images/banana_73.jpg"]
|
29 |
+
]
|
30 |
+
|
31 |
+
# Zero-shot classification function
|
32 |
+
def zero_shot_classification(image, classnames):
|
33 |
+
classnames = [cls.strip() for cls in classnames.split(",")]
|
34 |
+
text_inputs = clip.tokenize(classnames).to(device)
|
35 |
+
|
36 |
+
img_processed = preprocess(image).unsqueeze(0).to(device)
|
37 |
+
with torch.no_grad():
|
38 |
+
logits_per_image, _ = model(img_processed, text_inputs)
|
39 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
|
40 |
+
|
41 |
+
results = {classnames[i]: probs[i] for i in range(len(classnames))}
|
42 |
+
|
43 |
+
return results
|
44 |
+
|
45 |
+
# Image retrieval function
|
46 |
+
def image_retrieval(gallery, query):
|
47 |
+
query_processed = preprocess(query).unsqueeze(0).to(device)
|
48 |
+
query_embedding = model.encode_image(query_processed)
|
49 |
+
query_embedding /= query_embedding.norm(dim=-1, keepdim=True)
|
50 |
+
|
51 |
+
rank_list = []
|
52 |
+
|
53 |
+
for img in gallery:
|
54 |
+
# img = Image.open(img[0])
|
55 |
+
img_processed = preprocess(img[0]).unsqueeze(0).to(device)
|
56 |
+
embedding = model.encode_image(img_processed)
|
57 |
+
embedding /= embedding.norm(dim=-1, keepdim=True)
|
58 |
+
|
59 |
+
similarity_score = (100.0 * query_embedding @ embedding.t()).item()
|
60 |
+
similarity_score = round(similarity_score,3)
|
61 |
+
rank_list.append([similarity_score, img[0]])
|
62 |
+
|
63 |
+
|
64 |
+
rank_list = sorted(rank_list, key=lambda x: x[0], reverse = True)
|
65 |
+
|
66 |
+
|
67 |
+
fig = plt.figure(figsize=(5,20))
|
68 |
+
plot_length = 11
|
69 |
+
|
70 |
+
for i in range(1,plot_length):
|
71 |
+
gallery_ax = fig.add_subplot(plot_length,1,i)
|
72 |
+
img = rank_list[i][1]
|
73 |
+
gallery_ax.imshow(img)
|
74 |
+
gallery_ax.set_title('%.3f'% rank_list[i][0], fontsize=10) #add similarity score as title
|
75 |
+
gallery_ax.axis('off')
|
76 |
+
|
77 |
+
# Return the top-10 images with their similarity scores
|
78 |
+
buf = BytesIO()
|
79 |
+
fig.savefig(buf, format="png", bbox_inches="tight") # Save figure to buffer
|
80 |
+
buf.seek(0)
|
81 |
+
img = Image.open(buf) # Open buffer as PIL Image
|
82 |
+
return img
|
83 |
+
|
84 |
+
# Define Gradio interface for zero-shot classification
|
85 |
+
classification_interface = gr.Interface(
|
86 |
+
fn=zero_shot_classification,
|
87 |
+
inputs=[
|
88 |
+
gr.Image(type="pil", label="Input Image", sources=["upload"]),
|
89 |
+
gr.Textbox(lines=3, placeholder="Enter labels separated by commas, e.g., dog, cat, car", label="Class Labels"),
|
90 |
+
],
|
91 |
+
examples=classification_examples,
|
92 |
+
outputs=gr.Label(label="Classification Probabilities"),
|
93 |
+
)
|
94 |
+
|
95 |
+
# Define Gradio interface for image retrieval
|
96 |
+
retrieval_interface = gr.Interface(
|
97 |
+
fn=image_retrieval,
|
98 |
+
inputs=[
|
99 |
+
gr.Gallery(label="Gallery Folder", type="pil", columns=[3], rows=[10], object_fit="contain", height="auto"),
|
100 |
+
gr.Image(type="pil", label="Query Image"),
|
101 |
+
],
|
102 |
+
outputs= gr.Image(type="pil", label="Top-10 Retrieved Images"),
|
103 |
+
examples=retrieval_examples,
|
104 |
+
)
|
105 |
+
|
106 |
+
# Combine the interfaces into a single Gradio app
|
107 |
+
app = gr.Blocks()
|
108 |
+
|
109 |
+
with app:
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
gr.Markdown("## Zero-shot Classification")
|
113 |
+
classification_interface.render()
|
114 |
+
with gr.Column():
|
115 |
+
gr.Markdown("## Image Retrieval")
|
116 |
+
retrieval_interface.render()
|
117 |
+
|
118 |
+
# Launch the app
|
119 |
+
app.launch()
|
example_images/banana_73.jpg
ADDED
![]() |
example_images/turkish_tea.jpeg
ADDED
![]() |
gallery/apple_11.jpg
ADDED
![]() |
gallery/apple_14.jpg
ADDED
![]() |
gallery/apple_21.jpg
ADDED
![]() |
gallery/apple_4.jpg
ADDED
![]() |
gallery/apple_46.jpg
ADDED
![]() |
gallery/apple_5.jpg
ADDED
![]() |
gallery/apple_52.jpg
ADDED
![]() |
gallery/apple_55.jpg
ADDED
![]() |
gallery/apple_60.jpg
ADDED
![]() |
gallery/apple_71.jpg
ADDED
![]() |
gallery/banana_27.jpg
ADDED
![]() |
gallery/banana_40.jpg
ADDED
![]() |
gallery/banana_48.jpg
ADDED
![]() |
gallery/banana_50.jpg
ADDED
![]() |
gallery/banana_52.jpg
ADDED
![]() |
gallery/banana_59.jpg
ADDED
![]() |
gallery/banana_61.jpg
ADDED
![]() |
gallery/banana_70.jpg
ADDED
![]() |
gallery/mixed_1.jpg
ADDED
![]() |
gallery/mixed_17.jpg
ADDED
![]() |
gallery/mixed_18.jpg
ADDED
![]() |
gallery/mixed_19.jpg
ADDED
![]() |
gallery/mixed_20.jpg
ADDED
![]() |
gallery/mixed_3.jpg
ADDED
![]() |
gallery/mixed_4.jpg
ADDED
![]() |
gallery/orange_1.jpg
ADDED
![]() |
gallery/orange_10.jpg
ADDED
![]() |
gallery/orange_2.jpg
ADDED
![]() |
gallery/orange_3.jpg
ADDED
![]() |
gallery/orange_4.jpg
ADDED
![]() |
gallery/orange_5.jpg
ADDED
![]() |
gallery/orange_6.jpg
ADDED
![]() |
gallery/orange_7.jpg
ADDED
![]() |
gallery/orange_8.jpg
ADDED
![]() |
gallery/orange_9.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
torchaudio==0.13.1
|
4 |
+
gradio
|
5 |
+
clip-by-openai
|
6 |
+
pillow
|
7 |
+
matplotlib
|
8 |
+
ftfy
|