Amine-0047 commited on
Commit
0ec591c
·
verified ·
1 Parent(s): 361e064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -31
app.py CHANGED
@@ -5,19 +5,19 @@ import torch
5
  import yolov5
6
 
7
  # Load YOLOv5 model
8
- @st.cache(allow_output_mutation=True)
9
  def load_model():
10
  return yolov5.load('keremberke/yolov5m-license-plate')
11
 
12
  # Load TR-OCR model
13
- @st.cache(allow_output_mutation=True)
14
  def load_ocr_model():
15
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
16
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
17
  return processor, model
18
 
19
  # Load TTS model
20
- @st.cache(allow_output_mutation=True)
21
  def load_tts_model():
22
  model = VitsModel.from_pretrained("facebook/mms-tts-eng")
23
  tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
@@ -30,44 +30,46 @@ def main():
30
  # Static test image
31
  test_image_path = "test_image.jpg"
32
  test_image = Image.open(test_image_path)
33
- st.image(test_image, caption='Test Image', use_column_width=True)
34
 
35
  # Upload file
36
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
37
 
38
- # Load models on startup
39
- model = load_model()
40
- processor, ocr_model = load_ocr_model()
41
- tts_model, tokenizer = load_tts_model()
42
-
43
  if uploaded_file is not None:
44
  img = Image.open(uploaded_file)
45
- st.image(img, caption='Uploaded Image', use_column_width=True)
 
 
 
 
 
 
 
 
 
46
 
47
- if st.button("Run Inference"):
48
- results = model(img, size=640)
49
- # results.show()
50
- predictions = results.pred[0]
51
- boxes = predictions[:, :4] # x1, y1, x2, y2
52
- scores = predictions[:, 4]
53
- categories = predictions[:, 5]
54
 
55
- # Crop the image of the license plate
56
- cropped_image = img.crop(tuple(results.xyxy[0][0, :4].squeeze().tolist()[:4]))
57
- st.image(cropped_image, caption='Plate detected')
58
 
59
- # Extract text from the image
60
- pixel_values = processor(cropped_image, return_tensors="pt").pixel_values
61
- generated_ids = ocr_model.generate(pixel_values)
62
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
63
 
64
- st.write("Detected License Plate Text:", generated_text)
65
 
66
- # Convert the text to audio
67
- inputs = tokenizer(generated_text, return_tensors="pt")
68
- with torch.no_grad():
69
- output = tts_model(**inputs).waveform
70
- st.audio(output.numpy(), format="audio/wav", sample_rate=tts_model.config.sampling_rate)
71
 
72
  if __name__ == "__main__":
73
- main()
 
5
  import yolov5
6
 
7
  # Load YOLOv5 model
8
+ # @st.cache(allow_output_mutation=True)
9
  def load_model():
10
  return yolov5.load('keremberke/yolov5m-license-plate')
11
 
12
  # Load TR-OCR model
13
+ # @st.cache(allow_output_mutation=True)
14
  def load_ocr_model():
15
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
16
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
17
  return processor, model
18
 
19
  # Load TTS model
20
+ # @st.cache(allow_output_mutation=True)
21
  def load_tts_model():
22
  model = VitsModel.from_pretrained("facebook/mms-tts-eng")
23
  tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
30
  # Static test image
31
  test_image_path = "test_image.jpg"
32
  test_image = Image.open(test_image_path)
 
33
 
34
  # Upload file
35
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
36
 
 
 
 
 
 
37
  if uploaded_file is not None:
38
  img = Image.open(uploaded_file)
39
+ else:
40
+ img = test_image
41
+
42
+ st.image(img, caption='Image', use_column_width=True)
43
+
44
+ if st.button("Run Inference"):
45
+ # Load models on startup
46
+ model = load_model()
47
+ processor, ocr_model = load_ocr_model()
48
+ tts_model, tokenizer = load_tts_model()
49
 
50
+ results = model(img, size=640)
51
+ # results.show()
52
+ predictions = results.pred[0]
53
+ boxes = predictions[:, :4] # x1, y1, x2, y2
54
+ scores = predictions[:, 4]
55
+ categories = predictions[:, 5]
 
56
 
57
+ # Crop the image of the license plate
58
+ cropped_image = img.crop(tuple(results.xyxy[0][0, :4].squeeze().tolist()[:4]))
59
+ st.image(cropped_image, caption='Plate detected')
60
 
61
+ # Extract text from the image
62
+ pixel_values = processor(cropped_image, return_tensors="pt").pixel_values
63
+ generated_ids = ocr_model.generate(pixel_values)
64
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
65
 
66
+ st.write("Detected License Plate Text:", generated_text)
67
 
68
+ # Convert the text to audio
69
+ inputs = tokenizer(generated_text, return_tensors="pt")
70
+ with torch.no_grad():
71
+ output = tts_model(**inputs).waveform
72
+ st.audio(output.numpy(), format="audio/wav", sample_rate=tts_model.config.sampling_rate)
73
 
74
  if __name__ == "__main__":
75
+ main()