Spaces:
Sleeping
Sleeping
File size: 2,982 Bytes
e3e1e9b 3864676 e3e1e9b 4cfbed9 e3e1e9b 7aebb3f f143c99 80da5e3 939fac3 3138e19 939fac3 80da5e3 939fac3 80da5e3 f143c99 2202f73 7aebb3f 0c9e4f4 e3e1e9b 2202f73 e3e1e9b 2202f73 e3e1e9b 3864676 e3e1e9b 0a3512b e3e1e9b 3be0091 7b69f6b e3e1e9b 3be0091 7b69f6b 3138e19 939fac3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
from transformers import AutoModel
import torch.nn as nn
from PIL import Image
import numpy as np
import streamlit as st
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the trained model from the Hugging Face Hub
model = AutoModel.from_pretrained('dhhd255/parkinsons_20epochs')
# Move the model to the device
model = model.to(device)
# Add custom CSS to use the Inter font, define custom classes for healthy and parkinsons results, increase the font size, make the text bold, and define the footer styles
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
body {
font-family: 'Inter', sans-serif;
}
.result {
font-size: 24px;
font-weight: bold;
}
.healthy {
color: #007E3F;
}
.parkinsons {
color: #C30000;
}
.caption_c{
position: relative;
display: flex;
flex-directon: column;
align-items: center;
top: calc(99vh - 370px);
}
.caption {
text-align: center;
color: #646464;
font-size: 14px;
}
button:hover {
background-color: lightblue !important;
outline-color: lightblue !important;
}
button:focus {
background-color: lightblue !important;
outline-color: lightblue !important;
}
</style>
""", unsafe_allow_html=True)
st.title("Parkinson's Disease Prediction")
uploaded_file = st.file_uploader("Upload your :blue[Spiral] drawing here", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
col1, col2 = st.columns(2)
# Load and resize the image
image_size = (224, 224)
new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
col1.image(new_image, use_column_width=True)
new_image = np.array(new_image)
new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
# Move the data to the device
new_image = new_image.to(device)
# Make predictions using the trained model
with torch.no_grad():
predictions = model(new_image)
logits = predictions.last_hidden_state
logits = logits.view(logits.shape[0], -1)
num_classes=2
feature_reducer = nn.Linear(logits.shape[1], num_classes)
logits = logits.to(device)
feature_reducer = feature_reducer.to(device)
logits = feature_reducer(logits)
predicted_class = torch.argmax(logits, dim=1).item()
confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
if(predicted_class == 0):
col2.markdown('<span class="result parkinsons">Predicted class: Parkinson\'s</span>', unsafe_allow_html=True)
col2.caption(f'{confidence*100:.0f}% sure')
else:
col2.markdown('<span class="result healthy">Predicted class: Healthy</span>', unsafe_allow_html=True)
col2.caption(f'{confidence*100:.0f}% sure')
# Add a caption at the bottom of the page
st.markdown('<div class="caption_c"><p class="caption">Made with love by Jayant</p></div>', unsafe_allow_html=True)
|