cesar commited on
Commit
9472d8d
·
1 Parent(s): fe50644

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +113 -0
  2. packages.txt +1 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021, Mindee.
2
+
3
+ # This program is licensed under the Apache License version 2.
4
+ # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5
+
6
+ import os
7
+
8
+ import matplotlib.pyplot as plt
9
+ import streamlit as st
10
+
11
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
12
+
13
+ import cv2
14
+ import tensorflow as tf
15
+
16
+ gpu_devices = tf.config.experimental.list_physical_devices('GPU')
17
+ if any(gpu_devices):
18
+ tf.config.experimental.set_memory_growth(gpu_devices[0], True)
19
+
20
+ from doctr.io import DocumentFile
21
+ from doctr.models import ocr_predictor
22
+ from doctr.utils.visualization import visualize_page
23
+
24
+ DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"]
25
+ RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]
26
+
27
+
28
+ def main():
29
+
30
+ # Wide mode
31
+ st.set_page_config(layout="wide")
32
+
33
+ # Designing the interface
34
+ st.title("docTR: Document Text Recognition")
35
+ # For newline
36
+ st.write('\n')
37
+ #
38
+ st.write('Find more info at: https://github.com/mindee/doctr')
39
+ # For newline
40
+ st.write('\n')
41
+ # Instructions
42
+ st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
43
+ # Set the columns
44
+ cols = st.beta_columns((1, 1, 1, 1))
45
+ cols[0].subheader("Input page")
46
+ cols[1].subheader("Segmentation heatmap")
47
+ cols[2].subheader("OCR output")
48
+ cols[3].subheader("Page reconstitution")
49
+
50
+ # Sidebar
51
+ # File selection
52
+ st.sidebar.title("Document selection")
53
+ # Disabling warning
54
+ st.set_option('deprecation.showfileUploaderEncoding', False)
55
+ # Choose your own image
56
+ uploaded_file = st.sidebar.file_uploader("Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
57
+ if uploaded_file is not None:
58
+ if uploaded_file.name.endswith('.pdf'):
59
+ doc = DocumentFile.from_pdf(uploaded_file.read())
60
+ else:
61
+ doc = DocumentFile.from_images(uploaded_file.read())
62
+ page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
63
+ cols[0].image(doc[page_idx])
64
+
65
+ # Model selection
66
+ st.sidebar.title("Model selection")
67
+ det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
68
+ reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
69
+
70
+ # For newline
71
+ st.sidebar.write('\n')
72
+
73
+ if st.sidebar.button("Analyze page"):
74
+
75
+ if uploaded_file is None:
76
+ st.sidebar.write("Please upload a document")
77
+
78
+ else:
79
+ with st.spinner('Loading model...'):
80
+ predictor = ocr_predictor(det_arch, reco_arch, pretrained=True)
81
+
82
+ with st.spinner('Analyzing...'):
83
+
84
+ # Forward the image to the model
85
+ processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]])
86
+ out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
87
+ seg_map = out["out_map"]
88
+ seg_map = tf.squeeze(seg_map[0, ...], axis=[2])
89
+ seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
90
+ interpolation=cv2.INTER_LINEAR)
91
+ # Plot the raw heatmap
92
+ fig, ax = plt.subplots()
93
+ ax.imshow(seg_map)
94
+ ax.axis('off')
95
+ cols[1].pyplot(fig)
96
+
97
+ # Plot OCR output
98
+ out = predictor([doc[page_idx]])
99
+ fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False)
100
+ cols[2].pyplot(fig)
101
+
102
+ # Page reconsitution under input page
103
+ page_export = out.pages[0].export()
104
+ img = out.pages[0].synthesize()
105
+ cols[3].image(img, clamp=True)
106
+
107
+ # Display JSON
108
+ st.markdown("\nHere are your analysis results in JSON format:")
109
+ st.json(page_export)
110
+
111
+
112
+ if __name__ == '__main__':
113
+ main()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -e git+https://github.com/mindee/doctr.git#egg=python-doctr[tf]
2
+ streamlit>=0.65.0
3
+ PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12,!=1.19.5