import os import glob from itertools import groupby import cv2 import numpy as np import gradio as gr import tensorflow as tf def get_sample_images(): list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'samples/*.jpg')) return [[i] for i in list_] def inference(image): # load model demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite')) # load image image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = center_fit(image, 128, 64, top_left=True) image = np.reshape(image, (1, *image.shape, 1)).astype(np.uint8) # inference pred = demo.inference(image) # decode dict = load_dict(os.path.join(os.path.dirname(__file__), 'label.names')) res = decode_label(pred, dict) return res class TFliteDemo: def __init__(self, model_path): self.interpreter = tf.lite.Interpreter(model_path=model_path) self.interpreter.allocate_tensors() self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() def inference(self, x): self.interpreter.set_tensor(self.input_details[0]['index'], x) self.interpreter.invoke() return self.interpreter.get_tensor(self.output_details[0]['index']) def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True): # get img shape img_h, img_w = img.shape[:2] # get ratio ratio = min(w / img_w, h / img_h) if len(img.shape) == 3: inter = cv2.INTER_AREA # resize img img = cv2.resize(img, (int(img_w * ratio), int(img_h * ratio)), interpolation=inter) # get new img shape img_h, img_w = img.shape[:2] # get start point start_w = (w - img_w) // 2 start_h = (h - img_h) // 2 if top_left: start_w = 0 start_h = 0 if len(img.shape) == 2: # create new img new_img = np.zeros((h, w), dtype=np.uint8) new_img[start_h:start_h+img_h, start_w:start_w+img_w] = img else: new_img = np.zeros((h, w, 3), dtype=np.uint8) new_img[start_h:start_h+img_h, start_w:start_w+img_w, :] = img return new_img def load_dict(dict_path='label.names'): with open(dict_path, 'r', encoding='utf-8') as f: dict = f.read().splitlines() dict = {i: dict[i] for i in range(len(dict))} return dict def decode_label(mat, chars) -> str: # mat is the output of model # get char indices along best path best_path_indices = np.argmax(mat[0], axis=-1) # collapse best path (using itertools.groupby), map to chars, join char list to string best_chars_collapsed = [chars[k] for k, _ in groupby(best_path_indices) if k != len(chars)] res = ''.join(best_chars_collapsed) # remove space and '_' res = res.replace(' ', '').replace('_', '') return res _TITLE = '''South Korean License Plate Recognition''' _DESCRIPTION = '''
''' interface = gr.Interface( fn=inference, inputs="image", outputs="text", title=_TITLE, description=_DESCRIPTION, examples=get_sample_images(), ) interface.launch()