#importing the libraries import streamlit as st from PIL import Image import torch from transformers import AutoModelForImageClassification, AutoImageProcessor import numpy as np import pandas as pd import time import os model_repository_id = "Dusduo/Pokemon-classification-1stGen" # Loading the pokemon classifier model and its processor image_processor = AutoImageProcessor.from_pretrained(model_repository_id) model = AutoModelForImageClassification.from_pretrained(model_repository_id) # Loading the pokemon information table pokemon_info_df = pd.read_csv('pokemon_info.csv') pokeball_image = Image.open('pokeball.png').resize((20,20)) #functions to predict image def preprocess(processor: AutoImageProcessor, image): return processor(image.convert("RGB").resize((200,200)), return_tensors="pt") def predict(model: AutoModelForImageClassification, inputs, k=5): # Forward the image to the model and retrieve the logits with torch.no_grad(): logits = model(**inputs).logits # Convert the retrieved logits into a vector of probabilities for each class probabilities = torch.softmax(logits[0], dim=0).tolist() # Discriminate wether or not the inputted image was an image of a Pokemon # Compute the variance of the vector of probabilities # The spread of the probability values is a good represent of the confusion of the model # Or in other words, its confidence => the greater the spread, the lower its confidence variance = np.var(probabilities) # Too great of a spread: it is likely the image provided did not correspond to any known classes if variance < 0.001: #not a pokemon predicted_label = 'not a pokemon' probability = -1 (top_k_labels, top_k_probability) = '_', '_' else: # it is a pokemon # Retrieve the predicted class (pokemon) predicted_id = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_id] # Retrieve the probability for the predicted class, and format it to 2 decimals probability = round(probabilities[predicted_id]*100,2) # Retrieve the top 5 classes and their probabilities #top_k_labels = [model.config.id2label[key] for key in np.argpartition(logits.numpy(), -k)[-k:]] #top_k_probability = [round(prob*100,2) for prob in np.sort(probabilities.numpy())[-k:]] return predicted_label, probability #, (top_k_labels, top_k_probability) # Designing the interface ------------------------------------------ # Use the full page instead of a narrow central column st.set_page_config(layout="wide") # Define the title st.title("Gotta Classify 'Em All - 1st Generation Pokedex -") # For newline st.write('\n') col1, col2 = st.columns([3,1]) # [3,1] with col1: image = Image.open('base.jpg') show = st.image(image, use_column_width=True) # Display Sample images ---- st.subheader('Sample images') sample_imgs_dir = "sample_imgs/" sample_imgs = os.listdir(sample_imgs_dir) # get the list of all sample images img_idx = 0 n_cols = 4 groups = [] for i in range(0, len(sample_imgs), n_cols): groups.append(sample_imgs[i:i+n_cols]) for group in groups: cols = st.columns(n_cols) for i,image_file in enumerate(group): cols[i].image(sample_imgs_dir+image_file) # Sidebar work and model outputs --------------- st.sidebar.title("Upload Image") #Disabling warning #st.set_option('deprecation.showfileUploaderEncoding', False) #Choose your own image uploaded_file = st.sidebar.file_uploader("",type=['png', 'jpg', 'jpeg'], accept_multiple_files=False ) if uploaded_file is not None: u_img = Image.open(uploaded_file) show.image(u_img, 'Uploaded Image', width=400 )#use_column_width=True) # Preprocess the image for the model model_inputs = preprocess(image_processor, u_img) # For newline st.sidebar.write('\n') if st.sidebar.button("Click Here to Classify"): if uploaded_file is None: st.sidebar.write("Please upload an Image to Classify") else: with st.spinner('Classifying ...'): # Get prediction prediction, probability = predict(model, model_inputs,5) #, (top_k_labels, top_k_probability) time.sleep(2) st.sidebar.success('Done!') st.sidebar.header("Model predicts: ") # Display prediction if probability==-1: st.sidebar.write("It seems like it is not a picture of a 1st Generation Pokemon alone.", '\n', "There might be too many entities on the image." ) else: st.sidebar.write(f" It's a(n) {prediction} picture.",'\n') st.sidebar.write('Probability:',probability,'%') # Retrieve predicted pokemon information _, pokedex_number, english_name, romaji_name, katakana_name, weight_kg, height_m, type1, type2, color1, color2, classification, evolve_from, evolve_into, is_legendary = pokemon_info_df[pokemon_info_df['name']==prediction].values[0] with col2: # pokedex box with st.container(border=True ): # first row with st.container(): pokeball_image_col,pokedex_number_col, pokemon_name_col = st.columns([1,1,8]) pokeball_image_col.image(pokeball_image) pokedex_number_col.markdown(f'
{pokedex_number}
', unsafe_allow_html=True) pokemon_name_col.markdown(f'
{english_name}
', unsafe_allow_html=True) # second row with st.container(): st.markdown(f'
{classification}
', unsafe_allow_html=True) # 3rd row with st.container(): if pd.isna(type2): st.write('\n') st.markdown(f'
{type1}
', unsafe_allow_html=True) else: type1_col, type2_col = st.columns(2) type1_col.markdown(f'
{type1}
', unsafe_allow_html=True) type2_col.markdown(f'
{type2}
', unsafe_allow_html=True) st.write('\n') # 4th row with st.container(): st.write(f'
Height: {height_m}m', unsafe_allow_html=True) st.write('\n') st.write(f'
Weight: {weight_kg}kg', unsafe_allow_html=True) st.write('\n') if not pd.isna(evolve_from): st.markdown(f'
Evolves from: {evolve_from}', unsafe_allow_html=True) #st.write(f'Evolves from: {evolve_from}') st.write('\n') if not pd.isna(evolve_into): st.markdown(f'
Evolves into: {evolve_into}', unsafe_allow_html=True) #st.write(f'Evolves into: {evolve_into}') st.write('\n')