noahzhy commited on
Commit
17add13
1 Parent(s): 5bfef7f

Refactor code for improved readability and efficiency

Browse files
Files changed (1) hide show
  1. app.py +17 -27
app.py CHANGED
@@ -1,5 +1,4 @@
1
- import os
2
- import glob
3
  from itertools import groupby
4
 
5
  import cv2
@@ -64,9 +63,9 @@ def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True):
64
 
65
  def load_dict(dict_path='label.names'):
66
  with open(dict_path, 'r', encoding='utf-8') as f:
67
- dict = f.read().splitlines()
68
- dict = {i: dict[i] for i in range(len(dict))}
69
- return dict
70
 
71
 
72
  class TFliteDemo:
@@ -75,37 +74,32 @@ class TFliteDemo:
75
  self.conf_mode = conf_mode
76
  self.interpreter = tf.lite.Interpreter(model_path=model_path)
77
  self.interpreter.allocate_tensors()
78
- self.input_details = self.interpreter.get_input_details()
79
- self.output_details = self.interpreter.get_output_details()
80
 
81
  def inference(self, x):
82
- self.interpreter.set_tensor(self.input_details[0]['index'], x)
83
  self.interpreter.invoke()
84
- return self.interpreter.get_tensor(self.output_details[0]['index'])
85
 
86
  def preprocess(self, img):
87
  if isinstance(img, str):
88
  image = cv2_imread(img)
89
  else:
90
- image = img
 
 
 
91
  image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
92
  image = center_fit(image, 128, 64, top_left=True)
93
  image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8)
94
  return image
95
 
96
  def get_confidence(self, pred, mode="mean"):
97
- conf = []
98
- idxs = np.argmax(pred, axis=-1)
99
- values = np.max(pred, axis=-1)
100
-
101
- for idx, c in zip(idxs, values):
102
- if idx == self.blank: continue
103
- conf.append(c/255)
104
-
105
- if mode == "min":
106
- return np.min(conf)
107
-
108
- return np.mean(conf)
109
 
110
  def postprocess(self, pred):
111
  label = decode_label(pred, load_dict())
@@ -116,13 +110,9 @@ class TFliteDemo:
116
 
117
 
118
  def inference(img):
119
- # preprocess
120
  img = demo.preprocess(img)
121
- # inference
122
  pred = demo.inference(img)
123
- # postprocess
124
- label, conf = demo.postprocess(pred)
125
- return label, conf
126
 
127
 
128
  if __name__ == '__main__':
 
1
+ import os, glob
 
2
  from itertools import groupby
3
 
4
  import cv2
 
63
 
64
  def load_dict(dict_path='label.names'):
65
  with open(dict_path, 'r', encoding='utf-8') as f:
66
+ _dict = f.read().splitlines()
67
+ _dict = {i: _dict[i] for i in range(len(_dict))}
68
+ return _dict
69
 
70
 
71
  class TFliteDemo:
 
74
  self.conf_mode = conf_mode
75
  self.interpreter = tf.lite.Interpreter(model_path=model_path)
76
  self.interpreter.allocate_tensors()
77
+ self.inputs = self.interpreter.get_input_details()
78
+ self.outputs = self.interpreter.get_output_details()
79
 
80
  def inference(self, x):
81
+ self.interpreter.set_tensor(self.inputs[0]['index'], x)
82
  self.interpreter.invoke()
83
+ return self.interpreter.get_tensor(self.outputs[0]['index'])
84
 
85
  def preprocess(self, img):
86
  if isinstance(img, str):
87
  image = cv2_imread(img)
88
  else:
89
+ # check none
90
+ if img is None:
91
+ raise ValueError('img is None')
92
+ image = img.copy()
93
  image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
94
  image = center_fit(image, 128, 64, top_left=True)
95
  image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8)
96
  return image
97
 
98
  def get_confidence(self, pred, mode="mean"):
99
+ _argmax = np.argmax(pred, axis=-1)
100
+ _idx = _argmax != pred.shape[-1] - 1
101
+ conf = pred[_idx, _argmax[_idx]] / 255.0
102
+ return np.min(conf) if mode == "min" else np.mean(conf)
 
 
 
 
 
 
 
 
103
 
104
  def postprocess(self, pred):
105
  label = decode_label(pred, load_dict())
 
110
 
111
 
112
  def inference(img):
 
113
  img = demo.preprocess(img)
 
114
  pred = demo.inference(img)
115
+ return demo.postprocess(pred)
 
 
116
 
117
 
118
  if __name__ == '__main__':