fahmiaziz98
first commit
301cd46
raw
history blame
2.84 kB
import os
import torch
import streamlit as st
from transformers import pipeline, AutoImageProcessor
from PIL import Image
from utils import download_model_from_s3
model_paths = {
"TinyBert Sentiment Analysis": "ml-models/tinybert-sentiment-analysis/",
"TinyBert Disaster Classification": "ml-models/tinybert-disaster-tweet/",
"VIT Pose Classification": "ml-models/vit-human-pose-classification/"
}
st.title("Machine Learning Model Deployment")
model_choice = st.selectbox(
"Select Model:",[
"TinyBert Sentiment Analysis",
"TinyBert Disaster Classification",
"VIT Pose Classification"
]
)
local_path = model_choice.lower().replace(" ", "-")
s3_prefix = model_paths[model_choice]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if "downloaded_models" not in st.session_state:
st.session_state.downloaded_models = set()
if model_choice not in st.session_state.downloaded_models:
if st.button(f"Download {model_choice}"):
with st.spinner(f"Downloading {model_choice}... Please wait!"):
download_model_from_s3(local_path, s3_prefix)
st.session_state.downloaded_models.add(model_choice)
st.toast(f"βœ… {model_choice} Succesfuly Download!", icon="πŸŽ‰")
# **1. Sentiment Analysis Model**
if model_choice == "TinyBert Sentiment Analysis":
text = st.text_area("Enter Text:", "This movie was horrible, the plot was really boring. acting was okay")
predict = st.button("Predict Sentiment")
classifier = pipeline("text-classification", model=local_path, device=device)
if predict:
with st.spinner("Predicting..."):
output = classifier(text)
st.write(output)
# **2. Disaster Classification**
if model_choice == "TinyBert Disaster Classification":
text = st.text_area("Enter Text:", "There is a fire in the building")
predict = st.button("Predict Sentiment")
classifier = pipeline("text-classification", model=local_path, device=device)
if predict:
with st.spinner("Predicting..."):
output = classifier(text)
st.write(output)
# **3. Image Classification**
if model_choice == "VIT Pose Classification":
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"])
predict = st.button("Predict Image")
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Your Image", use_column_width=True)
image_processor = AutoImageProcessor.from_pretrained(local_path, use_fast=True)
pipe = pipeline('image-classification', model=local_path, image_processor=image_processor, device=device)
if predict:
with st.spinner("Predicting..."):
output = pipe(image)
st.write(output)