willco-afk commited on
Commit
dae7f18
·
verified ·
1 Parent(s): 7db490a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -1,21 +1,11 @@
1
  import os
2
  import streamlit as st
3
- from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor
4
  from PIL import Image
5
  import numpy as np
6
- from huggingface_hub import login
7
-
8
-
9
- # Authenticate with Hugging Face token (if available)
10
- hf_token = os.environ.get("HF_TOKEN")
11
- if hf_token:
12
- login(token=hf_token)
13
-
14
-
15
- # Load the model and feature extractor
16
- model = TFAutoModelForImageClassification.from_pretrained(os.environ.get("MODEL_ID", "willco-afk/tree-test-x"))
17
- feature_extractor = AutoFeatureExtractor.from_pretrained(model.config._name_or_path)
18
 
 
 
19
 
20
  # Streamlit UI
21
  st.title("Christmas Tree Classifier")
@@ -27,17 +17,19 @@ if uploaded_file is not None:
27
  # Display the uploaded image
28
  image = Image.open(uploaded_file)
29
  st.image(image, caption="Uploaded Image.", use_column_width=True)
 
 
30
 
31
  # Preprocess the image
32
- inputs = feature_extractor(images=image, return_tensors="tf")
 
 
33
 
34
  # Make prediction
35
- logits = model(**inputs).logits
36
- predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
37
 
38
- # Map class index to label
39
- class_names = model.config.id2label # Get class names from model config
40
- predicted_class = class_names[predicted_class_idx]
41
 
42
  # Display the prediction
43
- st.write(f"Prediction: **{predicted_class}**")
 
1
  import os
2
  import streamlit as st
3
+ import tensorflow as tf
4
  from PIL import Image
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Load your Keras model from Google Drive
8
+ model = tf.keras.models.load_model('/content/drive/MyDrive/your_trained_model.keras')
9
 
10
  # Streamlit UI
11
  st.title("Christmas Tree Classifier")
 
17
  # Display the uploaded image
18
  image = Image.open(uploaded_file)
19
  st.image(image, caption="Uploaded Image.", use_column_width=True)
20
+ st.write("")
21
+ st.write("Classifying...")
22
 
23
  # Preprocess the image
24
+ image = image.resize((224, 224)) # Resize to match your model's input size
25
+ image_array = np.array(image) / 255.0 # Normalize pixel values
26
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
27
 
28
  # Make prediction
29
+ prediction = model.predict(image_array)
 
30
 
31
+ # Get predicted class
32
+ predicted_class = "Decorated" if prediction[0][0] >= 0.5 else "Undecorated"
 
33
 
34
  # Display the prediction
35
+ st.write(f"Prediction: {predicted_class}")