Spaces:
Sleeping
Sleeping
im2
commited on
Commit
•
933911c
1
Parent(s):
41a6271
improved
Browse files
app.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
|
5 |
from PIL import Image
|
|
|
|
|
|
|
6 |
|
7 |
# Load the model using PyTorch
|
8 |
model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"
|
@@ -29,44 +33,72 @@ class ImageClassifier(torch.nn.Module):
|
|
29 |
# Instantiate the model and load weights
|
30 |
model = ImageClassifier()
|
31 |
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
|
32 |
-
model.eval()
|
33 |
|
34 |
-
#
|
35 |
-
|
36 |
-
# Extract the 'composite' key from the sketchpad input dictionary
|
37 |
-
imArray = image['composite'] # 'composite' contains the drawn image
|
38 |
-
return imArray
|
39 |
|
40 |
# Gradio preprocessing and prediction pipeline
|
41 |
def predict_digit(image):
|
42 |
-
#
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
])
|
51 |
-
|
52 |
-
img_tensor = transform(
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# Pass through the model
|
55 |
with torch.no_grad():
|
56 |
output = model(img_tensor)
|
57 |
-
|
58 |
|
59 |
-
return f"Predicted Label: {
|
60 |
|
61 |
-
# Create Gradio Interface
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
description="Draw a digit (0-9) and the model will predict the number!"
|
68 |
-
)
|
69 |
|
|
|
|
|
|
|
|
|
70 |
# Launch the app
|
71 |
if __name__ == "__main__":
|
72 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
from gradio import Brush
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
+
import cv2 # For Gaussian blur
|
6 |
from PIL import Image
|
7 |
+
from torch import nn, save, load
|
8 |
+
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
|
9 |
+
|
10 |
|
11 |
# Load the model using PyTorch
|
12 |
model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"
|
|
|
33 |
# Instantiate the model and load weights
|
34 |
model = ImageClassifier()
|
35 |
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
|
36 |
+
# model.eval()
|
37 |
|
38 |
+
# with open('pytorch_model.bin', 'rb') as f:
|
39 |
+
# model.load_state_dict(load(f))
|
|
|
|
|
|
|
40 |
|
41 |
# Gradio preprocessing and prediction pipeline
|
42 |
def predict_digit(image):
|
43 |
+
# Extract the 'composite' key, which contains the drawn image
|
44 |
+
if isinstance(image, dict):
|
45 |
+
image = image.get('composite', None) # Use the composite image
|
46 |
+
|
47 |
+
if image is None:
|
48 |
+
raise ValueError("No image data found in the input!")
|
49 |
+
|
50 |
+
#print("Unique pixel values in the image array:", np.unique(image))
|
51 |
+
|
52 |
+
# Ensure the input is a numpy array
|
53 |
+
image = np.array(image, dtype=np.uint8)
|
54 |
+
|
55 |
+
|
56 |
+
# Apply Gaussian blur to reduce noise
|
57 |
+
image = cv2.GaussianBlur(image, (5, 5), 0)
|
58 |
+
|
59 |
+
# If the image has multiple channels (e.g., BGR), convert it to grayscale
|
60 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
61 |
+
image = cv2.resize(image, (28, 28))
|
62 |
+
# Optional: Apply adaptive histogram equalization to improve contrast
|
63 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
64 |
+
image = clahe.apply(image)
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
# Convert the numpy array back to a PIL Image for torchvision compatibility
|
69 |
+
img_pil = Image.fromarray(image)
|
70 |
+
img_pil.show()
|
71 |
+
|
72 |
+
transform = Compose([
|
73 |
+
Resize((28, 28)),
|
74 |
+
ToTensor(),
|
75 |
+
Normalize((0.5,), (0.5,))
|
76 |
])
|
77 |
+
|
78 |
+
img_tensor = transform(img_pil).unsqueeze(0) # Add batch dimension
|
79 |
+
|
80 |
+
# Debugging: Print tensor shape and some pixel values
|
81 |
+
print(f"Input Tensor Shape: {img_tensor.shape}")
|
82 |
+
print(f"First 5 pixels of the tensor: {img_tensor[0, 0, :5, :5]}")
|
83 |
|
84 |
# Pass through the model
|
85 |
with torch.no_grad():
|
86 |
output = model(img_tensor)
|
87 |
+
prediction = torch.argmax(output)
|
88 |
|
89 |
+
return f"Predicted Label: {prediction}"
|
90 |
|
91 |
+
# Create Gradio Interface using ImageEditor
|
92 |
+
with gr.Blocks() as demo:
|
93 |
+
with gr.Row():
|
94 |
+
im = gr.ImageEditor(type="numpy", crop_size="1:1")
|
95 |
+
# im_preview = gr.Image()
|
96 |
+
prediction_box = gr.Textbox(label="Predicted Digit")
|
|
|
|
|
97 |
|
98 |
+
im.change(predict_digit, outputs=prediction_box, inputs=im, show_progress="hidden")
|
99 |
+
|
100 |
+
#im.change(predict_digit, outputs="text", inputs=im, show_progress=True)
|
101 |
+
|
102 |
# Launch the app
|
103 |
if __name__ == "__main__":
|
104 |
+
demo.launch()
|