File size: 4,232 Bytes
17f6c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from transformers import ViTImageProcessor, ViTForImageClassification
import gradio as gr
from datasets import load_dataset
import torch
import random
import numpy as np
import pandas as pd
# tabulate in dependencies



def get_predictions(image):
    inputs = processor(image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get top n predictions
    top_indices = logits[0].argsort(dim=-1, descending=True)
    probabilities = torch.softmax(logits, dim=-1)[0, top_indices]
    labels = [model.config.id2label[idx.item()] for idx in top_indices]

    predictions = {}
    for i, label in enumerate(labels):
        predictions[label] = probabilities[i]

    return predictions

data = load_dataset("marcelomoreno26/geoguessr",split="test")
model_name = "marcelomoreno26/vit-base-patch-16-384-geoguessr"

# Initialize the feature extractor
processor = ViTImageProcessor.from_pretrained(model_name)
# Initialize the model
model = ViTForImageClassification.from_pretrained(model_name)


length = len(data)
countries = []

with open("countries.txt", "r") as file:
    for line in file:
        countries.append(line.strip())


def get_result(selection):
    global correct_country
    global model_prediction
    global filtered_predictions
    if selection == correct_country and correct_country == model_prediction:
        result = "It's a draw!"
    elif selection == correct_country:
        result = "Congratulations! You won!"
    elif correct_country == model_prediction:
        result = "Sorry, you lost. The AI guessed it right!"
    else:
        result = "Sorry, you both lost."

    total_prob = sum([(float(value)) for value in filtered_predictions.values()])
    prob_per_country = [(key,np.round(float(value)/total_prob,3)*100) for key,value in filtered_predictions.items()]
    df = pd.DataFrame(prob_per_country,columns=["Country","Model Confidence (%)"]).sort_values(by="Model Confidence (%)",ascending=False)
    ai_confidence = f"The AI's guess was {model_prediction}\n\nAI's Results:\n"+ df.to_markdown(index=False)



    return f"The correct country was: {correct_country}\n{result}", ai_confidence


def load():
    global filtered_predictions
    # Randomly select an image
    i = random.randint(0, len(data) - 1)
    image = data[i]['image']
    correct_country = data[i]['label']

    # Randomly sample 4 countries as options
    options = [country for country in random.sample(countries, 4) if country != correct_country]
    options.append(correct_country)
    random.shuffle(options)

    # Get model predictions
    predictions = get_predictions(image)
    filtered_predictions = {country: predictions[country] for country in options}
    model_prediction = max(filtered_predictions, key=filtered_predictions.get)

    return image, options, correct_country, model_prediction


def reload():
    global correct_country
    global model_prediction
    global filtered_predictions
    # Randomly select an image
    i = random.randint(0, len(data) - 1)
    image = data[i]['image']
    correct_country = data[i]['label']

    # Randomly sample 4 countries as options
    options = [country for country in random.sample(countries, 4) if country != correct_country]
    options.append(correct_country)
    random.shuffle(options)

    # Get model predictions
    predictions = get_predictions(image)
    filtered_predictions = {country: predictions[country] for country in options}
    model_prediction = max(filtered_predictions, key=filtered_predictions.get)


    return gr.Image(image), gr.Radio(choices=options, label ="Select the country:"),  "", ""



with gr.Blocks() as demo:
    
    image, options, correct_country, model_prediction = load()

    gr.Markdown("# GeoGuessr - Can You Beat the AI?")
    gr.Markdown("Try to guess the country in the image. Can you beat the AI?")
    img = gr.Image(image)
    radio = gr.Radio(choices=options, label ="Select the country:")
    ai_pred = gr.Markdown()
    text = gr.Text(label="Result")
    radio.select(fn=get_result, inputs=radio, outputs=[text,ai_pred])

    btn = gr.Button(value="Get New Image")
    btn.click(reload, None,outputs=[img,radio,text,ai_pred])

demo.launch()