from transformers import ViTImageProcessor, ViTForImageClassification import gradio as gr from datasets import load_dataset import torch import random import numpy as np import pandas as pd # tabulate in dependencies def get_predictions(image): inputs = processor(image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get top n predictions top_indices = logits[0].argsort(dim=-1, descending=True) probabilities = torch.softmax(logits, dim=-1)[0, top_indices] labels = [model.config.id2label[idx.item()] for idx in top_indices] predictions = {} for i, label in enumerate(labels): predictions[label] = probabilities[i] return predictions data = load_dataset("marcelomoreno26/geoguessr",split="test") model_name = "marcelomoreno26/vit-base-patch-16-384-geoguessr" # Initialize the feature extractor processor = ViTImageProcessor.from_pretrained(model_name) # Initialize the model model = ViTForImageClassification.from_pretrained(model_name) length = len(data) countries = [] with open("countries.txt", "r") as file: for line in file: countries.append(line.strip()) def get_result(selection): global correct_country global model_prediction global filtered_predictions if selection == correct_country and correct_country == model_prediction: result = "It's a draw!" elif selection == correct_country: result = "Congratulations! You won!" elif correct_country == model_prediction: result = "Sorry, you lost. The AI guessed it right!" else: result = "Sorry, you both lost." total_prob = sum([(float(value)) for value in filtered_predictions.values()]) prob_per_country = [(key,np.round(float(value)/total_prob,3)*100) for key,value in filtered_predictions.items()] df = pd.DataFrame(prob_per_country,columns=["Country","Model Confidence (%)"]).sort_values(by="Model Confidence (%)",ascending=False) ai_confidence = f"The AI's guess was {model_prediction}\n\nAI's Results:\n"+ df.to_markdown(index=False) return f"The correct country was: {correct_country}\n{result}", ai_confidence def load(): global filtered_predictions # Randomly select an image i = random.randint(0, len(data) - 1) image = data[i]['image'] correct_country = data[i]['label'] # Randomly sample 4 countries as options options = [country for country in random.sample(countries, 4) if country != correct_country] options.append(correct_country) random.shuffle(options) # Get model predictions predictions = get_predictions(image) filtered_predictions = {country: predictions[country] for country in options} model_prediction = max(filtered_predictions, key=filtered_predictions.get) return image, options, correct_country, model_prediction def reload(): global correct_country global model_prediction global filtered_predictions # Randomly select an image i = random.randint(0, len(data) - 1) image = data[i]['image'] correct_country = data[i]['label'] # Randomly sample 4 countries as options options = [country for country in random.sample(countries, 4) if country != correct_country] options.append(correct_country) random.shuffle(options) # Get model predictions predictions = get_predictions(image) filtered_predictions = {country: predictions[country] for country in options} model_prediction = max(filtered_predictions, key=filtered_predictions.get) return gr.Image(image), gr.Radio(choices=options, label ="Select the country:"), "", "" with gr.Blocks() as demo: image, options, correct_country, model_prediction = load() gr.Markdown("# GeoGuessr - Can You Beat the AI?") gr.Markdown("Try to guess the country in the image. Can you beat the AI?") img = gr.Image(image) radio = gr.Radio(choices=options, label ="Select the country:") ai_pred = gr.Markdown() text = gr.Text(label="Result") radio.select(fn=get_result, inputs=radio, outputs=[text,ai_pred]) btn = gr.Button(value="Get New Image") btn.click(reload, None,outputs=[img,radio,text,ai_pred]) demo.launch()