Spaces:
Running
Running
File size: 4,294 Bytes
0e7de7e 1556762 845eb37 1556762 6197f1f 1556762 0e7de7e 1556762 0e7de7e 6197f1f 0e7de7e 1556762 0e7de7e 1556762 845eb37 1556762 0e7de7e 1556762 845eb37 1556762 0e7de7e 1556762 0e7de7e 1556762 845eb37 6197f1f 845eb37 1556762 845eb37 6197f1f 845eb37 1556762 845eb37 1556762 845eb37 0e7de7e 1556762 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import streamlit as st
import numpy as np
from PIL import Image
import requests
import ModelClass
from glob import glob
import torch
import torch.nn as nn
import numpy as np
@st.cache_resource
def load_model():
return ModelClass.get_model()
@st.cache_data
def get_images():
l = glob('./inputs/*')
l = {i.split('/')[-1]: i for i in l}
return l
def infer(img):
image = img.convert('RGB')
image = ModelClass.get_transform()(image)
image = image.unsqueeze(dim=0)
model = load_model()
model.eval()
with torch.no_grad():
out = model(image)
out = nn.Softmax()(out).squeeze()
return out
st.set_page_config(
page_title="ActionNet",
page_icon="🧊",
layout="centered",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': """
# This is a header. This is an *extremely* cool app!
How how are you doin.
---
I am fine
<style>
</style>
"""
}
)
# fix sidebar
st.markdown("""
<style>
.css-vk3wp9 {
background-color: rgb(255 255 255);
}
.css-18l0hbk {
padding: 0.34rem 1.2rem !important;
margin: 0.125rem 2rem;
}
.css-nziaof {
padding: 0.34rem 1.2rem !important;
margin: 0.125rem 2rem;
background-color: rgb(181 197 227 / 18%) !important;
}
.css-1y4p8pa {
padding: 3rem 1rem 10rem;
max-width: 58rem;
}
</style>
""", unsafe_allow_html=True
)
hide_st_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
st.markdown(hide_st_style, unsafe_allow_html=True)
def predict(image):
# Dummy prediction
classes = ['cat', 'dog']
prediction = np.random.rand(len(classes))
prediction /= np.sum(prediction)
return dict(zip(classes, prediction))
def app():
st.title('ActionNet')
# st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)")
st.markdown('Human Action Recognition using CNN: A Conputer Vision project that trains a ResNet model to classify human activities. The dataset contains 15 activity classes, and the model predicts the activity from input images.')
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
test_images = get_images()
test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
st.markdown('#### Selected Image')
left_column, right_column = st.columns([1.5, 2.5], gap="medium")
with left_column:
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
else:
image_url = test_images[test_image]
image = Image.open(image_url)
st.image(image, use_column_width=True)
if st.button('✨ Get prediction from AI', type='primary'):
spacer = st.empty()
res = infer(image)
prob = res.numpy()
idx = np.argpartition(prob, -6)[-6:]
right_column.markdown('#### Results')
idx = list(idx)
idx.sort(key=lambda x: prob[x].astype(float), reverse=True)
for i in idx:
class_name = ModelClass.get_class(i).replace('_', ' ').capitalize()
class_probability = prob[i].astype(float)
right_column.write(f'{class_name}: {class_probability:.2%}')
right_column.progress(class_probability)
st.markdown("---")
st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Data provided by [aiplanet](https://aiplanet.com/challenges/data-sprint-76-human-activity-recognition/233/overview/about)")
app() |