YagmurCA commited on
Commit
ca2bc6f
·
1 Parent(s): 3de6fa5

environment setup version 1

Browse files
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import clip
5
+ from PIL import Image
6
+ import os
7
+ import numpy as np
8
+ from matplotlib import pyplot as plt
9
+ from io import BytesIO
10
+
11
+ # Load CLIP model
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model, preprocess = clip.load("ViT-B/32", device=device)
14
+
15
+
16
+ f = open("debug.txt", "w+")
17
+
18
+ # Examples for Zero-Shot Classification
19
+ classification_examples = [
20
+ ["example_images/turkish_tea.jpeg", "turkish tea, coffee, water"]
21
+ ]
22
+
23
+ # Examples for Image Retrieval
24
+ gallery_image_paths_list = os.listdir("gallery")
25
+ gallery_image_paths_list = ["gallery/"+path for path in gallery_image_paths_list]
26
+ f.write(str(gallery_image_paths_list))
27
+ retrieval_examples = [
28
+ [gallery_image_paths_list, "example_images/banana_73.jpg"]
29
+ ]
30
+
31
+ # Zero-shot classification function
32
+ def zero_shot_classification(image, classnames):
33
+ classnames = [cls.strip() for cls in classnames.split(",")]
34
+ text_inputs = clip.tokenize(classnames).to(device)
35
+
36
+ img_processed = preprocess(image).unsqueeze(0).to(device)
37
+ with torch.no_grad():
38
+ logits_per_image, _ = model(img_processed, text_inputs)
39
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
40
+
41
+ results = {classnames[i]: probs[i] for i in range(len(classnames))}
42
+
43
+ return results
44
+
45
+ # Image retrieval function
46
+ def image_retrieval(gallery, query):
47
+ query_processed = preprocess(query).unsqueeze(0).to(device)
48
+ query_embedding = model.encode_image(query_processed)
49
+ query_embedding /= query_embedding.norm(dim=-1, keepdim=True)
50
+
51
+ rank_list = []
52
+
53
+ for img in gallery:
54
+ # img = Image.open(img[0])
55
+ img_processed = preprocess(img[0]).unsqueeze(0).to(device)
56
+ embedding = model.encode_image(img_processed)
57
+ embedding /= embedding.norm(dim=-1, keepdim=True)
58
+
59
+ similarity_score = (100.0 * query_embedding @ embedding.t()).item()
60
+ similarity_score = round(similarity_score,3)
61
+ rank_list.append([similarity_score, img[0]])
62
+
63
+
64
+ rank_list = sorted(rank_list, key=lambda x: x[0], reverse = True)
65
+
66
+
67
+ fig = plt.figure(figsize=(5,20))
68
+ plot_length = 11
69
+
70
+ for i in range(1,plot_length):
71
+ gallery_ax = fig.add_subplot(plot_length,1,i)
72
+ img = rank_list[i][1]
73
+ gallery_ax.imshow(img)
74
+ gallery_ax.set_title('%.3f'% rank_list[i][0], fontsize=10) #add similarity score as title
75
+ gallery_ax.axis('off')
76
+
77
+ # Return the top-10 images with their similarity scores
78
+ buf = BytesIO()
79
+ fig.savefig(buf, format="png", bbox_inches="tight") # Save figure to buffer
80
+ buf.seek(0)
81
+ img = Image.open(buf) # Open buffer as PIL Image
82
+ return img
83
+
84
+ # Define Gradio interface for zero-shot classification
85
+ classification_interface = gr.Interface(
86
+ fn=zero_shot_classification,
87
+ inputs=[
88
+ gr.Image(type="pil", label="Input Image", sources=["upload"]),
89
+ gr.Textbox(lines=3, placeholder="Enter labels separated by commas, e.g., dog, cat, car", label="Class Labels"),
90
+ ],
91
+ examples=classification_examples,
92
+ outputs=gr.Label(label="Classification Probabilities"),
93
+ )
94
+
95
+ # Define Gradio interface for image retrieval
96
+ retrieval_interface = gr.Interface(
97
+ fn=image_retrieval,
98
+ inputs=[
99
+ gr.Gallery(label="Gallery Folder", type="pil", columns=[3], rows=[10], object_fit="contain", height="auto"),
100
+ gr.Image(type="pil", label="Query Image"),
101
+ ],
102
+ outputs= gr.Image(type="pil", label="Top-10 Retrieved Images"),
103
+ examples=retrieval_examples,
104
+ )
105
+
106
+ # Combine the interfaces into a single Gradio app
107
+ app = gr.Blocks()
108
+
109
+ with app:
110
+ with gr.Row():
111
+ with gr.Column():
112
+ gr.Markdown("## Zero-shot Classification")
113
+ classification_interface.render()
114
+ with gr.Column():
115
+ gr.Markdown("## Image Retrieval")
116
+ retrieval_interface.render()
117
+
118
+ # Launch the app
119
+ app.launch()
example_images/banana_73.jpg ADDED
example_images/turkish_tea.jpeg ADDED
gallery/apple_11.jpg ADDED
gallery/apple_14.jpg ADDED
gallery/apple_21.jpg ADDED
gallery/apple_4.jpg ADDED
gallery/apple_46.jpg ADDED
gallery/apple_5.jpg ADDED
gallery/apple_52.jpg ADDED
gallery/apple_55.jpg ADDED
gallery/apple_60.jpg ADDED
gallery/apple_71.jpg ADDED
gallery/banana_27.jpg ADDED
gallery/banana_40.jpg ADDED
gallery/banana_48.jpg ADDED
gallery/banana_50.jpg ADDED
gallery/banana_52.jpg ADDED
gallery/banana_59.jpg ADDED
gallery/banana_61.jpg ADDED
gallery/banana_70.jpg ADDED
gallery/mixed_1.jpg ADDED
gallery/mixed_17.jpg ADDED
gallery/mixed_18.jpg ADDED
gallery/mixed_19.jpg ADDED
gallery/mixed_20.jpg ADDED
gallery/mixed_3.jpg ADDED
gallery/mixed_4.jpg ADDED
gallery/orange_1.jpg ADDED
gallery/orange_10.jpg ADDED
gallery/orange_2.jpg ADDED
gallery/orange_3.jpg ADDED
gallery/orange_4.jpg ADDED
gallery/orange_5.jpg ADDED
gallery/orange_6.jpg ADDED
gallery/orange_7.jpg ADDED
gallery/orange_8.jpg ADDED
gallery/orange_9.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ torchaudio==0.13.1
4
+ gradio
5
+ clip-by-openai
6
+ pillow
7
+ matplotlib
8
+ ftfy