#importing the libraries import streamlit as st from PIL import Image import torch from transformers import AutoModelForImageClassification, AutoImageProcessor import numpy as np import pandas as pd import time import os model_repository_id = "Dusduo/Pokemon-classification-1stGen" # Loading the pokemon classifier model and its processor image_processor = AutoImageProcessor.from_pretrained(model_repository_id) model = AutoModelForImageClassification.from_pretrained(model_repository_id) # Loading the pokemon information table pokemon_info_df = pd.read_csv('pokemon_info.csv') pokeball_image = Image.open('pokeball.png').resize((20,20)) #functions to predict image def preprocess(processor: AutoImageProcessor, image): return processor(image.convert("RGB").resize((200,200)), return_tensors="pt") def predict(model: AutoModelForImageClassification, inputs, k=5): # Forward the image to the model and retrieve the logits with torch.no_grad(): logits = model(**inputs).logits # Convert the retrieved logits into a vector of probabilities for each class probabilities = torch.softmax(logits[0], dim=0).tolist() # Discriminate wether or not the inputted image was an image of a Pokemon # Compute the variance of the vector of probabilities # The spread of the probability values is a good represent of the confusion of the model # Or in other words, its confidence => the greater the spread, the lower its confidence variance = np.var(probabilities) # Too great of a spread: it is likely the image provided did not correspond to any known classes if variance < 0.001: #not a pokemon predicted_label = 'not a pokemon' probability = -1 (top_k_labels, top_k_probability) = '_', '_' else: # it is a pokemon # Retrieve the predicted class (pokemon) predicted_id = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_id] # Retrieve the probability for the predicted class, and format it to 2 decimals probability = round(probabilities[predicted_id]*100,2) # Retrieve the top 5 classes and their probabilities #top_k_labels = [model.config.id2label[key] for key in np.argpartition(logits.numpy(), -k)[-k:]] #top_k_probability = [round(prob*100,2) for prob in np.sort(probabilities.numpy())[-k:]] return predicted_label, probability #, (top_k_labels, top_k_probability) # Designing the interface ------------------------------------------ # Use the full page instead of a narrow central column st.set_page_config(layout="wide") # Define the title st.title("Gotta Classify 'Em All") st.subheader("Image classifier for Pokemons from the 1st generation.") # For newline st.write('\n') image = Image.open('base.jpg') col1, col2 = st.columns([1,2]) # [3,1] with col1: image = Image.open('base.jpg') show = st.image(image, use_column_width=True) # Display Sample images ---- st.subheader('Sample images') sample_imgs_dir = "sample_imgs/" sample_imgs = os.listdir(sample_imgs_dir) # get the list of all sample images img_idx = 0 n_cols = 4 groups = [] for i in range(0, len(sample_imgs), n_cols): groups.append(sample_imgs[i:i+n_cols]) for group in groups: cols = st.columns(n_cols) for i,image_file in enumerate(group): cols[i].image(sample_imgs_dir+image_file) # Sidebar work and model outputs --------------- st.sidebar.title("Upload Image") #Disabling warning #st.set_option('deprecation.showfileUploaderEncoding', False) #Choose your own image uploaded_file = st.sidebar.file_uploader("",type=['png', 'jpg', 'jpeg'], accept_multiple_files=False ) if uploaded_file is not None: u_img = Image.open(uploaded_file) show.image(u_img, 'Uploaded Image',use_column_width=True) #, width=400 )# # Preprocess the image for the model model_inputs = preprocess(image_processor, u_img) # For newline st.sidebar.write('\n') if st.sidebar.button("Click Here to Classify"): if uploaded_file is None: st.sidebar.write("Please upload an Image to Classify") else: with st.spinner('Classifying ...'): # Get prediction prediction, probability = predict(model, model_inputs,5) #, (top_k_labels, top_k_probability) time.sleep(2) st.sidebar.success('Done!') st.sidebar.header("Model response: ") # Display prediction if probability==-1: st.sidebar.write("""I am sorry I am having trouble finding a matching pokemon.
Potential explanations:
- The image provided is a Pokemon but not from the 1st Generation.
- The image provided is not a Pokemon.
- There are too many entities on the image.
""", unsafe_allow_html=True) else: st.sidebar.write(f" It's a(n) {prediction} picture.",'\n', unsafe_allow_html=True) st.sidebar.write(f'Probability:',probability,'%', unsafe_allow_html=True) # Retrieve predicted pokemon information _, pokedex_number, english_name, romaji_name, katakana_name, weight_kg, height_m, type1, type2, color1, color2, classification, evolve_from, evolve_into, is_legendary = pokemon_info_df[pokemon_info_df['name']==prediction].values[0] with col2: # pokedex box with st.container(border=True ): # first row with st.container(): pokeball_image_col,pokedex_number_col, pokemon_name_col = st.columns([1,1,8]) pokeball_image_col.image(pokeball_image) pokedex_number_col.markdown(f'
Pokedex n°{pokedex_number}
', unsafe_allow_html=True) pokemon_name_col.markdown(f'
{english_name}
{katakana_name}
', unsafe_allow_html=True) # second row with st.container(): st.markdown(f'
{classification}
', unsafe_allow_html=True) # 3rd row with st.container(): if pd.isna(type2): st.write('\n') st.markdown(f'
{type1}
', unsafe_allow_html=True) else: type1_col, type2_col = st.columns(2) type1_col.markdown(f'
{type1}
', unsafe_allow_html=True) type2_col.markdown(f'
{type2}
', unsafe_allow_html=True) st.write('\n') # 4th row with st.container(): st.write(f'
Height: {height_m}m', unsafe_allow_html=True) st.write('\n') st.write(f'
Weight: {weight_kg}kg', unsafe_allow_html=True) st.write('\n') if not pd.isna(evolve_from): st.markdown(f'
Evolves from: {evolve_from}', unsafe_allow_html=True) #st.write(f'Evolves from: {evolve_from}') st.write('\n') if not pd.isna(evolve_into): st.markdown(f'
Evolves into: {evolve_into}', unsafe_allow_html=True) #st.write(f'Evolves into: {evolve_into}') st.write('\n') st.sidebar.write('\n') st.sidebar.info( """ - Web App URL: [url](https://huggingface.co/spaces/Dusduo/GottaClassifyEmAll) - GitHub repository: [repository](https://github.com/A-Duss/GottaClassifyEmAll.git) """ ) st.sidebar.title("Contact") st.sidebar.info( """ Antoine Dussolle: [LinkedIn](https://www.linkedin.com/in/antoine-dussolle/) | [GitHub](https://github.com/A-Duss) """ )