Your Name commited on
Commit
4f12f52
·
1 Parent(s): ad53307

error resolved

Browse files
Files changed (2) hide show
  1. app.py +22 -14
  2. model.py +1 -0
app.py CHANGED
@@ -4,28 +4,36 @@ import os
4
  from model import create_effnet_b2
5
  from timeit import default_timer as timer
6
  from typing import Tuple, Dict
 
 
 
7
 
8
  class_names = ["pizza", "steak", "sushi"]
9
 
10
- effbet_b2_model , efftnet_b2_transform = create_effnet_b2()
11
 
12
- effbet_b2_model.load_state_dict(torch.load(f = "./effnet_b2.pt", map_location = torch.device("cpu")))
13
 
14
- def predict(img)-> Tuple[Dict,float]:
15
-
16
  start_time = timer()
17
- img = efftnet_b2_transform(img).unsqueeze(0)
18
-
19
- effbet_b2_model.eval()
20
 
 
 
 
 
 
 
 
21
  with torch.inference_mode():
22
- pred_prob = torch.softmax(effbet_b2_model(img), 1)
23
-
24
- pred_label_probs = {class_names[i] : float(pred_prob[0][i]) for i in range(len(class_names))}
25
-
 
 
26
  end_time = timer()
27
- pred_time = round(end_time - start_time , 4)
28
-
29
  return pred_label_probs, pred_time
30
 
31
 
@@ -46,7 +54,7 @@ demo = gr.Interface(
46
  allow_flagging="never"
47
  )
48
 
49
- demo.launch(debug=False)
50
 
51
 
52
 
 
4
  from model import create_effnet_b2
5
  from timeit import default_timer as timer
6
  from typing import Tuple, Dict
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
 
11
  class_names = ["pizza", "steak", "sushi"]
12
 
13
+ effnet_b2_model , effnet_b2_transform = create_effnet_b2()
14
 
15
+ effnet_b2_model.load_state_dict(torch.load(f = "./effnet_b2.pt", map_location = torch.device("cpu")))
16
 
17
+ def predict(img) -> Tuple[Dict, float]:
 
18
  start_time = timer()
 
 
 
19
 
20
+ # Convert from NumPy array to PIL image
21
+ if isinstance(img, np.ndarray):
22
+ img = Image.fromarray(img.astype("uint8"), "RGB")
23
+
24
+ img = effnet_b2_transform(img).unsqueeze(0)
25
+
26
+ effnet_b2_model.eval()
27
  with torch.inference_mode():
28
+ pred_prob = torch.softmax(effnet_b2_model(img), dim=1)
29
+
30
+ pred_label_probs = {
31
+ class_names[i]: float(pred_prob[0][i]) for i in range(len(class_names))
32
+ }
33
+
34
  end_time = timer()
35
+ pred_time = round(end_time - start_time, 4)
36
+
37
  return pred_label_probs, pred_time
38
 
39
 
 
54
  allow_flagging="never"
55
  )
56
 
57
+ demo.launch(debug=False, share=True)
58
 
59
 
60
 
model.py CHANGED
@@ -1,4 +1,5 @@
1
 
 
2
  from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2
3
  from torch import nn
4
 
 
1
 
2
+ import torch
3
  from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2
4
  from torch import nn
5