topshelf-poc / src /Hockey_Breeds.py
Dan Biagini
add try it inference for hockey breeds
0200926
raw
history blame
3.5 kB
import streamlit as st
from streamlit_image_select import image_select
import zip_files
import random
import logging
from huggingface_hub import from_pretrained_fastai
@st.cache_resource
def get_model():
repo_id = "danbiagini/hockey_breeds"
return from_pretrained_fastai(repo_id)
def classify_image(learn, img):
categories = ('Hockey Goalie', 'Hockey Player', "Hockey Referee")
pred,idx,prob = learn.predict(img)
return dict(zip(categories, map(float, prob)))
def reroll_sample_images():
# unzip the sample images
players = zip_files.extract_files_from_zip("src/images/samples/player-samples.zip")
goalies = zip_files.extract_files_from_zip("src/images/samples/goalie-samples.zip")
referees = zip_files.extract_files_from_zip("src/images/samples/referee-samples.zip")
#randomize a single file from players, goalies and referee for samples
st.session_state.sample = dict()
st.session_state.sample["player"] = players[list(players.keys())[random.randint(0, len(players) - 1)]]
st.session_state.sample["goalie"] = goalies[list(goalies.keys())[random.randint(0, len(goalies) - 1)]]
st.session_state.sample["referee"] = referees[list(referees.keys())[random.randint(0, len(referees) - 1)]]
if 'sample' not in st.session_state:
reroll_sample_images()
st.set_page_config(page_title='Hockey Breeds', layout="wide",
page_icon=":frame_with_picture:")
st.title('Hockey Breeds - Hello Computer Vision')
st.subheader('Image Classification')
img_class = '''Image Classification in Computer Vision is the act of determining the most appropriate label for an entire image from a set of fixed labels.
A popular topic of image classification in Computer Vision introductions and courses is to use an example problem of training a model to label images of various pet breeds.
*Hockey Breeds* is a hockey slant on this common theme in Computer Vision educational materials.'''
st.markdown(img_class)
st.subheader("Hockey Image Classification")
desc = '''This "Hockey Breeds" model was built using 50 hockey related images found on the web and in my own collection. I started with a pretrained *ResNet18* model (resnet18 is trained on *ImageNet*, a very large dataset with millions of images). I fine tuned the model by labeling the hockey photos, then training using python (*Fast.ai* & *PyTorch* libraries).
The total training time for this was approximately 5 minutes running on a low-end GPU. It’s impressive how accurate this quick / small model can be!'''
st.markdown(desc)
st.image("src/images/samples/sampl_batch.png")
st.subheader("Validation Results")
st.markdown('Validation of the model\'s performance was done using 26 images not included in the training set. The model performed fairly well against the validation dataset, with only 1 misclassified image.')
st.image("src/images/artifacts/confusion_matrix.png", caption="Confusion Matrix for Hockey Breeds ")
st.subheader("Try It Out")
img = image_select(label="Select an image and hockey breeds will guess a label", images=list(st.session_state.sample.values()))
st.button("Re-roll Samples", on_click=reroll_sample_images)
model = get_model()
if img:
res = classify_image(model, img)
# Sort the dictionary items by value in descending order
max = 0
max_label = ""
for k,v in res.items():
prob = round(v*100, 2)
if prob > max:
max = prob
max_label = k
st.metric(label=max_label, value=max)