import streamlit as st
import numpy as np
from PIL import Image
import requests
import ModelClass
from glob import glob
import torch
import torch.nn as nn
@st.cache_resource
def load_model():
return ModelClass.get_model()
@st.cache_data
def get_images():
l = glob('./inputs/*')
l = {i.split('/')[-1]: i for i in l}
return l
def infer(img):
image = img.convert('RGB')
image = ModelClass.get_transform()(image)
image = image.unsqueeze(dim=0)
model = load_model()
model.eval()
with torch.no_grad():
out = model(image)
out = nn.Softmax()(out).squeeze()
return out
st.set_page_config(
page_title="Whale Identification",
page_icon="🧊",
layout="centered",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': """
# This is a header. This is an *extremely* cool app!
How how are you doin.
---
I am fine
"""
}
)
# fix sidebar
st.markdown("""
""", unsafe_allow_html=True
)
hide_st_style = """
"""
#st.markdown(hide_st_style, unsafe_allow_html=True)
def predict(image):
# Dummy prediction
classes = ['cat', 'dog']
prediction = np.random.rand(len(classes))
prediction /= np.sum(prediction)
return dict(zip(classes, prediction))
def app():
st.title('ActionNet')
st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai//?workspace=user-)")
st.markdown('This project aims to identify whales and dolphins by their unique characteristics. It can help researchers understand their behavior, population dynamics, and migration patterns. This project can aid researchers in identifying these marine mammals, providing valuable data for conservation efforts. [[Source Code]](https://kaggle.com/)')
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
test_images = get_images()
test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
st.subheader('Selected Image')
left_column, right_column = st.columns([1.5, 2.5], gap="medium")
with left_column:
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
else:
image_url = test_images[test_image]
image = Image.open(image_url)
st.image(image, use_column_width=True)
if st.button('✨ Get prediction from AI', type='primary'):
spacer = st.empty()
res = infer(image)
res = torch.argmax(res)
cname = ModelClass.get_class(res)
st.write(f'{cname}')
prediction = predict(image)
right_column.subheader('Results')
for class_name, class_probability in prediction.items():
right_column.write(f'{class_name}: {class_probability:.2%}')
right_column.progress(class_probability)
st.markdown("---")
st.markdown("Built by [Shamim Ahamed](https://your-portfolio-website.com/). Data provided by [Kaggle](https://www.kaggle.com/c/)")
app()