windy2612 commited on
Commit
1c9ac2f
·
verified ·
1 Parent(s): 6c3ee18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -8,7 +8,7 @@ from wpodnet.lib_detection import load_model_wpod, detect_lp
8
  import numpy as np
9
  import gradio as gr
10
  from torchvision import transforms as T
11
-
12
  trans = T.Compose([
13
  T.Resize((224, 224), T.InterpolationMode.BICUBIC),
14
  T.ToTensor(),
@@ -17,7 +17,7 @@ trans = T.Compose([
17
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
- checkpoint_path = 'weights/parseq.ckpt'
21
  config_path = 'parseq/config.yaml'
22
  wpod_path = 'weights/wpod-net.h5'
23
  wpod_net = load_model_wpod(wpod_path)
@@ -25,7 +25,6 @@ wpod_net = load_model_wpod(wpod_path)
25
  with open(config_path, 'r') as data:
26
  config = yaml.safe_load(data)
27
  system = System(config)
28
- checkpoint_path = 'weights/parseq.ckpt'
29
  checkpoint = torch.load(checkpoint_path, map_location = 'cpu')
30
  system.load_state_dict(checkpoint['state_dict'])
31
  system.to(device)
@@ -33,12 +32,14 @@ system.to(device)
33
  def predict(image):
34
  if isinstance(image, str):
35
  image = cv2.imread(image)
36
- _, img_wapred, _, _ = detect_lp(wpod_net, image, 0.5)
 
 
37
  if len(img_wapred) == 0:
38
  return "Can not detect license plate from image"
39
  else:
40
  system.eval()
41
- pred_labels = []
42
  for i in range(len(img_wapred)):
43
  img = (img_wapred[i] * 255).astype(np.uint8)
44
  img = Image.fromarray(img).convert("RGB")
@@ -46,11 +47,16 @@ def predict(image):
46
  with torch.no_grad():
47
  pred = system(image).softmax(-1)
48
  generated_text, _ = system.tokenizer.decode(pred)
49
- pred_labels.append(generated_text[0])
50
- return pred_labels
 
 
 
 
 
51
 
52
  interface = gr.Interface(
53
  fn = predict,
54
  inputs =[gr.components.Image()],
55
- outputs=[gr.components.Textbox(label = "License plate", lines = 2)])
56
  interface.launch(share = True, debug = True)
 
8
  import numpy as np
9
  import gradio as gr
10
  from torchvision import transforms as T
11
+ import matplotlib.pyplot as plt
12
  trans = T.Compose([
13
  T.Resize((224, 224), T.InterpolationMode.BICUBIC),
14
  T.ToTensor(),
 
17
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
+ checkpoint_path = 'weights/best.ckpt'
21
  config_path = 'parseq/config.yaml'
22
  wpod_path = 'weights/wpod-net.h5'
23
  wpod_net = load_model_wpod(wpod_path)
 
25
  with open(config_path, 'r') as data:
26
  config = yaml.safe_load(data)
27
  system = System(config)
 
28
  checkpoint = torch.load(checkpoint_path, map_location = 'cpu')
29
  system.load_state_dict(checkpoint['state_dict'])
30
  system.to(device)
 
32
  def predict(image):
33
  if isinstance(image, str):
34
  image = cv2.imread(image)
35
+
36
+ draw_image = image.copy()
37
+ _, img_wapred, _, bounding_boxes = detect_lp(wpod_net, image, 0.5)
38
  if len(img_wapred) == 0:
39
  return "Can not detect license plate from image"
40
  else:
41
  system.eval()
42
+ bounding_boxes = np.array(bounding_boxes).astype(int)
43
  for i in range(len(img_wapred)):
44
  img = (img_wapred[i] * 255).astype(np.uint8)
45
  img = Image.fromarray(img).convert("RGB")
 
47
  with torch.no_grad():
48
  pred = system(image).softmax(-1)
49
  generated_text, _ = system.tokenizer.decode(pred)
50
+ if len(generated_text[0]) >= 5:
51
+ points = bounding_boxes[i]
52
+ cv2.polylines(draw_image, [points], isClosed = True, color = (0, 255, 0), thickness = 2)
53
+ position = (points[:, 0].min(), points[:, 1].min())
54
+ cv2.putText(draw_image, generated_text[0], position,
55
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.8, color=(255, 255, 0), thickness = 2)
56
+ return draw_image
57
 
58
  interface = gr.Interface(
59
  fn = predict,
60
  inputs =[gr.components.Image()],
61
+ outputs=[gr.components.Image()])
62
  interface.launch(share = True, debug = True)