|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import torch |
|
import random |
|
import numpy as np |
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
def get_predictions(image): |
|
inputs = processor(image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
|
|
top_indices = logits[0].argsort(dim=-1, descending=True) |
|
probabilities = torch.softmax(logits, dim=-1)[0, top_indices] |
|
labels = [model.config.id2label[idx.item()] for idx in top_indices] |
|
|
|
predictions = {} |
|
for i, label in enumerate(labels): |
|
predictions[label] = probabilities[i] |
|
|
|
return predictions |
|
|
|
data = load_dataset("marcelomoreno26/geoguessr",split="test") |
|
model_name = "marcelomoreno26/vit-base-patch-16-384-geoguessr" |
|
|
|
|
|
processor = ViTImageProcessor.from_pretrained(model_name) |
|
|
|
model = ViTForImageClassification.from_pretrained(model_name) |
|
|
|
|
|
length = len(data) |
|
countries = [] |
|
|
|
with open("countries.txt", "r") as file: |
|
for line in file: |
|
countries.append(line.strip()) |
|
|
|
|
|
def get_result(selection): |
|
global correct_country |
|
global model_prediction |
|
global filtered_predictions |
|
if selection == correct_country and correct_country == model_prediction: |
|
result = "It's a draw!" |
|
elif selection == correct_country: |
|
result = "Congratulations! You won!" |
|
elif correct_country == model_prediction: |
|
result = "Sorry, you lost. The AI guessed it right!" |
|
else: |
|
result = "Sorry, you both lost." |
|
|
|
total_prob = sum([(float(value)) for value in filtered_predictions.values()]) |
|
prob_per_country = [(key,np.round(float(value)/total_prob,3)*100) for key,value in filtered_predictions.items()] |
|
df = pd.DataFrame(prob_per_country,columns=["Country","Model Confidence (%)"]).sort_values(by="Model Confidence (%)",ascending=False) |
|
ai_confidence = f"The AI's guess was {model_prediction}\n\nAI's Results:\n"+ df.to_markdown(index=False) |
|
|
|
|
|
|
|
return f"The correct country was: {correct_country}\n{result}", ai_confidence |
|
|
|
|
|
def load(): |
|
global filtered_predictions |
|
|
|
i = random.randint(0, len(data) - 1) |
|
image = data[i]['image'] |
|
correct_country = data[i]['label'] |
|
|
|
|
|
options = [country for country in random.sample(countries, 4) if country != correct_country] |
|
options.append(correct_country) |
|
random.shuffle(options) |
|
|
|
|
|
predictions = get_predictions(image) |
|
filtered_predictions = {country: predictions[country] for country in options} |
|
model_prediction = max(filtered_predictions, key=filtered_predictions.get) |
|
|
|
return image, options, correct_country, model_prediction |
|
|
|
|
|
def reload(): |
|
global correct_country |
|
global model_prediction |
|
global filtered_predictions |
|
|
|
i = random.randint(0, len(data) - 1) |
|
image = data[i]['image'] |
|
correct_country = data[i]['label'] |
|
|
|
|
|
options = [country for country in random.sample(countries, 4) if country != correct_country] |
|
options.append(correct_country) |
|
random.shuffle(options) |
|
|
|
|
|
predictions = get_predictions(image) |
|
filtered_predictions = {country: predictions[country] for country in options} |
|
model_prediction = max(filtered_predictions, key=filtered_predictions.get) |
|
|
|
|
|
return gr.Image(image), gr.Radio(choices=options, label ="Select the country:"), "", "" |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
image, options, correct_country, model_prediction = load() |
|
|
|
gr.Markdown("# GeoGuessr - Can You Beat the AI?") |
|
gr.Markdown("Try to guess the country in the image. Can you beat the AI?") |
|
img = gr.Image(image) |
|
radio = gr.Radio(choices=options, label ="Select the country:") |
|
ai_pred = gr.Markdown() |
|
text = gr.Text(label="Result") |
|
radio.select(fn=get_result, inputs=radio, outputs=[text,ai_pred]) |
|
|
|
btn = gr.Button(value="Get New Image") |
|
btn.click(reload, None,outputs=[img,radio,text,ai_pred]) |
|
|
|
demo.launch() |