Update README.md
Browse files
README.md
CHANGED
@@ -5,4 +5,102 @@ language:
|
|
5 |
metrics:
|
6 |
- accuracy
|
7 |
pipeline_tag: image-classification
|
8 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
metrics:
|
6 |
- accuracy
|
7 |
pipeline_tag: image-classification
|
8 |
+
---
|
9 |
+
|
10 |
+
# Mushroom Classification Model - JarviSpore
|
11 |
+
|
12 |
+
This repository contains **JarviSpore**, a mushroom image classification model trained on a multi-class dataset with 23 different types of mushrooms. Developed from scratch with TensorFlow and Keras, this model aims to provide accurate mushroom identification using advanced deep learning techniques, including *Grad-CAM* for interpreting predictions. This project explores the performance of from-scratch models compared to transfer learning.
|
13 |
+
|
14 |
+
## Model Details
|
15 |
+
|
16 |
+
- **Architecture**: Custom CNN (Convolutional Neural Network)
|
17 |
+
- **Number of Classes**: 23 mushroom classes
|
18 |
+
- **Input Format**: RGB images resized to 224x224 pixels
|
19 |
+
- **Framework**: TensorFlow & Keras
|
20 |
+
- **Training**: Conducted on a machine with an i9 14900k processor, 192GB RAM, and an RTX 3090 GPU
|
21 |
+
|
22 |
+
## Key Features
|
23 |
+
|
24 |
+
1. **Multi-Class Classification**: The model can predict among 23 mushroom species.
|
25 |
+
2. **Regularization**: Includes L2 regularization and Dropout to prevent overfitting.
|
26 |
+
3. **Class Weighting**: Manages dataset imbalances by applying specific weights for each class.
|
27 |
+
4. **Grad-CAM Visualization**: Utilizes Grad-CAM to generate heatmaps, allowing visualization of the regions influencing the model's predictions.
|
28 |
+
|
29 |
+
## Model Training
|
30 |
+
|
31 |
+
The model was trained using a structured dataset directory with data split as follows:
|
32 |
+
- train: Balanced training dataset
|
33 |
+
- validation: Validation set to monitor performance
|
34 |
+
- test: Test set to evaluate final accuracy
|
35 |
+
|
36 |
+
Main training hyperparameters include:
|
37 |
+
- **Batch Size**: 32
|
38 |
+
- **Epochs**: 20 with Early Stopping
|
39 |
+
- **Learning Rate**: 0.0001
|
40 |
+
|
41 |
+
Training was tracked and logged via MLflow, including accuracy and loss curves, as well as the best model weights saved automatically.
|
42 |
+
|
43 |
+
## Model Usage
|
44 |
+
|
45 |
+
### Prerequisites
|
46 |
+
|
47 |
+
Ensure the following libraries are installed:
|
48 |
+
bash
|
49 |
+
pip install tensorflow pillow matplotlib numpy
|
50 |
+
|
51 |
+
|
52 |
+
### Loading the Model
|
53 |
+
|
54 |
+
To load and use the model for predictions:
|
55 |
+
|
56 |
+
python
|
57 |
+
import tensorflow as tf
|
58 |
+
from PIL import Image
|
59 |
+
import numpy as np
|
60 |
+
|
61 |
+
# Load the model
|
62 |
+
model = tf.keras.models.load_model("path_to_model.h5")
|
63 |
+
|
64 |
+
# Prepare an image for prediction
|
65 |
+
def prepare_image(image_path):
|
66 |
+
img = Image.open(image_path).convert("RGB")
|
67 |
+
img = img.resize((224, 224))
|
68 |
+
img_array = tf.keras.preprocessing.image.img_to_array(img)
|
69 |
+
img_array = np.expand_dims(img_array, axis=0)
|
70 |
+
return img_array
|
71 |
+
|
72 |
+
# Prediction
|
73 |
+
image_path = "path_to_image.jpg"
|
74 |
+
img_array = prepare_image(image_path)
|
75 |
+
predictions = model.predict(img_array)
|
76 |
+
predicted_class = np.argmax(predictions[0])
|
77 |
+
|
78 |
+
print(f"Predicted Class: {predicted_class}")
|
79 |
+
|
80 |
+
|
81 |
+
### Grad-CAM Visualization
|
82 |
+
|
83 |
+
The integrated *Grad-CAM* functionality allows interpretation of the model's predictions. To use it, select an image and apply the Grad-CAM function to display the heatmap overlaid on the original image, highlighting areas influencing the model.
|
84 |
+
|
85 |
+
Grad-CAM example usage:
|
86 |
+
|
87 |
+
python
|
88 |
+
# Example usage of the make_gradcam_heatmap function
|
89 |
+
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name="last_conv_layer_name")
|
90 |
+
|
91 |
+
# Superimpose the heatmap on the original image
|
92 |
+
superimposed_img = superimpose_heatmap(Image.open(image_path), heatmap)
|
93 |
+
superimposed_img.show()
|
94 |
+
|
95 |
+
|
96 |
+
## Evaluation
|
97 |
+
|
98 |
+
The model was evaluated on the test set with an average accuracy above random chance, showing promising results for a first from-scratch version.
|
99 |
+
|
100 |
+
## Contributing
|
101 |
+
|
102 |
+
Contributions to improve accuracy or add new features (e.g., other visualization techniques or advanced optimization) are welcome. Please submit a pull request with relevant modifications.
|
103 |
+
|
104 |
+
## License
|
105 |
+
|
106 |
+
This model is licensed under a controlled license: please refer to the LICENSE file for details. You may use this model for personal projects, but any modifications or redistribution must be approved.
|