im2 commited on
Commit
81a6137
1 Parent(s): ccd5a57

draw and test

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
6
+
7
+ # Load the model and feature extractor from Hugging Face
8
+ model_name = "immartian/improved_digits_recognition"
9
+ model = AutoModelForImageClassification.from_pretrained(model_name)
10
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
+
12
+ # Preprocessing function to transform the drawn image into a format the model can recognize
13
+ def preprocess_image(image):
14
+ # Convert the image into a format suitable for the model
15
+ image = Image.fromarray(image).convert('L') # Convert to grayscale
16
+ image = image.resize((28, 28)) # Resize to 28x28 pixels
17
+ image = image.convert('RGB') # Model expects 3-channel images, so convert to RGB
18
+ inputs = feature_extractor(images=image, return_tensors="pt")
19
+ return inputs['pixel_values']
20
+
21
+ # Prediction function to classify the drawn digit
22
+ def predict_digit(image):
23
+ # Preprocess the input image
24
+ inputs = preprocess_image(image)
25
+
26
+ # Make the prediction
27
+ with torch.no_grad():
28
+ outputs = model(inputs)
29
+ predicted_label = outputs.logits.argmax(-1).item()
30
+
31
+ return f"Predicted Digit: {predicted_label}"
32
+
33
+ # Gradio interface for drawing the digit and displaying the prediction
34
+ demo = gr.Interface(
35
+ fn=predict_digit,
36
+ inputs="sketchpad", # Allow users to draw a digit
37
+ outputs="text",
38
+ title="MNIST Digit Recognition",
39
+ description="Draw a digit (0-9) and let the model recognize it!",
40
+ live=True # The prediction updates while the user draws
41
+ )
42
+
43
+ # Launch the app
44
+ if __name__ == "__main__":
45
+ demo.launch()