Spaces:
Running
Running
import os | |
import glob | |
from itertools import groupby | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
import tensorflow as tf | |
def get_sample_images(): | |
list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'samples/*.jpg')) | |
return [[i] for i in list_] | |
def inference(image): | |
# load model | |
demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite')) | |
# check image is not None | |
if image is None: | |
return 'None', 'None' | |
# load image | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
image = center_fit(image, 128, 64, top_left=True) | |
image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8) | |
# inference | |
pred = demo.inference(image) | |
# decode | |
dict = load_dict(os.path.join(os.path.dirname(__file__), 'label.names')) | |
res = decode_label(pred, dict) | |
# get confidence | |
confidence = get_confidence(pred) | |
return res, confidence | |
class TFliteDemo: | |
def __init__(self, model_path): | |
self.interpreter = tf.lite.Interpreter(model_path=model_path) | |
self.interpreter.allocate_tensors() | |
self.input_details = self.interpreter.get_input_details() | |
self.output_details = self.interpreter.get_output_details() | |
def inference(self, x): | |
self.interpreter.set_tensor(self.input_details[0]['index'], x) | |
self.interpreter.invoke() | |
return self.interpreter.get_tensor(self.output_details[0]['index']) | |
def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True): | |
# get img shape | |
img_h, img_w = img.shape[:2] | |
# get ratio | |
ratio = min(w / img_w, h / img_h) | |
if len(img.shape) == 3: | |
inter = cv2.INTER_AREA | |
# resize img | |
img = cv2.resize(img, (int(img_w * ratio), int(img_h * ratio)), interpolation=inter) | |
# get new img shape | |
img_h, img_w = img.shape[:2] | |
# get start point | |
start_w = (w - img_w) // 2 | |
start_h = (h - img_h) // 2 | |
if top_left: | |
start_w = 0 | |
start_h = 0 | |
if len(img.shape) == 2: | |
# create new img | |
new_img = np.zeros((h, w), dtype=np.uint8) | |
new_img[start_h:start_h+img_h, start_w:start_w+img_w] = img | |
else: | |
new_img = np.zeros((h, w, 3), dtype=np.uint8) | |
new_img[start_h:start_h+img_h, start_w:start_w+img_w, :] = img | |
return new_img | |
def load_dict(dict_path='label.names'): | |
with open(dict_path, 'r', encoding='utf-8') as f: | |
dict = f.read().splitlines() | |
dict = {i: dict[i] for i in range(len(dict))} | |
return dict | |
def get_confidence(mat) -> float: | |
# mat is the output of model | |
# get char indices along best path | |
best_path_indices = np.argmax(mat[0], axis=-1) | |
confidence = np.max(mat[0], axis=-1) | |
blank_idx = mat.shape[-1] - 1 | |
avg_confidence = [] | |
for idx, conf in zip(best_path_indices, confidence): | |
if idx != blank_idx: | |
avg_confidence.append(conf) | |
conf = np.mean(avg_confidence) / 255.0 | |
# keep 4 decimal places | |
return "{:.4f}".format(conf) | |
def decode_label(mat, chars) -> str: | |
# mat is the output of model | |
# get char indices along best path | |
best_path_indices = np.argmax(mat[0], axis=-1) | |
# collapse best path (using itertools.groupby), map to chars, join char list to string | |
best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != len(chars)] | |
res = ''.join(best_chars_collapsed) | |
# remove space and '_' | |
res = res.replace(' ', '').replace('_', '') | |
return res | |
if __name__ == '__main__': | |
_TITLE = '''South Korean License Plate Recognition''' | |
_DESCRIPTION = ''' | |
<div> | |
<p style="text-align: center; font-size: 1.3em">This is a demo of South Korean License Plate Recognition. | |
<a style="display:inline-block; margin-left: .5em" href='https://github.com/noahzhy/KR_LPR_TF/'><img src='https://img.shields.io/github/stars/noahzhy/KR_LPR_TF?style=social' /></a> | |
</p> | |
</div> | |
''' | |
interface = gr.Interface( | |
fn=inference, | |
inputs="image", | |
outputs=[ | |
gr.Textbox(label="Plate Number", type="text"), | |
gr.Textbox(label="Confidence", type="text"), | |
], | |
title=_TITLE, | |
description=_DESCRIPTION, | |
examples=get_sample_images(), | |
) | |
interface.launch() | |