Henry Du commited on
Commit
d85c7f7
·
1 Parent(s): d2a009f

Update files

Browse files
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ from torch import nn
6
+ from typing import Dict
7
+ import torchvision
8
+ from torchvision import transforms
9
+ import PIL
10
+ import numpy
11
+
12
+ # import os
13
+ # os.system("pip uninstall -y gradio")
14
+ # os.system("pip install gradio==3.50.2")
15
+
16
+ with open("class_names.txt", "r") as f: # reading them in from class_names.txt
17
+ class_names = [food_name.strip() for food_name in f.readlines()]
18
+
19
+
20
+ mobileNetV3_transform = torchvision.models.MobileNet_V3_Large_Weights.DEFAULT.transforms()
21
+
22
+ mobileNetV3 = torchvision.models.mobilenet_v3_large()
23
+
24
+ mobileNetV3.classifier = nn.Sequential(
25
+ nn.Linear(in_features=960, out_features=1280, bias=True),
26
+ nn.Hardswish(),
27
+ nn.Dropout(p=0.2, inplace=True),
28
+ nn.Linear(in_features=1280, out_features=len(class_names), bias=True)
29
+ )
30
+
31
+ mobileNetV3.load_state_dict(torch.load("models/mobileNetV3_quickdraw_animals_epoch_80_Adam.pth",map_location=torch.device('cpu')))
32
+
33
+
34
+
35
+ def convert_img_to_tensor(img):
36
+ convert_img = torch.from_numpy(img)
37
+ convert_img = convert_img.repeat(3, 1, 1)
38
+ return convert_img
39
+
40
+
41
+
42
+ def predict(img) -> Dict:
43
+ """Transforms and performs a prediction on img and returns prediction and time taken.
44
+ """
45
+ # Start the timer
46
+ # print(type(img))
47
+ # img = img['composite']
48
+ # if type(img) == numpy.ndarray:
49
+ # img = PIL.Image.fromarray(img)
50
+ img = convert_img_to_tensor(img)
51
+ # padding = transforms.CenterCrop([224, 224])
52
+ # img = padding(img)
53
+ # img = torch.from_numpy(img)
54
+ # print(img.shape)
55
+ # Transform the target image and add a batch dimension
56
+ # # img = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255
57
+ img = mobileNetV3_transform(img).unsqueeze(0)
58
+ # img = image_transform(img).unsqueeze(0)
59
+
60
+
61
+ # Put model into evaluation mode and turn on inference mode
62
+ mobileNetV3.eval()
63
+ with torch.inference_mode():
64
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
65
+ y_pred = mobileNetV3(img)
66
+ pred_probs = torch.softmax(y_pred, dim=1)
67
+ # y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
68
+
69
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
70
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
71
+
72
+ # Calculate the prediction time
73
+
74
+ # Return the prediction dictionary and prediction time
75
+ return pred_labels_and_probs
76
+
77
+ title = "MobileNetV3 - Quick Draw - Animals 🔢"
78
+ description = "An MobileNetV3 feature extractor computer vision model to classify doodling of animals."
79
+ article = "Created using transfer learning from MobileNetV3"
80
+
81
+
82
+ sp = gr.Sketchpad(shape = (28,28), brush_radius = 1)
83
+ # sp = gr.Sketchpad(type = "pil")
84
+ demo = gr.Interface(
85
+ fn=predict,
86
+ inputs=sp,
87
+ outputs= gr.Label(num_top_classes=5, label="Predictions"),
88
+ # outputs= gr.Image(),
89
+ # examples=example_list,
90
+ # title=title,
91
+ # description=description,
92
+ # article=article,
93
+ )
94
+ demo.launch(debug=True)
class_names.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Ant
2
+ Bat
3
+ Bear
4
+ Bee
5
+ Bird
6
+ Crab
7
+ Crocodile
8
+ Dog
9
+ Dolphin
10
+ Donut
11
+ Dragon
12
+ Elephant
13
+ Flamingo
14
+ Frog
15
+ Giraffe
16
+ Hedgehog
17
+ Horse
18
+ Kangaroo
19
+ Lion
20
+ Lobster
21
+ Monkey
22
+ Octopus
23
+ Owl
24
+ Panda
25
+ Parrot
26
+ Penguin
27
+ Pig
28
+ Rabbit
29
+ Raccoon
30
+ Rhinoceros
31
+ Scorpion
32
+ Sea Turtle
33
+ Shark
34
+ Sheep
35
+ Snail
36
+ Snake
37
+ Spider
38
+ Squirrel
39
+ Swan
40
+ Tiger
41
+ Whale
42
+ Zebra
models/mobileNetV3_quickdraw_animals_epoch_80_Adam.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57405814360d60d10c8bb7562be945e338844ffbca061796ace78d4bf2a74493
3
+ size 17245682
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio==3.50.2