File size: 2,190 Bytes
ab576ba
 
 
 
 
 
 
 
 
c80b9ba
1c9ac2f
c80b9ba
 
 
 
 
ab576ba
 
 
0487ed0
ab576ba
 
 
 
 
 
 
1041d2e
ab576ba
 
 
 
 
 
1c9ac2f
 
 
6c3ee18
010a14f
bf10b27
 
1c9ac2f
6c3ee18
bf10b27
 
 
 
 
 
1c9ac2f
 
 
 
 
8aca60d
1c9ac2f
ab576ba
 
 
 
1c9ac2f
02f32ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch 
import torch.nn as nn
from parseq.system import System
import yaml
import cv2
from PIL import Image
from wpodnet.lib_detection import load_model_wpod, detect_lp
import numpy as np
import gradio as gr 
from torchvision import transforms as T
import matplotlib.pyplot as plt
trans = T.Compose([
            T.Resize((224, 224), T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint_path = 'weights/best.ckpt'
config_path = 'parseq/config.yaml'
wpod_path = 'weights/wpod-net.h5'
wpod_net = load_model_wpod(wpod_path)

with open(config_path, 'r') as data:
    config = yaml.safe_load(data)
system = System(config)
checkpoint = torch.load(checkpoint_path, map_location = 'cpu')
system.load_state_dict(checkpoint['state_dict'])
system.to(device)

def predict(image):
    if isinstance(image, str):
        image = cv2.imread(image)
    
    draw_image = image.copy()
    _, img_wapred, _, bounding_boxes = detect_lp(wpod_net, image, 0.5)
    if len(img_wapred) == 0:
        return draw_image
    else:
        system.eval()
        bounding_boxes = np.array(bounding_boxes).astype(int)
        for i in range(len(img_wapred)):
            img =  (img_wapred[i] * 255).astype(np.uint8)
            img = Image.fromarray(img).convert("RGB")
            image = trans(img).unsqueeze(0)
            with torch.no_grad():
                pred = system(image).softmax(-1)
            generated_text, _ = system.tokenizer.decode(pred)
            if len(generated_text[0]) >= 5:
                points = bounding_boxes[i]
                cv2.polylines(draw_image, [points], isClosed = True, color = (0, 255, 0), thickness = 2)
                position = (points[:, 0].min(), points[:, 1].min())
                cv2.putText(draw_image, generated_text[0], position, 
                            fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.8, color=(226, 218, 53), thickness = 2)
        return draw_image

interface = gr.Interface(
    fn = predict,
    inputs =[gr.components.Image()],
    outputs=[gr.components.Image()])
interface.launch(debug = True)