marcelomoreno26's picture
Create app.py
17f6c62 verified
raw
history blame
4.23 kB
from transformers import ViTImageProcessor, ViTForImageClassification
import gradio as gr
from datasets import load_dataset
import torch
import random
import numpy as np
import pandas as pd
# tabulate in dependencies
def get_predictions(image):
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get top n predictions
top_indices = logits[0].argsort(dim=-1, descending=True)
probabilities = torch.softmax(logits, dim=-1)[0, top_indices]
labels = [model.config.id2label[idx.item()] for idx in top_indices]
predictions = {}
for i, label in enumerate(labels):
predictions[label] = probabilities[i]
return predictions
data = load_dataset("marcelomoreno26/geoguessr",split="test")
model_name = "marcelomoreno26/vit-base-patch-16-384-geoguessr"
# Initialize the feature extractor
processor = ViTImageProcessor.from_pretrained(model_name)
# Initialize the model
model = ViTForImageClassification.from_pretrained(model_name)
length = len(data)
countries = []
with open("countries.txt", "r") as file:
for line in file:
countries.append(line.strip())
def get_result(selection):
global correct_country
global model_prediction
global filtered_predictions
if selection == correct_country and correct_country == model_prediction:
result = "It's a draw!"
elif selection == correct_country:
result = "Congratulations! You won!"
elif correct_country == model_prediction:
result = "Sorry, you lost. The AI guessed it right!"
else:
result = "Sorry, you both lost."
total_prob = sum([(float(value)) for value in filtered_predictions.values()])
prob_per_country = [(key,np.round(float(value)/total_prob,3)*100) for key,value in filtered_predictions.items()]
df = pd.DataFrame(prob_per_country,columns=["Country","Model Confidence (%)"]).sort_values(by="Model Confidence (%)",ascending=False)
ai_confidence = f"The AI's guess was {model_prediction}\n\nAI's Results:\n"+ df.to_markdown(index=False)
return f"The correct country was: {correct_country}\n{result}", ai_confidence
def load():
global filtered_predictions
# Randomly select an image
i = random.randint(0, len(data) - 1)
image = data[i]['image']
correct_country = data[i]['label']
# Randomly sample 4 countries as options
options = [country for country in random.sample(countries, 4) if country != correct_country]
options.append(correct_country)
random.shuffle(options)
# Get model predictions
predictions = get_predictions(image)
filtered_predictions = {country: predictions[country] for country in options}
model_prediction = max(filtered_predictions, key=filtered_predictions.get)
return image, options, correct_country, model_prediction
def reload():
global correct_country
global model_prediction
global filtered_predictions
# Randomly select an image
i = random.randint(0, len(data) - 1)
image = data[i]['image']
correct_country = data[i]['label']
# Randomly sample 4 countries as options
options = [country for country in random.sample(countries, 4) if country != correct_country]
options.append(correct_country)
random.shuffle(options)
# Get model predictions
predictions = get_predictions(image)
filtered_predictions = {country: predictions[country] for country in options}
model_prediction = max(filtered_predictions, key=filtered_predictions.get)
return gr.Image(image), gr.Radio(choices=options, label ="Select the country:"), "", ""
with gr.Blocks() as demo:
image, options, correct_country, model_prediction = load()
gr.Markdown("# GeoGuessr - Can You Beat the AI?")
gr.Markdown("Try to guess the country in the image. Can you beat the AI?")
img = gr.Image(image)
radio = gr.Radio(choices=options, label ="Select the country:")
ai_pred = gr.Markdown()
text = gr.Text(label="Result")
radio.select(fn=get_result, inputs=radio, outputs=[text,ai_pred])
btn = gr.Button(value="Get New Image")
btn.click(reload, None,outputs=[img,radio,text,ai_pred])
demo.launch()