File size: 2,835 Bytes
6f3af4f
 
eb1008c
6f3af4f
e069d9a
6f3af4f
 
f417379
6f3af4f
 
 
 
 
eb1008c
e069d9a
6f3af4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301cd46
 
6f3af4f
 
 
 
 
 
 
 
 
 
 
301cd46
6f3af4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301cd46
6f3af4f
 
 
 
 
f417379
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
import streamlit as st
from transformers import pipeline, AutoImageProcessor
from PIL import Image
from utils import download_model_from_s3

model_paths = {
    "TinyBert Sentiment Analysis": "ml-models/tinybert-sentiment-analysis/",
    "TinyBert Disaster Classification": "ml-models/tinybert-disaster-tweet/",
    "VIT Pose Classification": "ml-models/vit-human-pose-classification/"
}


st.title("Machine Learning Model Deployment")

model_choice = st.selectbox(
    "Select Model:",[
        "TinyBert Sentiment Analysis",
        "TinyBert Disaster Classification",
        "VIT Pose Classification"   
    ]
)

local_path = model_choice.lower().replace(" ", "-")
s3_prefix = model_paths[model_choice]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if "downloaded_models" not in st.session_state:
    st.session_state.downloaded_models = set()


if model_choice not in st.session_state.downloaded_models:
    if st.button(f"Download {model_choice}"):
        with st.spinner(f"Downloading {model_choice}... Please wait!"):
            download_model_from_s3(local_path, s3_prefix)
        st.session_state.downloaded_models.add(model_choice)
        st.toast(f"✅ {model_choice} Succesfuly Download!", icon="🎉")

# **1. Sentiment Analysis Model**
if model_choice == "TinyBert Sentiment Analysis":
    text = st.text_area("Enter Text:", "This movie was horrible, the plot was really boring. acting was okay")
    predict = st.button("Predict Sentiment")

    classifier = pipeline("text-classification", model=local_path, device=device)

    if predict:
        with st.spinner("Predicting..."):
            output = classifier(text)
            st.write(output)

# **2. Disaster Classification**
if model_choice == "TinyBert Disaster Classification":
    text = st.text_area("Enter Text:", "There is a fire in the building")
    predict = st.button("Predict Sentiment")

    classifier = pipeline("text-classification", model=local_path, device=device)

    if predict:
        with st.spinner("Predicting..."):
            output = classifier(text)
            st.write(output)

# **3. Image Classification**
if model_choice == "VIT Pose Classification":
    uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"])
    predict = st.button("Predict Image")

    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption="Your Image", use_column_width=True)

        image_processor = AutoImageProcessor.from_pretrained(local_path, use_fast=True)
        pipe = pipeline('image-classification', model=local_path, image_processor=image_processor, device=device)

        if predict:
            with st.spinner("Predicting..."):
                output = pipe(image)
                st.write(output)