marcelomoreno26 commited on
Commit
17f6c62
1 Parent(s): 69c6042

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTImageProcessor, ViTForImageClassification
2
+ import gradio as gr
3
+ from datasets import load_dataset
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ import pandas as pd
8
+ # tabulate in dependencies
9
+
10
+
11
+
12
+ def get_predictions(image):
13
+ inputs = processor(image, return_tensors="pt")
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ logits = outputs.logits
17
+
18
+ # Get top n predictions
19
+ top_indices = logits[0].argsort(dim=-1, descending=True)
20
+ probabilities = torch.softmax(logits, dim=-1)[0, top_indices]
21
+ labels = [model.config.id2label[idx.item()] for idx in top_indices]
22
+
23
+ predictions = {}
24
+ for i, label in enumerate(labels):
25
+ predictions[label] = probabilities[i]
26
+
27
+ return predictions
28
+
29
+ data = load_dataset("marcelomoreno26/geoguessr",split="test")
30
+ model_name = "marcelomoreno26/vit-base-patch-16-384-geoguessr"
31
+
32
+ # Initialize the feature extractor
33
+ processor = ViTImageProcessor.from_pretrained(model_name)
34
+ # Initialize the model
35
+ model = ViTForImageClassification.from_pretrained(model_name)
36
+
37
+
38
+ length = len(data)
39
+ countries = []
40
+
41
+ with open("countries.txt", "r") as file:
42
+ for line in file:
43
+ countries.append(line.strip())
44
+
45
+
46
+ def get_result(selection):
47
+ global correct_country
48
+ global model_prediction
49
+ global filtered_predictions
50
+ if selection == correct_country and correct_country == model_prediction:
51
+ result = "It's a draw!"
52
+ elif selection == correct_country:
53
+ result = "Congratulations! You won!"
54
+ elif correct_country == model_prediction:
55
+ result = "Sorry, you lost. The AI guessed it right!"
56
+ else:
57
+ result = "Sorry, you both lost."
58
+
59
+ total_prob = sum([(float(value)) for value in filtered_predictions.values()])
60
+ prob_per_country = [(key,np.round(float(value)/total_prob,3)*100) for key,value in filtered_predictions.items()]
61
+ df = pd.DataFrame(prob_per_country,columns=["Country","Model Confidence (%)"]).sort_values(by="Model Confidence (%)",ascending=False)
62
+ ai_confidence = f"The AI's guess was {model_prediction}\n\nAI's Results:\n"+ df.to_markdown(index=False)
63
+
64
+
65
+
66
+ return f"The correct country was: {correct_country}\n{result}", ai_confidence
67
+
68
+
69
+ def load():
70
+ global filtered_predictions
71
+ # Randomly select an image
72
+ i = random.randint(0, len(data) - 1)
73
+ image = data[i]['image']
74
+ correct_country = data[i]['label']
75
+
76
+ # Randomly sample 4 countries as options
77
+ options = [country for country in random.sample(countries, 4) if country != correct_country]
78
+ options.append(correct_country)
79
+ random.shuffle(options)
80
+
81
+ # Get model predictions
82
+ predictions = get_predictions(image)
83
+ filtered_predictions = {country: predictions[country] for country in options}
84
+ model_prediction = max(filtered_predictions, key=filtered_predictions.get)
85
+
86
+ return image, options, correct_country, model_prediction
87
+
88
+
89
+ def reload():
90
+ global correct_country
91
+ global model_prediction
92
+ global filtered_predictions
93
+ # Randomly select an image
94
+ i = random.randint(0, len(data) - 1)
95
+ image = data[i]['image']
96
+ correct_country = data[i]['label']
97
+
98
+ # Randomly sample 4 countries as options
99
+ options = [country for country in random.sample(countries, 4) if country != correct_country]
100
+ options.append(correct_country)
101
+ random.shuffle(options)
102
+
103
+ # Get model predictions
104
+ predictions = get_predictions(image)
105
+ filtered_predictions = {country: predictions[country] for country in options}
106
+ model_prediction = max(filtered_predictions, key=filtered_predictions.get)
107
+
108
+
109
+ return gr.Image(image), gr.Radio(choices=options, label ="Select the country:"), "", ""
110
+
111
+
112
+
113
+ with gr.Blocks() as demo:
114
+
115
+ image, options, correct_country, model_prediction = load()
116
+
117
+ gr.Markdown("# GeoGuessr - Can You Beat the AI?")
118
+ gr.Markdown("Try to guess the country in the image. Can you beat the AI?")
119
+ img = gr.Image(image)
120
+ radio = gr.Radio(choices=options, label ="Select the country:")
121
+ ai_pred = gr.Markdown()
122
+ text = gr.Text(label="Result")
123
+ radio.select(fn=get_result, inputs=radio, outputs=[text,ai_pred])
124
+
125
+ btn = gr.Button(value="Get New Image")
126
+ btn.click(reload, None,outputs=[img,radio,text,ai_pred])
127
+
128
+ demo.launch()