tree-test / app.py
willco-afk's picture
Update app.py
34bd185 verified
raw
history blame
1.22 kB
import streamlit as st
from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import numpy as np
# Load the model and feature extractor
model = TFAutoModelForImageClassification.from_pretrained("your-username/your-repo-name") # Replace with your repo
feature_extractor = AutoFeatureExtractor.from_pretrained(model.config._name_or_path)
# Streamlit UI
st.title("Christmas Tree Classifier")
st.write("Upload an image of a Christmas tree to classify it:")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display the uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image.", use_column_width=True)
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="tf")
# Make prediction
logits = model(**inputs).logits
predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
# Map class index to label
class_names = model.config.id2label # Get class names from model config
predicted_class = class_names[predicted_class_idx]
# Display the prediction
st.write(f"Prediction: **{predicted_class}**")