Spaces:
Running
Running
#importing the libraries | |
import streamlit as st | |
from PIL import Image | |
import torch | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
import numpy as np | |
import pandas as pd | |
import time | |
import os | |
model_repository_id = "Dusduo/Pokemon-classification-1stGen" | |
# Loading the pokemon classifier model and its processor | |
image_processor = AutoImageProcessor.from_pretrained(model_repository_id) | |
model = AutoModelForImageClassification.from_pretrained(model_repository_id) | |
# Loading the pokemon information table | |
pokemon_info_df = pd.read_csv('pokemon_info.csv') | |
pokeball_image = Image.open('pokeball.png').resize((20,20)) | |
#functions to predict image | |
def preprocess(processor: AutoImageProcessor, image): | |
return processor(image.convert("RGB").resize((200,200)), return_tensors="pt") | |
def predict(model: AutoModelForImageClassification, inputs, k=5): | |
# Forward the image to the model and retrieve the logits | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# Convert the retrieved logits into a vector of probabilities for each class | |
probabilities = torch.softmax(logits[0], dim=0).tolist() | |
# Discriminate wether or not the inputted image was an image of a Pokemon | |
# Compute the variance of the vector of probabilities | |
# The spread of the probability values is a good represent of the confusion of the model | |
# Or in other words, its confidence => the greater the spread, the lower its confidence | |
variance = np.var(probabilities) | |
# Too great of a spread: it is likely the image provided did not correspond to any known classes | |
if variance < 0.001: #not a pokemon | |
predicted_label = 'not a pokemon' | |
probability = -1 | |
(top_k_labels, top_k_probability) = '_', '_' | |
else: # it is a pokemon | |
# Retrieve the predicted class (pokemon) | |
predicted_id = logits.argmax(-1).item() | |
predicted_label = model.config.id2label[predicted_id] | |
# Retrieve the probability for the predicted class, and format it to 2 decimals | |
probability = round(probabilities[predicted_id]*100,2) | |
# Retrieve the top 5 classes and their probabilities | |
#top_k_labels = [model.config.id2label[key] for key in np.argpartition(logits.numpy(), -k)[-k:]] | |
#top_k_probability = [round(prob*100,2) for prob in np.sort(probabilities.numpy())[-k:]] | |
return predicted_label, probability #, (top_k_labels, top_k_probability) | |
# Designing the interface ------------------------------------------ | |
# Use the full page instead of a narrow central column | |
st.set_page_config(layout="wide") | |
# Define the title | |
st.title("Gotta Classify 'Em All - 1st Generation Pokedex -") | |
# For newline | |
st.write('\n') | |
col1, col2 = st.columns([3,1]) # [3,1] | |
with col1: | |
image = Image.open('base.jpg') | |
show = st.image(image, use_column_width=True) | |
# Display Sample images ---- | |
st.subheader('Sample images') | |
sample_imgs_dir = "sample_imgs/" | |
sample_imgs = os.listdir(sample_imgs_dir) # get the list of all sample images | |
img_idx = 0 | |
n_cols = 4 | |
groups = [] | |
for i in range(0, len(sample_imgs), n_cols): | |
groups.append(sample_imgs[i:i+n_cols]) | |
for group in groups: | |
cols = st.columns(n_cols) | |
for i,image_file in enumerate(group): | |
cols[i].image(sample_imgs_dir+image_file) | |
# Sidebar work and model outputs --------------- | |
st.sidebar.title("Upload Image") | |
#Disabling warning | |
#st.set_option('deprecation.showfileUploaderEncoding', False) | |
#Choose your own image | |
uploaded_file = st.sidebar.file_uploader("",type=['png', 'jpg', 'jpeg'], accept_multiple_files=False ) | |
if uploaded_file is not None: | |
u_img = Image.open(uploaded_file) | |
show.image(u_img, 'Uploaded Image', width=400 )#use_column_width=True) | |
# Preprocess the image for the model | |
model_inputs = preprocess(image_processor, u_img) | |
# For newline | |
st.sidebar.write('\n') | |
if st.sidebar.button("Click Here to Classify"): | |
if uploaded_file is None: | |
st.sidebar.write("Please upload an Image to Classify") | |
else: | |
with st.spinner('Classifying ...'): | |
# Get prediction | |
prediction, probability = predict(model, model_inputs,5) #, (top_k_labels, top_k_probability) | |
time.sleep(2) | |
st.sidebar.success('Done!') | |
st.sidebar.header("Model predicts: ") | |
# Display prediction | |
if probability==-1: | |
st.sidebar.write("It seems like it is not a picture of a 1st Generation Pokemon alone.", '\n', | |
"There might be too many entities on the image." ) | |
else: | |
st.sidebar.write(f" It's a(n) {prediction} picture.",'\n') | |
st.sidebar.write('Probability:',probability,'%') | |
# Retrieve predicted pokemon information | |
_, pokedex_number, english_name, romaji_name, katakana_name, weight_kg, height_m, type1, type2, color1, color2, classification, evolve_from, evolve_into, is_legendary = pokemon_info_df[pokemon_info_df['name']==prediction].values[0] | |
with col2: | |
# pokedex box | |
with st.container(border=True ): | |
# first row | |
with st.container(): | |
pokeball_image_col,pokedex_number_col, pokemon_name_col = st.columns([1,1,8]) | |
pokeball_image_col.image(pokeball_image) | |
pokedex_number_col.markdown(f'<div style="text-align: left; font-size: 1.4rem;"><b>{pokedex_number}</b></div>', unsafe_allow_html=True) | |
pokemon_name_col.markdown(f'<div style="text-align: right; font-size: 1.4rem;"><b>{english_name}</b></div>', unsafe_allow_html=True) | |
# second row | |
with st.container(): | |
st.markdown(f'<div style="text-align: center; color: {color1}; font-size: 1.2rem;"><b>{classification}</b></div>', unsafe_allow_html=True) | |
# 3rd row | |
with st.container(): | |
if pd.isna(type2): | |
st.write('\n') | |
st.markdown(f'<div style="display: flex; justify-content: center; align-items: center; "><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color1}; color: white;">{type1}</div>', unsafe_allow_html=True) | |
else: | |
type1_col, type2_col = st.columns(2) | |
type1_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color1}; color: white;">{type1}</div>', unsafe_allow_html=True) | |
type2_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color2}; color: white;">{type2}</div>', unsafe_allow_html=True) | |
st.write('\n') | |
# 4th row | |
with st.container(): | |
st.write(f'<div style=font-size: 1.4rem;><b>Height:</b> {height_m}m', unsafe_allow_html=True) | |
st.write('\n') | |
st.write(f'<div style=font-size: 1.4rem;><b>Weight:</b> {weight_kg}kg', unsafe_allow_html=True) | |
st.write('\n') | |
if not pd.isna(evolve_from): | |
st.markdown(f'<div style=font-size: 1.4rem;><b>Evolves from:</b> {evolve_from}', unsafe_allow_html=True) | |
#st.write(f'Evolves from: {evolve_from}') | |
st.write('\n') | |
if not pd.isna(evolve_into): | |
st.markdown(f'<div style=font-size: 1.4rem;><b>Evolves into:</b> {evolve_into}', unsafe_allow_html=True) | |
#st.write(f'Evolves into: {evolve_into}') | |
st.write('\n') | |