tree-test / app.py
willco-afk's picture
Update app.py
059669c verified
raw
history blame
1.23 kB
import os
import streamlit as st
from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import numpy as np
# Load the model and feature extractor
model = TFAutoModelForImageClassification.from_pretrained(os.environ.get("MODEL_ID", "willco-afk/tree-test-x"))
feature_extractor = AutoFeatureExtractor.from_pretrained(model.config._name_or_path)
# Streamlit UI
st.title("Christmas Tree Classifier")
st.write("Upload an image of a Christmas tree to classify it:")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display the uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image.", use_column_width=True)
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="tf")
# Make prediction
logits = model(**inputs).logits
predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
# Map class index to label
class_names = model.config.id2label # Get class names from model config
predicted_class = class_names[predicted_class_idx]
# Display the prediction
st.write(f"Prediction: **{predicted_class}**")