Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import clip
|
6 |
+
from PIL import Image
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
+
|
9 |
+
# Load CLIP model
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
12 |
+
|
13 |
+
# Configuration
|
14 |
+
FLAG_IMAGE_DIR = "/content/drive/MyDrive/Flags/flags"
|
15 |
+
|
16 |
+
# Function to search flags with specific queries using CLIP
|
17 |
+
def search_by_query(query, top_n=10):
|
18 |
+
"""Search flags based on a text query using CLIP."""
|
19 |
+
# Encode the text query
|
20 |
+
with torch.no_grad():
|
21 |
+
text_embedding = model.encode_text(clip.tokenize([query]).to(device))
|
22 |
+
|
23 |
+
# Compare the query embedding with all flag embeddings
|
24 |
+
similarities = {}
|
25 |
+
for flag, embedding in flag_embeddings.items():
|
26 |
+
similarity = cosine_similarity(text_embedding.cpu().numpy(), embedding)[0][0]
|
27 |
+
similarities[flag] = similarity
|
28 |
+
|
29 |
+
# Sort and return the top_n results
|
30 |
+
sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
|
31 |
+
results = []
|
32 |
+
for flag_file, similarity in sorted_flags[:top_n]:
|
33 |
+
flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file)
|
34 |
+
if os.path.exists(flag_path):
|
35 |
+
results.append((flag_path, f"{get_country_name(flag_file)} (Similarity: {similarity:.3f})"))
|
36 |
+
else:
|
37 |
+
print(f"File not found: {flag_file}")
|
38 |
+
return results
|
39 |
+
|
40 |
+
# Get all image paths
|
41 |
+
image_paths = [
|
42 |
+
os.path.join(FLAG_IMAGE_DIR, img)
|
43 |
+
for img in os.listdir(FLAG_IMAGE_DIR)
|
44 |
+
if img.endswith((".png", ".jpg", ".jpeg"))
|
45 |
+
]
|
46 |
+
|
47 |
+
# Load precomputed embeddings
|
48 |
+
FLAG_EMBEDDINGS_PATH = "/content/drive/MyDrive/flag_embeddings_1.npy"
|
49 |
+
flag_embeddings = np.load(FLAG_EMBEDDINGS_PATH, allow_pickle=True).item()
|
50 |
+
|
51 |
+
def get_country_name(image_filename):
|
52 |
+
"""Extract country name from image filename."""
|
53 |
+
return os.path.splitext(os.path.basename(image_filename))[0].upper()
|
54 |
+
|
55 |
+
def get_image_embedding(image_path):
|
56 |
+
"""Get embedding for an input image."""
|
57 |
+
image = Image.open(image_path).convert("RGB")
|
58 |
+
image_input = preprocess(image).unsqueeze(0).to(device)
|
59 |
+
with torch.no_grad():
|
60 |
+
embedding = model.encode_image(image_input)
|
61 |
+
return embedding.cpu().numpy()
|
62 |
+
|
63 |
+
def find_similar_flags(image_path, top_n=10):
|
64 |
+
"""Find similar flags based on cosine similarity."""
|
65 |
+
query_embedding = get_image_embedding(image_path)
|
66 |
+
|
67 |
+
similarities = {}
|
68 |
+
for flag, embedding in flag_embeddings.items():
|
69 |
+
similarity = cosine_similarity(query_embedding, embedding)[0][0]
|
70 |
+
similarities[flag] = similarity
|
71 |
+
|
72 |
+
sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
|
73 |
+
return sorted_flags[1:top_n + 1] # Skip the first one as it's the same flag
|
74 |
+
|
75 |
+
def search_flags(query):
|
76 |
+
"""Search flags based on country name."""
|
77 |
+
if not query:
|
78 |
+
return image_paths
|
79 |
+
return [img for img in image_paths if query.lower() in get_country_name(img).lower()]
|
80 |
+
|
81 |
+
def analyze_and_display(selected_flag):
|
82 |
+
"""Main function to analyze flag similarity and prepare display."""
|
83 |
+
try:
|
84 |
+
if selected_flag is None:
|
85 |
+
return None
|
86 |
+
|
87 |
+
similar_flags = find_similar_flags(selected_flag)
|
88 |
+
output_images = []
|
89 |
+
|
90 |
+
for flag_file, similarity in similar_flags:
|
91 |
+
flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file)
|
92 |
+
country_name = get_country_name(flag_file)
|
93 |
+
output_images.append((flag_path, f"{country_name} (Similarity: {similarity:.3f})"))
|
94 |
+
|
95 |
+
return output_images
|
96 |
+
except Exception as e:
|
97 |
+
return gr.Error(f"Error processing image: {str(e)}")
|
98 |
+
|
99 |
+
# Create Gradio interface
|
100 |
+
with gr.Blocks() as demo:
|
101 |
+
gr.Markdown("# Flag Similarity Analysis")
|
102 |
+
gr.Markdown("Select a flag from the gallery to find similar flags based on visual features or search using text queries.")
|
103 |
+
|
104 |
+
with gr.Tabs():
|
105 |
+
with gr.Tab("Similarity Search"):
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column(scale=1):
|
108 |
+
# Search and input gallery
|
109 |
+
search_box = gr.Textbox(label="Search Flags", placeholder="Enter country name...")
|
110 |
+
#query_box = gr.Textbox(label="Search by Query", placeholder="e.g., 'crescent in the center'")
|
111 |
+
input_gallery = gr.Gallery(
|
112 |
+
label="Available Flags",
|
113 |
+
show_label=True,
|
114 |
+
elem_id="gallery",
|
115 |
+
columns=4,
|
116 |
+
height="auto"
|
117 |
+
)
|
118 |
+
|
119 |
+
with gr.Column(scale=1):
|
120 |
+
# Output gallery
|
121 |
+
output_gallery = gr.Gallery(
|
122 |
+
label="Similar Flags",
|
123 |
+
show_label=True,
|
124 |
+
elem_id="output",
|
125 |
+
columns=2,
|
126 |
+
height="auto"
|
127 |
+
)
|
128 |
+
|
129 |
+
# Event handlers
|
130 |
+
def update_gallery(query):
|
131 |
+
matching_flags = search_flags(query)
|
132 |
+
return [(path, get_country_name(path)) for path in matching_flags]
|
133 |
+
|
134 |
+
def on_select(evt: gr.SelectData, gallery):
|
135 |
+
"""Handle flag selection from gallery"""
|
136 |
+
selected_flag_path = gallery[evt.index][0]
|
137 |
+
return analyze_and_display(selected_flag_path)
|
138 |
+
|
139 |
+
# Connect event handlers
|
140 |
+
search_box.change(
|
141 |
+
update_gallery,
|
142 |
+
inputs=[search_box],
|
143 |
+
outputs=[input_gallery]
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
input_gallery.select(
|
148 |
+
on_select,
|
149 |
+
inputs=[input_gallery],
|
150 |
+
outputs=[output_gallery]
|
151 |
+
)
|
152 |
+
|
153 |
+
with gr.Tab("Advanced Search"):
|
154 |
+
gr.Markdown("### Search Flags with Nuanced Queries")
|
155 |
+
nuanced_query_box = gr.Textbox(label="Enter Advanced Query", placeholder="e.g., 'Find flags with crescent' or 'flags with animals'")
|
156 |
+
advanced_output_gallery = gr.Gallery(
|
157 |
+
label="Matching Flags",
|
158 |
+
show_label=True,
|
159 |
+
elem_id="advanced_output",
|
160 |
+
columns=3,
|
161 |
+
height="auto"
|
162 |
+
)
|
163 |
+
|
164 |
+
def advanced_search(query):
|
165 |
+
return search_by_query(query)
|
166 |
+
|
167 |
+
nuanced_query_box.change(
|
168 |
+
advanced_search,
|
169 |
+
inputs=[nuanced_query_box],
|
170 |
+
outputs=[advanced_output_gallery]
|
171 |
+
)
|
172 |
+
|
173 |
+
# Initialize gallery with all flags
|
174 |
+
def init_gallery():
|
175 |
+
return [(path, get_country_name(path)) for path in image_paths]
|
176 |
+
|
177 |
+
demo.load(init_gallery, outputs=[input_gallery])
|
178 |
+
|
179 |
+
# Launch the app
|
180 |
+
demo.launch()
|