Jugal-sheth commited on
Commit
bf2bb8a
·
1 Parent(s): a4c6f44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -1,24 +1,24 @@
1
  import gradio as gr
2
- import cv2
3
  import torch
4
- from model import model
 
5
  from torchvision import transforms
6
 
7
- #Load the Model
8
  model.load_state_dict(torch.load('mnist_model.pth'))
9
- # Set model to evaluation mode
10
  model.eval()
11
 
12
-
13
  def preprocess_image(image):
14
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
15
- threshold = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
16
- resized = cv2.resize(threshold, (28, 28), interpolation=cv2.INTER_AREA)
17
- tensor = transforms.ToTensor()(resized).unsqueeze(0)
18
- tensor = transforms.Normalize((0.5,), (0.5,))(tensor)
 
 
 
19
  return tensor
20
 
21
-
22
  def classify(image):
23
  tensor = preprocess_image(image)
24
  with torch.no_grad():
@@ -26,7 +26,6 @@ def classify(image):
26
  prediction = output.argmax(dim=1, keepdim=True).item()
27
  return str(prediction) # Convert prediction to string
28
 
29
-
30
  iface = gr.Interface(
31
  fn=classify,
32
  inputs="sketchpad",
 
1
  import gradio as gr
 
2
  import torch
3
+ from PIL import Image
4
+ from model import model
5
  from torchvision import transforms
6
 
7
+ # Load your own model
8
  model.load_state_dict(torch.load('mnist_model.pth'))
 
9
  model.eval()
10
 
 
11
  def preprocess_image(image):
12
+ transform = transforms.Compose([
13
+ transforms.Resize((28, 28)),
14
+ transforms.Grayscale(num_output_channels=1),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize((0.5,), (0.5,))
17
+ ])
18
+ image = Image.fromarray(image)
19
+ tensor = transform(image).unsqueeze(0)
20
  return tensor
21
 
 
22
  def classify(image):
23
  tensor = preprocess_image(image)
24
  with torch.no_grad():
 
26
  prediction = output.argmax(dim=1, keepdim=True).item()
27
  return str(prediction) # Convert prediction to string
28
 
 
29
  iface = gr.Interface(
30
  fn=classify,
31
  inputs="sketchpad",