Spaces:
Runtime error
Runtime error
Davidzhangyuanhan
commited on
Commit
·
df59928
1
Parent(s):
d2d4aba
Add application file
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import cv2
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from PIL import Image
|
9 |
-
|
10 |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
from timm.data import create_transform
|
12 |
|
@@ -20,14 +20,14 @@ def pil_loader(filepath):
|
|
20 |
img = img.convert('RGB')
|
21 |
return img
|
22 |
|
23 |
-
def build_transforms(input_size):
|
24 |
-
transform = transforms.Compose([
|
25 |
-
transforms.Resize(input_size * 8 // 7),
|
26 |
-
transforms.CenterCrop(input_size),
|
27 |
-
transforms.ToTensor(),
|
28 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
29 |
])
|
30 |
-
return
|
31 |
|
32 |
# Download human-readable labels for Bamboo.
|
33 |
with open('./trainid2name.json') as f:
|
@@ -40,11 +40,6 @@ build model
|
|
40 |
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
41 |
model.eval()
|
42 |
|
43 |
-
'''
|
44 |
-
build data transform
|
45 |
-
'''
|
46 |
-
eval_transforms = build_transforms(224)
|
47 |
-
|
48 |
'''
|
49 |
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
|
50 |
'''
|
@@ -73,15 +68,19 @@ def show_cam_on_image(img: np.ndarray,
|
|
73 |
# cam = cam / np.max(cam)
|
74 |
return np.uint8(255 * cam)
|
75 |
|
|
|
|
|
|
|
76 |
def recognize_image(image):
|
77 |
img_t = eval_transforms(image)
|
78 |
-
|
79 |
# compute output
|
80 |
output = model(img_t.unsqueeze(0))
|
81 |
prediction = output.softmax(-1).flatten()
|
82 |
_,top5_idx = torch.topk(prediction, 5)
|
83 |
return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
|
84 |
|
|
|
|
|
85 |
|
86 |
image = gr.inputs.Image()
|
87 |
label = gr.outputs.Label(num_top_classes=5)
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from PIL import Image
|
9 |
+
import torchvision
|
10 |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
from timm.data import create_transform
|
12 |
|
|
|
20 |
img = img.convert('RGB')
|
21 |
return img
|
22 |
|
23 |
+
def build_transforms(input_size, center_crop=True):
|
24 |
+
transform = torchvision.transforms.Compose([
|
25 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
26 |
+
torchvision.transforms.CenterCrop(input_size),
|
27 |
+
torchvision.transforms.ToTensor(),
|
28 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
29 |
])
|
30 |
+
return transform
|
31 |
|
32 |
# Download human-readable labels for Bamboo.
|
33 |
with open('./trainid2name.json') as f:
|
|
|
40 |
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
41 |
model.eval()
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
'''
|
44 |
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
|
45 |
'''
|
|
|
68 |
# cam = cam / np.max(cam)
|
69 |
return np.uint8(255 * cam)
|
70 |
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
def recognize_image(image):
|
75 |
img_t = eval_transforms(image)
|
|
|
76 |
# compute output
|
77 |
output = model(img_t.unsqueeze(0))
|
78 |
prediction = output.softmax(-1).flatten()
|
79 |
_,top5_idx = torch.topk(prediction, 5)
|
80 |
return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
|
81 |
|
82 |
+
eval_transforms = build_transforms(224)
|
83 |
+
|
84 |
|
85 |
image = gr.inputs.Image()
|
86 |
label = gr.outputs.Label(num_top_classes=5)
|