willco-afk commited on
Commit
34bd185
·
verified ·
1 Parent(s): 6568577

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -22
app.py CHANGED
@@ -1,20 +1,11 @@
1
  import streamlit as st
2
- import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
- import zipfile
6
- import os
7
 
8
- # Function to load the model from the zip file
9
- def load_model_from_zip(zip_file_path):
10
- with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
11
- zip_ref.extractall('.') # Extract to the current directory
12
- # Load the SavedModel directly from the current directory (root)
13
- model = tf.keras.models.load_model('.')
14
- return model
15
-
16
- # Load the model
17
- model = load_model_from_zip('my_christmas_tree_model.zip')
18
 
19
  # Streamlit UI
20
  st.title("Christmas Tree Classifier")
@@ -28,17 +19,15 @@ if uploaded_file is not None:
28
  st.image(image, caption="Uploaded Image.", use_column_width=True)
29
 
30
  # Preprocess the image
31
- image = image.resize((150, 150)) # Resize to match your model's input shape
32
- image_array = np.array(image) / 255.0 # Normalize
33
- image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
34
 
35
  # Make prediction
36
- prediction = model.predict(image_array)
37
-
38
- # Interpret prediction (assuming binary classification)
39
- class_names = ['Undecorated', 'Decorated'] # Update with your actual class names
40
- predicted_class_index = 1 if prediction[0][0] >= 0.5 else 0 # Adjust threshold if needed
41
- predicted_class = class_names[predicted_class_index]
42
 
43
  # Display the prediction
44
  st.write(f"Prediction: **{predicted_class}**")
 
1
  import streamlit as st
2
+ from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor
3
  from PIL import Image
4
  import numpy as np
 
 
5
 
6
+ # Load the model and feature extractor
7
+ model = TFAutoModelForImageClassification.from_pretrained("your-username/your-repo-name") # Replace with your repo
8
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model.config._name_or_path)
 
 
 
 
 
 
 
9
 
10
  # Streamlit UI
11
  st.title("Christmas Tree Classifier")
 
19
  st.image(image, caption="Uploaded Image.", use_column_width=True)
20
 
21
  # Preprocess the image
22
+ inputs = feature_extractor(images=image, return_tensors="tf")
 
 
23
 
24
  # Make prediction
25
+ logits = model(**inputs).logits
26
+ predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
27
+
28
+ # Map class index to label
29
+ class_names = model.config.id2label # Get class names from model config
30
+ predicted_class = class_names[predicted_class_idx]
31
 
32
  # Display the prediction
33
  st.write(f"Prediction: **{predicted_class}**")