Spaces:
Running
Running
File size: 4,420 Bytes
17add13 39ee507 ffef866 39ee507 eb74463 39ee507 074635e 39ee507 5bfef7f 39ee507 5bfef7f 39ee507 eb74463 39ee507 17add13 39ee507 5bfef7f eb74463 5bfef7f 17add13 cb30e70 5bfef7f 17add13 5bfef7f 17add13 cb30e70 5bfef7f 17add13 eb74463 c393f6b 5bfef7f eb74463 17add13 c393f6b eb74463 5bfef7f eb74463 5bfef7f eb74463 ffef866 cb30e70 5bfef7f 931f551 eb74463 cb30e70 eb74463 cb30e70 eb74463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os, glob
from itertools import groupby
import cv2
import numpy as np
import gradio as gr
import tensorflow as tf
def get_samples():
list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'samples/*.jpg'))
# sort by name
list_.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
return [[i] for i in list_]
def cv2_imread(path):
return cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
def decode_label(mat, chars) -> str:
# mat is the output of model
# get char indices along best path
best_path_indices = np.argmax(mat[0], axis=-1)
# collapse best path (using itertools.groupby), map to chars, join char list to string
best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != len(chars)]
res = ''.join(best_chars_collapsed)
# remove space and '_'
res = res.replace(' ', '').replace('_', '')
return res
def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True):
# get img shape
img_h, img_w = img.shape[:2]
# get ratio
ratio = min(w / img_w, h / img_h)
if len(img.shape) == 3:
inter = cv2.INTER_AREA
# resize img
img = cv2.resize(img, (int(img_w * ratio), int(img_h * ratio)), interpolation=inter)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# get new img shape
img_h, img_w = img.shape[:2]
# get start point
start_w = (w - img_w) // 2
start_h = (h - img_h) // 2
if top_left:
start_w = 0
start_h = 0
if len(img.shape) == 2:
# create new img
new_img = np.zeros((h, w), dtype=np.uint8)
new_img[start_h:start_h+img_h, start_w:start_w+img_w] = img
else:
new_img = np.zeros((h, w, 3), dtype=np.uint8)
new_img[start_h:start_h+img_h, start_w:start_w+img_w, :] = img
return new_img
def load_dict(dict_path='label.names'):
with open(dict_path, 'r', encoding='utf-8') as f:
_dict = f.read().splitlines()
_dict = {i: _dict[i] for i in range(len(_dict))}
return _dict
class TFliteDemo:
def __init__(self, model_path, blank=0):
self.blank = blank
self.interpreter = tf.lite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.inputs = self.interpreter.get_input_details()
self.outputs = self.interpreter.get_output_details()
def inference(self, x):
self.interpreter.set_tensor(self.inputs[0]['index'], x)
self.interpreter.invoke()
return self.interpreter.get_tensor(self.outputs[0]['index'])
def preprocess(self, img):
if isinstance(img, str):
image = cv2_imread(img)
else:
# check none
if img is None:
raise ValueError('img is None')
image = img.copy()
image = center_fit(image, 192, 96, top_left=True)
image = np.reshape(image, (1, *image.shape, 1)).astype(np.float32) / 255.0
return image
def get_confidence(self, pred):
_argmax = np.argmax(pred, axis=-1)
_idx = _argmax != pred.shape[-1] - 1
conf = pred[_idx, _argmax[_idx]]
return np.min(np.exp(conf))
def postprocess(self, pred):
label = decode_label(pred, load_dict())
conf = self.get_confidence(pred[0])
# keep 4 decimal places
conf = float('{:.4f}'.format(conf))
return label, conf
def get_results(self, img):
img = self.preprocess(img)
pred = self.inference(img)
return self.postprocess(pred)
if __name__ == '__main__':
_TITLE = '''South Korean License Plate Recognition'''
_DESCRIPTION = '''
<div>
<p style="text-align: center; font-size: 1.3em">This is a demo of South Korean License Plate Recognition.
<a style="display:inline-block; margin-left: .5em" href='https://github.com/noahzhy/KR_LPR_TF/'><img src='https://img.shields.io/github/stars/noahzhy/KR_LPR_TF?style=social' /></a>
</p>
</div>
'''
# init model
demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite'))
app = gr.Interface(
fn=demo.get_results,
inputs="image",
outputs=[
gr.Textbox(label="Plate Number", type="text"),
gr.Textbox(label="Confidence", type="text"),
],
title=_TITLE,
description=_DESCRIPTION,
examples=get_samples(),
)
app.launch()
|