noahzhy commited on
Commit
5bfef7f
1 Parent(s): 074635e

Refactor inference function and add preprocessing

Browse files
Files changed (1) hide show
  1. app.py +66 -52
app.py CHANGED
@@ -15,37 +15,20 @@ def get_sample_images():
15
  return [[i] for i in list_]
16
 
17
 
18
- def inference(image):
19
- # load model
20
- demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite'))
21
- # check image is not None
22
- if image is None:
23
- return 'None', 'None'
24
- # load image
25
- image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
26
- image = center_fit(image, 128, 64, top_left=True)
27
- image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8)
28
- # inference
29
- pred = demo.inference(image)
30
- # decode
31
- dict = load_dict(os.path.join(os.path.dirname(__file__), 'label.names'))
32
- res = decode_label(pred, dict)
33
- # get confidence
34
- confidence = get_confidence(pred)
35
- return res, confidence
36
 
37
 
38
- class TFliteDemo:
39
- def __init__(self, model_path):
40
- self.interpreter = tf.lite.Interpreter(model_path=model_path)
41
- self.interpreter.allocate_tensors()
42
- self.input_details = self.interpreter.get_input_details()
43
- self.output_details = self.interpreter.get_output_details()
44
-
45
- def inference(self, x):
46
- self.interpreter.set_tensor(self.input_details[0]['index'], x)
47
- self.interpreter.invoke()
48
- return self.interpreter.get_tensor(self.output_details[0]['index'])
49
 
50
 
51
  def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True):
@@ -86,31 +69,60 @@ def load_dict(dict_path='label.names'):
86
  return dict
87
 
88
 
89
- def get_confidence(mat) -> float:
90
- # mat is the output of model
91
- # get char indices along best path
92
- best_path_indices = np.argmax(mat[0], axis=-1)
93
- confidence = np.max(mat[0], axis=-1)
94
- blank_idx = mat.shape[-1] - 1
95
- avg_confidence = []
96
- for idx, conf in zip(best_path_indices, confidence):
97
- if idx != blank_idx:
98
- avg_confidence.append(conf)
99
- conf = np.mean(avg_confidence) / 255.0
100
- # keep 4 decimal places
101
- return "{:.4f}".format(conf)
102
 
 
 
 
 
103
 
104
- def decode_label(mat, chars) -> str:
105
- # mat is the output of model
106
- # get char indices along best path
107
- best_path_indices = np.argmax(mat[0], axis=-1)
108
- # collapse best path (using itertools.groupby), map to chars, join char list to string
109
- best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != len(chars)]
110
- res = ''.join(best_chars_collapsed)
111
- # remove space and '_'
112
- res = res.replace(' ', '').replace('_', '')
113
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  if __name__ == '__main__':
@@ -122,6 +134,8 @@ if __name__ == '__main__':
122
  </p>
123
  </div>
124
  '''
 
 
125
  interface = gr.Interface(
126
  fn=inference,
127
  inputs="image",
 
15
  return [[i] for i in list_]
16
 
17
 
18
+ def cv2_imread(path):
19
+ return cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
+ def decode_label(mat, chars) -> str:
23
+ # mat is the output of model
24
+ # get char indices along best path
25
+ best_path_indices = np.argmax(mat[0], axis=-1)
26
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
27
+ best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != len(chars)]
28
+ res = ''.join(best_chars_collapsed)
29
+ # remove space and '_'
30
+ res = res.replace(' ', '').replace('_', '')
31
+ return res
 
32
 
33
 
34
  def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True):
 
69
  return dict
70
 
71
 
72
+ class TFliteDemo:
73
+ def __init__(self, model_path, blank=85, conf_mode="mean"):
74
+ self.blank = blank
75
+ self.conf_mode = conf_mode
76
+ self.interpreter = tf.lite.Interpreter(model_path=model_path)
77
+ self.interpreter.allocate_tensors()
78
+ self.input_details = self.interpreter.get_input_details()
79
+ self.output_details = self.interpreter.get_output_details()
 
 
 
 
 
80
 
81
+ def inference(self, x):
82
+ self.interpreter.set_tensor(self.input_details[0]['index'], x)
83
+ self.interpreter.invoke()
84
+ return self.interpreter.get_tensor(self.output_details[0]['index'])
85
 
86
+ def preprocess(self, img):
87
+ if isinstance(img, str):
88
+ image = cv2_imread(img)
89
+ else:
90
+ image = img
91
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
92
+ image = center_fit(image, 128, 64, top_left=True)
93
+ image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8)
94
+ return image
95
+
96
+ def get_confidence(self, pred, mode="mean"):
97
+ conf = []
98
+ idxs = np.argmax(pred, axis=-1)
99
+ values = np.max(pred, axis=-1)
100
+
101
+ for idx, c in zip(idxs, values):
102
+ if idx == self.blank: continue
103
+ conf.append(c/255)
104
+
105
+ if mode == "min":
106
+ return np.min(conf)
107
+
108
+ return np.mean(conf)
109
+
110
+ def postprocess(self, pred):
111
+ label = decode_label(pred, load_dict())
112
+ conf = self.get_confidence(pred[0], mode=self.conf_mode)
113
+ # keep 4 decimal places
114
+ conf = float('{:.4f}'.format(conf))
115
+ return label, conf
116
+
117
+
118
+ def inference(img):
119
+ # preprocess
120
+ img = demo.preprocess(img)
121
+ # inference
122
+ pred = demo.inference(img)
123
+ # postprocess
124
+ label, conf = demo.postprocess(pred)
125
+ return label, conf
126
 
127
 
128
  if __name__ == '__main__':
 
134
  </p>
135
  </div>
136
  '''
137
+ # init model
138
+ demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite'))
139
  interface = gr.Interface(
140
  fn=inference,
141
  inputs="image",