windy2612 commited on
Commit
bf10b27
·
verified ·
1 Parent(s): 7a785fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -22,7 +22,7 @@ with open(config_path, 'r') as data:
22
  config = yaml.safe_load(data)
23
  system = System(config)
24
  checkpoint_path = 'weights/parseq.ckpt'
25
- checkpoint = torch.load(checkpoint_path, map_location = 'cpu')
26
  system.load_state_dict(checkpoint['state_dict'])
27
  system.to(device)
28
 
@@ -30,13 +30,20 @@ def predict(image):
30
  if isinstance(image, str):
31
  image = cv2.imread(image)
32
  _, img_wapred, _, _ = detect_lp(wpod_net, image, 0.5)
33
- img = (img_wapred[0] * 255).astype(np.uint8)
34
- img = Image.fromarray(img).convert("RGB")
35
- image = trans(img).unsqueeze(0)
36
- with torch.no_grad():
37
- pred = system(image).softmax(-1)
38
- generated_text, _ = system.tokenizer.decode(pred)
39
- return generated_text[0]
 
 
 
 
 
 
 
40
 
41
  interface = gr.Interface(
42
  fn = predict,
 
22
  config = yaml.safe_load(data)
23
  system = System(config)
24
  checkpoint_path = 'weights/parseq.ckpt'
25
+ checkpoint = torch.load(checkpoint_path, map_location = 'cuda')
26
  system.load_state_dict(checkpoint['state_dict'])
27
  system.to(device)
28
 
 
30
  if isinstance(image, str):
31
  image = cv2.imread(image)
32
  _, img_wapred, _, _ = detect_lp(wpod_net, image, 0.5)
33
+ if len(img_warped) == 0:
34
+ return "Can not detect license plate from image"
35
+ else:
36
+ system.eval()
37
+ pred_labels = []
38
+ for i in range(len(img_warped)):
39
+ img = (img_wapred[i] * 255).astype(np.uint8)
40
+ img = Image.fromarray(img).convert("RGB")
41
+ image = trans(img).unsqueeze(0)
42
+ with torch.no_grad():
43
+ pred = system(image).softmax(-1)
44
+ generated_text, _ = system.tokenizer.decode(pred)
45
+ pred_labels.append(generated_text[0])
46
+ return pred_labels
47
 
48
  interface = gr.Interface(
49
  fn = predict,