Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
import zipfile | |
import gdown | |
# Path to your zipped model file (this will be the local path after downloading) | |
ZIP_MODEL_PATH = '/app/your_trained_model_resnet50.keras.zip' | |
UNZIPPED_MODEL_PATH = '/app/your_trained_model_resnet50.keras' # Path where the model will be extracted | |
# Google Drive link to the model file | |
MODEL_URL = 'https://drive.google.com/uc?export=download&id=1-4p6AZBkooWL1rhN9WwIrfd9fJbhzY0e' | |
# Download the model if it doesn't exist | |
if not os.path.exists(ZIP_MODEL_PATH): | |
gdown.download(MODEL_URL, ZIP_MODEL_PATH, quiet=False) | |
print(f"Model downloaded to {ZIP_MODEL_PATH}") | |
# Unzip the model if it hasn't been unzipped already | |
if not os.path.exists(UNZIPPED_MODEL_PATH): | |
with zipfile.ZipFile(ZIP_MODEL_PATH, 'r') as zip_ref: | |
zip_ref.extractall('/app') | |
print(f"Model unzipped to {UNZIPPED_MODEL_PATH}") | |
# Load the model | |
try: | |
model = tf.keras.models.load_model(UNZIPPED_MODEL_PATH) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Define the function to predict decoration | |
def predict_decoration(image: Image.Image): | |
# Preprocess the image to match the model input format | |
image = image.resize((224, 224)) # Resize to match model's expected input size | |
image_array = np.array(image) / 255.0 # Normalize the image to [0, 1] | |
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension | |
# Make prediction | |
prediction = model.predict(image_array) | |
return "Decorated" if prediction[0] > 0.5 else "Undecorated" | |
# Set up Streamlit interface with Christmas theme | |
st.set_page_config(page_title="Tree Decoration Predictor", page_icon="π") | |
# Custom CSS for Christmas theme | |
st.markdown(""" | |
<style> | |
body { | |
background-color: #fae1dc; /* Soft pink background */ | |
color: #1b5e20; /* Deep green text */ | |
font-family: 'Comic Sans MS', cursive, sans-serif; | |
} | |
.css-18e3th9 { | |
background-color: #d32f2f; /* Christmas red button */ | |
color: white; | |
} | |
.css-1lcbm2e { | |
background-color: #388e3c; /* Christmas green button */ | |
color: white; | |
} | |
.stButton>button { | |
background-color: #f44336; /* Red button color */ | |
color: white; | |
border-radius: 12px; | |
padding: 10px; | |
font-size: 16px; | |
} | |
.stButton>button:hover { | |
background-color: #c62828; /* Darker red on hover */ | |
} | |
.stMarkdown { | |
font-size: 18px; | |
} | |
.stTab { | |
font-size: 20px; | |
font-weight: bold; | |
color: #388e3c; /* Christmas green */ | |
} | |
.stImage { | |
border: 2px solid #388e3c; /* Green border around images */ | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Title of the page | |
st.title("π Tree Decoration Predictor π") | |
# Create tabs for better organization | |
tab1, tab2 = st.tabs(["Upload Image", "Tree Image URLs"]) | |
# Upload Image Tab | |
with tab1: | |
uploaded_image = st.file_uploader("Upload an image of a tree", type=["jpg", "jpeg", "png"]) | |
if uploaded_image: | |
image = Image.open(uploaded_image) | |
st.image(image, caption="Uploaded Tree Image", use_container_width=True) | |
if st.button("Predict Decoration"): | |
prediction = predict_decoration(image) | |
st.write(f"Prediction: {prediction}") | |
# Tree Image URLs Tab | |
with tab2: | |
st.subheader("π Tree Image Samples π") | |
st.markdown(""" | |
View some of my decorated and undecorated tree samples for the Model here: | |
[View Trees](https://www.dropbox.com/scl/fo/cuzo12z39cxv6joz7gz2o/ACf5xSjT7nHqMRdgh21GYlc?raw=1) | |
Download the tree samples pictures to test them on the model yourself here: | |
[Download Trees](https://www.dropbox.com/scl/fo/cuzo12z39cxv6joz7gz2o/ACf5xSjT7nHqMRdgh21GYlc?raw=1&dl=1) | |
""") | |
# Add download link for images if needed | |
st.markdown("[Download the image list](https://raw.githubusercontent.com/willco-afk/tree-samples/main/tree_images.txt)") |