KR_LPR / app.py
noahzhy's picture
Add image check and confidence calculation in
cb30e70
raw
history blame
4.26 kB
import os
import glob
from itertools import groupby
import cv2
import numpy as np
import gradio as gr
import tensorflow as tf
def get_sample_images():
list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'samples/*.jpg'))
return [[i] for i in list_]
def inference(image):
# load model
demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite'))
# check image is not None
if image is None:
return 'None', 'None'
# load image
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = center_fit(image, 128, 64, top_left=True)
image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8)
# inference
pred = demo.inference(image)
# decode
dict = load_dict(os.path.join(os.path.dirname(__file__), 'label.names'))
res = decode_label(pred, dict)
# get confidence
confidence = get_confidence(pred)
return res, confidence
class TFliteDemo:
def __init__(self, model_path):
self.interpreter = tf.lite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def inference(self, x):
self.interpreter.set_tensor(self.input_details[0]['index'], x)
self.interpreter.invoke()
return self.interpreter.get_tensor(self.output_details[0]['index'])
def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True):
# get img shape
img_h, img_w = img.shape[:2]
# get ratio
ratio = min(w / img_w, h / img_h)
if len(img.shape) == 3:
inter = cv2.INTER_AREA
# resize img
img = cv2.resize(img, (int(img_w * ratio), int(img_h * ratio)), interpolation=inter)
# get new img shape
img_h, img_w = img.shape[:2]
# get start point
start_w = (w - img_w) // 2
start_h = (h - img_h) // 2
if top_left:
start_w = 0
start_h = 0
if len(img.shape) == 2:
# create new img
new_img = np.zeros((h, w), dtype=np.uint8)
new_img[start_h:start_h+img_h, start_w:start_w+img_w] = img
else:
new_img = np.zeros((h, w, 3), dtype=np.uint8)
new_img[start_h:start_h+img_h, start_w:start_w+img_w, :] = img
return new_img
def load_dict(dict_path='label.names'):
with open(dict_path, 'r', encoding='utf-8') as f:
dict = f.read().splitlines()
dict = {i: dict[i] for i in range(len(dict))}
return dict
def get_confidence(mat) -> float:
# mat is the output of model
# get char indices along best path
best_path_indices = np.argmax(mat[0], axis=-1)
confidence = np.max(mat[0], axis=-1)
blank_idx = mat.shape[-1] - 1
avg_confidence = []
for idx, conf in zip(best_path_indices, confidence):
if idx != blank_idx:
avg_confidence.append(conf)
conf = np.mean(avg_confidence) / 255.0
# keep 4 decimal places
return "{:.4f}".format(conf)
def decode_label(mat, chars) -> str:
# mat is the output of model
# get char indices along best path
best_path_indices = np.argmax(mat[0], axis=-1)
# collapse best path (using itertools.groupby), map to chars, join char list to string
best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != len(chars)]
res = ''.join(best_chars_collapsed)
# remove space and '_'
res = res.replace(' ', '').replace('_', '')
return res
if __name__ == '__main__':
_TITLE = '''South Korean License Plate Recognition'''
_DESCRIPTION = '''
<div>
<p style="text-align: center; font-size: 1.3em">This is a demo of South Korean License Plate Recognition.
<a style="display:inline-block; margin-left: .5em" href='https://github.com/noahzhy/KR_LPR_TF/'><img src='https://img.shields.io/github/stars/noahzhy/KR_LPR_TF?style=social' /></a>
</p>
</div>
'''
interface = gr.Interface(
fn=inference,
inputs="image",
outputs=[
gr.Textbox(label="Plate Number", type="text"),
gr.Textbox(label="Confidence", type="text"),
],
title=_TITLE,
description=_DESCRIPTION,
examples=get_sample_images(),
)
interface.launch()