PuristanLabs1 commited on
Commit
1510baf
·
verified ·
1 Parent(s): e8b1380

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import clip
6
+ from PIL import Image
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+
9
+ # Load CLIP model
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model, preprocess = clip.load("ViT-B/32", device=device)
12
+
13
+ # Configuration
14
+ FLAG_IMAGE_DIR = "/content/drive/MyDrive/Flags/flags"
15
+
16
+ # Function to search flags with specific queries using CLIP
17
+ def search_by_query(query, top_n=10):
18
+ """Search flags based on a text query using CLIP."""
19
+ # Encode the text query
20
+ with torch.no_grad():
21
+ text_embedding = model.encode_text(clip.tokenize([query]).to(device))
22
+
23
+ # Compare the query embedding with all flag embeddings
24
+ similarities = {}
25
+ for flag, embedding in flag_embeddings.items():
26
+ similarity = cosine_similarity(text_embedding.cpu().numpy(), embedding)[0][0]
27
+ similarities[flag] = similarity
28
+
29
+ # Sort and return the top_n results
30
+ sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
31
+ results = []
32
+ for flag_file, similarity in sorted_flags[:top_n]:
33
+ flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file)
34
+ if os.path.exists(flag_path):
35
+ results.append((flag_path, f"{get_country_name(flag_file)} (Similarity: {similarity:.3f})"))
36
+ else:
37
+ print(f"File not found: {flag_file}")
38
+ return results
39
+
40
+ # Get all image paths
41
+ image_paths = [
42
+ os.path.join(FLAG_IMAGE_DIR, img)
43
+ for img in os.listdir(FLAG_IMAGE_DIR)
44
+ if img.endswith((".png", ".jpg", ".jpeg"))
45
+ ]
46
+
47
+ # Load precomputed embeddings
48
+ FLAG_EMBEDDINGS_PATH = "/content/drive/MyDrive/flag_embeddings_1.npy"
49
+ flag_embeddings = np.load(FLAG_EMBEDDINGS_PATH, allow_pickle=True).item()
50
+
51
+ def get_country_name(image_filename):
52
+ """Extract country name from image filename."""
53
+ return os.path.splitext(os.path.basename(image_filename))[0].upper()
54
+
55
+ def get_image_embedding(image_path):
56
+ """Get embedding for an input image."""
57
+ image = Image.open(image_path).convert("RGB")
58
+ image_input = preprocess(image).unsqueeze(0).to(device)
59
+ with torch.no_grad():
60
+ embedding = model.encode_image(image_input)
61
+ return embedding.cpu().numpy()
62
+
63
+ def find_similar_flags(image_path, top_n=10):
64
+ """Find similar flags based on cosine similarity."""
65
+ query_embedding = get_image_embedding(image_path)
66
+
67
+ similarities = {}
68
+ for flag, embedding in flag_embeddings.items():
69
+ similarity = cosine_similarity(query_embedding, embedding)[0][0]
70
+ similarities[flag] = similarity
71
+
72
+ sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
73
+ return sorted_flags[1:top_n + 1] # Skip the first one as it's the same flag
74
+
75
+ def search_flags(query):
76
+ """Search flags based on country name."""
77
+ if not query:
78
+ return image_paths
79
+ return [img for img in image_paths if query.lower() in get_country_name(img).lower()]
80
+
81
+ def analyze_and_display(selected_flag):
82
+ """Main function to analyze flag similarity and prepare display."""
83
+ try:
84
+ if selected_flag is None:
85
+ return None
86
+
87
+ similar_flags = find_similar_flags(selected_flag)
88
+ output_images = []
89
+
90
+ for flag_file, similarity in similar_flags:
91
+ flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file)
92
+ country_name = get_country_name(flag_file)
93
+ output_images.append((flag_path, f"{country_name} (Similarity: {similarity:.3f})"))
94
+
95
+ return output_images
96
+ except Exception as e:
97
+ return gr.Error(f"Error processing image: {str(e)}")
98
+
99
+ # Create Gradio interface
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown("# Flag Similarity Analysis")
102
+ gr.Markdown("Select a flag from the gallery to find similar flags based on visual features or search using text queries.")
103
+
104
+ with gr.Tabs():
105
+ with gr.Tab("Similarity Search"):
106
+ with gr.Row():
107
+ with gr.Column(scale=1):
108
+ # Search and input gallery
109
+ search_box = gr.Textbox(label="Search Flags", placeholder="Enter country name...")
110
+ #query_box = gr.Textbox(label="Search by Query", placeholder="e.g., 'crescent in the center'")
111
+ input_gallery = gr.Gallery(
112
+ label="Available Flags",
113
+ show_label=True,
114
+ elem_id="gallery",
115
+ columns=4,
116
+ height="auto"
117
+ )
118
+
119
+ with gr.Column(scale=1):
120
+ # Output gallery
121
+ output_gallery = gr.Gallery(
122
+ label="Similar Flags",
123
+ show_label=True,
124
+ elem_id="output",
125
+ columns=2,
126
+ height="auto"
127
+ )
128
+
129
+ # Event handlers
130
+ def update_gallery(query):
131
+ matching_flags = search_flags(query)
132
+ return [(path, get_country_name(path)) for path in matching_flags]
133
+
134
+ def on_select(evt: gr.SelectData, gallery):
135
+ """Handle flag selection from gallery"""
136
+ selected_flag_path = gallery[evt.index][0]
137
+ return analyze_and_display(selected_flag_path)
138
+
139
+ # Connect event handlers
140
+ search_box.change(
141
+ update_gallery,
142
+ inputs=[search_box],
143
+ outputs=[input_gallery]
144
+ )
145
+
146
+
147
+ input_gallery.select(
148
+ on_select,
149
+ inputs=[input_gallery],
150
+ outputs=[output_gallery]
151
+ )
152
+
153
+ with gr.Tab("Advanced Search"):
154
+ gr.Markdown("### Search Flags with Nuanced Queries")
155
+ nuanced_query_box = gr.Textbox(label="Enter Advanced Query", placeholder="e.g., 'Find flags with crescent' or 'flags with animals'")
156
+ advanced_output_gallery = gr.Gallery(
157
+ label="Matching Flags",
158
+ show_label=True,
159
+ elem_id="advanced_output",
160
+ columns=3,
161
+ height="auto"
162
+ )
163
+
164
+ def advanced_search(query):
165
+ return search_by_query(query)
166
+
167
+ nuanced_query_box.change(
168
+ advanced_search,
169
+ inputs=[nuanced_query_box],
170
+ outputs=[advanced_output_gallery]
171
+ )
172
+
173
+ # Initialize gallery with all flags
174
+ def init_gallery():
175
+ return [(path, get_country_name(path)) for path in image_paths]
176
+
177
+ demo.load(init_gallery, outputs=[input_gallery])
178
+
179
+ # Launch the app
180
+ demo.launch()