shreyas2509 commited on
Commit
f828007
·
verified ·
1 Parent(s): 03618ce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from pathlib import Path
4
+ import gradio as gr
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ from torchvision import transforms
7
+ import reverse_geocoder as rg
8
+ import folium
9
+ from geopy.exc import GeocoderTimedOut
10
+ from geopy.geocoders import Nominatim
11
+
12
+ # streetclip_model = CLIPModel.from_pretrained("E:/github projects/Country Classification/GeolocationCountryClassification/")
13
+ model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
14
+ processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
15
+ labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
16
+
17
+ def create_map(lat, lon):
18
+ m = folium.Map(location=[lat, lon], zoom_start=4)
19
+ folium.Marker([lat, lon]).add_to(m)
20
+ map_html = m._repr_html_()
21
+ return map_html
22
+
23
+ geolocator = Nominatim(user_agent="predictGeolocforImage")
24
+
25
+ def get_country_coordinates(country_name):
26
+ try:
27
+ location = geolocator.geocode(country_name, timeout=10)
28
+ if location:
29
+ return location.latitude, location.longitude
30
+ except GeocoderTimedOut:
31
+ return None
32
+ return None
33
+
34
+
35
+ def classify_streetclip(image):
36
+ inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ logits_per_image = outputs.logits_per_image
40
+ prediction = logits_per_image.softmax(dim=1)
41
+ confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
42
+
43
+ sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True)
44
+ top_label, top_confidence = sorted_confidences[0]
45
+ coords = get_country_coordinates(top_label)
46
+ map_html = create_map(*coords) if coords else "Map not available"
47
+ return f"Country: {top_label}", map_html, confidences
48
+
49
+ text = '''
50
+ <b style="color: #F36912;">List of countries supported</b>: Albania, Andorra, Argentina, Australia, Austria, Bangladesh, Belgium, Bermuda, Bhutan, Bolivia, Botswana, Brazil, Bulgaria, Cambodia, Canada, Chile, China, Colombia, Croatia, Czech Republic, Denmark, Dominican Republic, Ecuador, Estonia, Finland, France, Germany, Ghana, Greece, Greenland, Guam, Guatemala, Hungary, Iceland, India, Indonesia, Ireland, Israel, Italy, Japan, Jordan, Kenya, Kyrgyzstan, Laos, Latvia, Lesotho, Lithuania, Luxembourg, Macedonia, Madagascar, Malaysia, Malta, Mexico, Monaco, Mongolia, Montenegro, Netherlands, New Zealand, Nigeria, Norway, Pakistan, Palestine, Peru, Philippines, Poland, Portugal, Puerto Rico, Romania, Russia, Rwanda, Senegal, Serbia, Singapore, Slovakia, Slovenia, South Africa, South Korea, Spain, Sri Lanka, Swaziland, Sweden, Switzerland, Taiwan, Thailand, Tunisia, Turkey, Uganda, Ukraine, United Arab Emirates, United Kingdom, United States, Uruguay
51
+ </p>
52
+ ---<br>
53
+ <span style="color: #F24F13;">You may choose to use the images provided below, or feel free to upload your own images.</span>
54
+ '''
55
+
56
+ interface = gr.Interface(
57
+ fn=classify_streetclip,
58
+ inputs=gr.Image(type="pil", label="Upload Image", elem_id="image_input"),
59
+ outputs=[gr.Textbox(label="Prediction", elem_id="output"), gr.HTML(label="Map", elem_id="map_output"), gr.Label(num_top_classes=10,label="Top 10 countries")],
60
+ title="COUNTRY GUESSER",
61
+ description=text,
62
+ article="<span style='color: #F24F13;'>Model is not running on a GPU, so the interpretation takes some time. Thank you for your patience🙏🏻</span>",
63
+ examples=["taj.jpg","stockholm.jpeg","palace-square-saint-petersburg.jpg","monument.jpg"],
64
+ allow_flagging="never",
65
+ )
66
+
67
+ interface.launch()