noahzhy commited on
Commit
39ee507
1 Parent(s): ffef866

add online demo

Browse files
Files changed (10) hide show
  1. .gitignore +1 -0
  2. app.py +96 -4
  3. label.names +84 -0
  4. model.tflite +3 -0
  5. samples/00.jpg +0 -0
  6. samples/01.jpg +0 -0
  7. samples/02.jpg +0 -0
  8. samples/03.jpg +0 -0
  9. samples/04.jpg +0 -0
  10. samples/06.jpg +0 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
app.py CHANGED
@@ -1,7 +1,99 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from itertools import groupby
4
+
5
+ import cv2
6
+ import numpy as np
7
  import gradio as gr
8
+ import tensorflow as tf
9
+
10
+
11
+ def get_sample_images():
12
+ list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'samples/*.jpg'))
13
+ return [[i] for i in list_]
14
+
15
+
16
+ def inference(image):
17
+ # load model
18
+ demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite'))
19
+ # load image
20
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
21
+ image = center_fit(image, 128, 64, top_left=True)
22
+ image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8)
23
+ # inference
24
+ pred = demo.inference(image)
25
+ # decode
26
+ dict = load_dict(os.path.join(os.path.dirname(__file__), 'label.names'))
27
+ res = decode_label(pred, dict)
28
+ return res
29
+
30
+
31
+ class TFliteDemo:
32
+ def __init__(self, model_path):
33
+ self.interpreter = tf.lite.Interpreter(model_path=model_path)
34
+ self.interpreter.allocate_tensors()
35
+ self.input_details = self.interpreter.get_input_details()
36
+ self.output_details = self.interpreter.get_output_details()
37
+
38
+ def inference(self, x):
39
+ self.interpreter.set_tensor(self.input_details[0]['index'], x)
40
+ self.interpreter.invoke()
41
+ return self.interpreter.get_tensor(self.output_details[0]['index'])
42
+
43
+
44
+ def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True):
45
+ # get img shape
46
+ img_h, img_w = img.shape[:2]
47
+ # get ratio
48
+ ratio = min(w / img_w, h / img_h)
49
+
50
+ if len(img.shape) == 3:
51
+ inter = cv2.INTER_AREA
52
+ # resize img
53
+ img = cv2.resize(img, (int(img_w * ratio), int(img_h * ratio)), interpolation=inter)
54
+ # get new img shape
55
+ img_h, img_w = img.shape[:2]
56
+ # get start point
57
+ start_w = (w - img_w) // 2
58
+ start_h = (h - img_h) // 2
59
+
60
+ if top_left:
61
+ start_w = 0
62
+ start_h = 0
63
+
64
+ if len(img.shape) == 2:
65
+ # create new img
66
+ new_img = np.zeros((h, w), dtype=np.uint8)
67
+ new_img[start_h:start_h+img_h, start_w:start_w+img_w] = img
68
+ else:
69
+ new_img = np.zeros((h, w, 3), dtype=np.uint8)
70
+ new_img[start_h:start_h+img_h, start_w:start_w+img_w, :] = img
71
+
72
+ return new_img
73
+
74
+
75
+ def load_dict(dict_path='label.names'):
76
+ with open(dict_path, 'r', encoding='utf-8') as f:
77
+ dict = f.read().splitlines()
78
+ dict = {i: dict[i] for i in range(len(dict))}
79
+ return dict
80
+
81
+
82
+ def decode_label(mat, chars) -> str:
83
+ # mat is the output of model
84
+ # get char indices along best path
85
+ best_path_indices = np.argmax(mat[0], axis=-1)
86
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
87
+ best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != len(chars)]
88
+ res = ''.join(best_chars_collapsed)
89
+ return res
90
 
 
 
91
 
92
+ interface = gr.Interface(
93
+ fn=inference,
94
+ inputs="image",
95
+ outputs="text",
96
+ title="South Korean License Plate Recognition",
97
+ examples=get_sample_images(),
98
+ )
99
+ interface.launch()
label.names ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0
2
+ 1
3
+ 2
4
+ 3
5
+ 4
6
+ 5
7
+ 6
8
+ 7
9
+ 8
10
+ 9
11
+
12
+
13
+
14
+
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+ 서울
52
+ 부산
53
+ 대구
54
+ 인천
55
+ 광주
56
+ 대전
57
+ 울산
58
+ 세종
59
+ 경기
60
+ 강원
61
+ 충북
62
+ 충남
63
+ 전북
64
+ 전남
65
+ 경북
66
+ 경남
67
+ 제주
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
model.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a16ab2c9446d3edfa5e620ae418ed68952dd0047d8656eba2001a269341c0f81
3
+ size 307488
samples/00.jpg ADDED
samples/01.jpg ADDED
samples/02.jpg ADDED
samples/03.jpg ADDED
samples/04.jpg ADDED
samples/06.jpg ADDED