Spaces:
Runtime error
Runtime error
paddleocr
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +1 -0
- app.py +93 -35
- ocr/.gitignore +129 -0
- ocr/README.md +1 -0
- ocr/__init__.py +0 -0
- ocr/ch_PP-OCRv3_det_infer/inference.pdiparams +3 -0
- ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info +0 -0
- ocr/ch_PP-OCRv3_det_infer/inference.pdmodel +3 -0
- ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams +3 -0
- ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info +0 -0
- ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel +3 -0
- ocr/detector.py +248 -0
- ocr/inference.py +68 -0
- ocr/postprocess/__init__.py +66 -0
- ocr/postprocess/cls_postprocess.py +30 -0
- ocr/postprocess/db_postprocess.py +207 -0
- ocr/postprocess/east_postprocess.py +122 -0
- ocr/postprocess/extract_textpoint_fast.py +464 -0
- ocr/postprocess/extract_textpoint_slow.py +608 -0
- ocr/postprocess/fce_postprocess.py +234 -0
- ocr/postprocess/locality_aware_nms.py +198 -0
- ocr/postprocess/pg_postprocess.py +189 -0
- ocr/postprocess/poly_nms.py +132 -0
- ocr/postprocess/pse_postprocess/__init__.py +1 -0
- ocr/postprocess/pse_postprocess/pse/__init__.py +20 -0
- ocr/postprocess/pse_postprocess/pse/pse.pyx +72 -0
- ocr/postprocess/pse_postprocess/pse/setup.py +19 -0
- ocr/postprocess/pse_postprocess/pse_postprocess.py +100 -0
- ocr/postprocess/rec_postprocess.py +731 -0
- ocr/postprocess/sast_postprocess.py +355 -0
- ocr/postprocess/vqa_token_re_layoutlm_postprocess.py +36 -0
- ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py +96 -0
- ocr/ppocr/__init__.py +0 -0
- ocr/ppocr/data/__init__.py +79 -0
- ocr/ppocr/data/collate_fn.py +59 -0
- ocr/ppocr/data/imaug/ColorJitter.py +14 -0
- ocr/ppocr/data/imaug/__init__.py +61 -0
- ocr/ppocr/data/imaug/copy_paste.py +167 -0
- ocr/ppocr/data/imaug/east_process.py +427 -0
- ocr/ppocr/data/imaug/fce_aug.py +563 -0
- ocr/ppocr/data/imaug/fce_targets.py +671 -0
- ocr/ppocr/data/imaug/gen_table_mask.py +228 -0
- ocr/ppocr/data/imaug/iaa_augment.py +72 -0
- ocr/ppocr/data/imaug/label_ops.py +1046 -0
- ocr/ppocr/data/imaug/make_border_map.py +155 -0
- ocr/ppocr/data/imaug/make_pse_gt.py +88 -0
- ocr/ppocr/data/imaug/make_shrink_map.py +100 -0
- ocr/ppocr/data/imaug/operators.py +458 -0
- ocr/ppocr/data/imaug/pg_process.py +961 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.pdiparams filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.pdmodel filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -7,6 +7,7 @@ sdk: gradio
|
|
7 |
sdk_version: 3.17.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
7 |
sdk_version: 3.17.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
duplicated_from: mertcobanov/deprem-ocr-2
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1 |
import gradio as gr
|
2 |
-
from easyocr import Reader
|
3 |
-
from PIL import Image
|
4 |
-
import io
|
5 |
import json
|
6 |
import csv
|
7 |
import openai
|
@@ -9,18 +6,64 @@ import ast
|
|
9 |
import os
|
10 |
from deta import Deta
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
openai.api_key = os.getenv('API_KEY')
|
14 |
-
reader = Reader(["tr"])
|
15 |
|
16 |
def get_parsed_address(input_img):
|
17 |
|
18 |
address_full_text = get_text(input_img)
|
19 |
return openai_response(address_full_text)
|
20 |
-
|
21 |
|
22 |
def get_text(input_img):
|
23 |
-
|
|
|
|
|
24 |
return " ".join(result)
|
25 |
|
26 |
|
@@ -38,9 +81,10 @@ def get_json(mahalle, il, sokak, apartman):
|
|
38 |
dump = json.dumps(adres, indent=4, ensure_ascii=False)
|
39 |
return dump
|
40 |
|
|
|
41 |
def write_db(data_dict):
|
42 |
# 2) initialize with a project key
|
43 |
-
deta_key = os.getenv(
|
44 |
deta = Deta(deta_key)
|
45 |
|
46 |
# 3) create and use as many DBs as you want!
|
@@ -53,16 +97,17 @@ def text_dict(input):
|
|
53 |
write_db(eval_result)
|
54 |
|
55 |
return (
|
56 |
-
str(eval_result[
|
57 |
-
str(eval_result[
|
58 |
-
str(eval_result[
|
59 |
-
str(eval_result[
|
60 |
-
str(eval_result[
|
61 |
-
str(eval_result[
|
62 |
-
str(eval_result[
|
63 |
-
str(eval_result[
|
64 |
)
|
65 |
-
|
|
|
66 |
def openai_response(ocr_input):
|
67 |
prompt = f"""Tabular Data Extraction You are a highly intelligent and accurate tabular data extractor from
|
68 |
plain text input and especially from emergency text that carries address information, your inputs can be text
|
@@ -91,28 +136,31 @@ def openai_response(ocr_input):
|
|
91 |
resp = eval(resp.replace("'{", "{").replace("}'", "}"))
|
92 |
resp["input"] = ocr_input
|
93 |
dict_keys = [
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
]
|
104 |
for key in dict_keys:
|
105 |
if key not in resp.keys():
|
106 |
-
resp[key] =
|
107 |
return resp
|
108 |
|
109 |
|
110 |
with gr.Blocks() as demo:
|
111 |
gr.Markdown(
|
112 |
-
|
113 |
# Enkaz Bildirme Uygulaması
|
114 |
-
"""
|
115 |
-
|
|
|
|
|
|
|
116 |
with gr.Row():
|
117 |
img_area = gr.Image(label="Ekran Görüntüsü yükleyin 👇")
|
118 |
ocr_result = gr.Textbox(label="Metin yükleyin 👇 ")
|
@@ -133,13 +181,23 @@ with gr.Blocks() as demo:
|
|
133 |
with gr.Row():
|
134 |
no = gr.Textbox(label="Kapı No")
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
|
141 |
-
open_api_text.change(
|
|
|
|
|
|
|
|
|
142 |
|
143 |
|
144 |
if __name__ == "__main__":
|
145 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
import json
|
3 |
import csv
|
4 |
import openai
|
|
|
6 |
import os
|
7 |
from deta import Deta
|
8 |
|
9 |
+
import numpy as np
|
10 |
+
from ocr import utility
|
11 |
+
from ocr.detector import TextDetector
|
12 |
+
from ocr.recognizer import TextRecognizer
|
13 |
+
|
14 |
+
# Global Detector and Recognizer
|
15 |
+
args = utility.parse_args()
|
16 |
+
text_recognizer = TextRecognizer(args)
|
17 |
+
text_detector = TextDetector(args)
|
18 |
+
|
19 |
+
openai.api_key = os.getenv("API_KEY")
|
20 |
+
|
21 |
+
args = utility.parse_args()
|
22 |
+
text_recognizer = TextRecognizer(args)
|
23 |
+
text_detector = TextDetector(args)
|
24 |
+
|
25 |
+
|
26 |
+
def apply_ocr(img):
|
27 |
+
# Detect text regions
|
28 |
+
dt_boxes, _ = text_detector(img)
|
29 |
+
|
30 |
+
boxes = []
|
31 |
+
for box in dt_boxes:
|
32 |
+
p1, p2, p3, p4 = box
|
33 |
+
x1 = min(p1[0], p2[0], p3[0], p4[0])
|
34 |
+
y1 = min(p1[1], p2[1], p3[1], p4[1])
|
35 |
+
x2 = max(p1[0], p2[0], p3[0], p4[0])
|
36 |
+
y2 = max(p1[1], p2[1], p3[1], p4[1])
|
37 |
+
boxes.append([x1, y1, x2, y2])
|
38 |
+
|
39 |
+
# Recognize text
|
40 |
+
img_list = []
|
41 |
+
for i in range(len(boxes)):
|
42 |
+
x1, y1, x2, y2 = map(int, boxes[i])
|
43 |
+
img_list.append(img.copy()[y1:y2, x1:x2])
|
44 |
+
img_list.reverse()
|
45 |
+
|
46 |
+
rec_res, _ = text_recognizer(img_list)
|
47 |
+
|
48 |
+
# Postprocess
|
49 |
+
total_text = ""
|
50 |
+
for i in range(len(rec_res)):
|
51 |
+
total_text += rec_res[i][0] + " "
|
52 |
+
|
53 |
+
total_text = total_text.strip()
|
54 |
+
return total_text
|
55 |
|
|
|
|
|
56 |
|
57 |
def get_parsed_address(input_img):
|
58 |
|
59 |
address_full_text = get_text(input_img)
|
60 |
return openai_response(address_full_text)
|
61 |
+
|
62 |
|
63 |
def get_text(input_img):
|
64 |
+
input_img = np.array(input_img)
|
65 |
+
result = apply_ocr(input_img)
|
66 |
+
print(result)
|
67 |
return " ".join(result)
|
68 |
|
69 |
|
|
|
81 |
dump = json.dumps(adres, indent=4, ensure_ascii=False)
|
82 |
return dump
|
83 |
|
84 |
+
|
85 |
def write_db(data_dict):
|
86 |
# 2) initialize with a project key
|
87 |
+
deta_key = os.getenv("DETA_KEY")
|
88 |
deta = Deta(deta_key)
|
89 |
|
90 |
# 3) create and use as many DBs as you want!
|
|
|
97 |
write_db(eval_result)
|
98 |
|
99 |
return (
|
100 |
+
str(eval_result["city"]),
|
101 |
+
str(eval_result["distinct"]),
|
102 |
+
str(eval_result["neighbourhood"]),
|
103 |
+
str(eval_result["street"]),
|
104 |
+
str(eval_result["address"]),
|
105 |
+
str(eval_result["tel"]),
|
106 |
+
str(eval_result["name_surname"]),
|
107 |
+
str(eval_result["no"]),
|
108 |
)
|
109 |
+
|
110 |
+
|
111 |
def openai_response(ocr_input):
|
112 |
prompt = f"""Tabular Data Extraction You are a highly intelligent and accurate tabular data extractor from
|
113 |
plain text input and especially from emergency text that carries address information, your inputs can be text
|
|
|
136 |
resp = eval(resp.replace("'{", "{").replace("}'", "}"))
|
137 |
resp["input"] = ocr_input
|
138 |
dict_keys = [
|
139 |
+
"city",
|
140 |
+
"distinct",
|
141 |
+
"neighbourhood",
|
142 |
+
"street",
|
143 |
+
"no",
|
144 |
+
"tel",
|
145 |
+
"name_surname",
|
146 |
+
"address",
|
147 |
+
"input",
|
148 |
]
|
149 |
for key in dict_keys:
|
150 |
if key not in resp.keys():
|
151 |
+
resp[key] = ""
|
152 |
return resp
|
153 |
|
154 |
|
155 |
with gr.Blocks() as demo:
|
156 |
gr.Markdown(
|
157 |
+
"""
|
158 |
# Enkaz Bildirme Uygulaması
|
159 |
+
"""
|
160 |
+
)
|
161 |
+
gr.Markdown(
|
162 |
+
"Bu uygulamada ekran görüntüsü sürükleyip bırakarak AFAD'a enkaz bildirimi yapabilirsiniz. Mesajı metin olarak da girebilirsiniz, tam adresi ayrıştırıp döndürür. API olarak kullanmak isterseniz sayfanın en altında use via api'ya tıklayın."
|
163 |
+
)
|
164 |
with gr.Row():
|
165 |
img_area = gr.Image(label="Ekran Görüntüsü yükleyin 👇")
|
166 |
ocr_result = gr.Textbox(label="Metin yükleyin 👇 ")
|
|
|
181 |
with gr.Row():
|
182 |
no = gr.Textbox(label="Kapı No")
|
183 |
|
184 |
+
submit_button.click(
|
185 |
+
get_parsed_address,
|
186 |
+
inputs=img_area,
|
187 |
+
outputs=open_api_text,
|
188 |
+
api_name="upload_image",
|
189 |
+
)
|
190 |
|
191 |
+
ocr_result.change(
|
192 |
+
openai_response, ocr_result, open_api_text, api_name="upload-text"
|
193 |
+
)
|
194 |
|
195 |
+
open_api_text.change(
|
196 |
+
text_dict,
|
197 |
+
open_api_text,
|
198 |
+
[city, distinct, neighbourhood, street, address, tel, name_surname, no],
|
199 |
+
)
|
200 |
|
201 |
|
202 |
if __name__ == "__main__":
|
203 |
+
demo.launch()
|
ocr/.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
ocr/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# deprem-ocr
|
ocr/__init__.py
ADDED
File without changes
|
ocr/ch_PP-OCRv3_det_infer/inference.pdiparams
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e9518c6ab706fe87842a8de1c098f990e67f9212b67c9ef8bc4bca6dc17b91a
|
3 |
+
size 2377917
|
ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info
ADDED
Binary file (26.4 kB). View file
|
|
ocr/ch_PP-OCRv3_det_infer/inference.pdmodel
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74b075e6cfbc8206dab2eee86a6a8bd015a7be612b2bf6d1a1ef878d31df84f7
|
3 |
+
size 1413260
|
ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d99d4279f7c64471b8f0be426ee09a46c0f1ecb344406bf0bb9571f670e8d0c7
|
3 |
+
size 10614098
|
ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info
ADDED
Binary file (22 kB). View file
|
|
ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9beb0b9520d34bde2a0f92581ed64db7e4d6c76abead8b859189ea72db9ee20
|
3 |
+
size 1266415
|
ocr/detector.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
sys.path.append(__dir__)
|
6 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
|
7 |
+
|
8 |
+
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
|
9 |
+
|
10 |
+
import json
|
11 |
+
import sys
|
12 |
+
import time
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import utility
|
18 |
+
from postprocess import build_post_process
|
19 |
+
from ppocr.data import create_operators, transform
|
20 |
+
|
21 |
+
|
22 |
+
class TextDetector(object):
|
23 |
+
def __init__(self, args):
|
24 |
+
self.args = args
|
25 |
+
self.det_algorithm = args.det_algorithm
|
26 |
+
self.use_onnx = args.use_onnx
|
27 |
+
pre_process_list = [
|
28 |
+
{
|
29 |
+
"DetResizeForTest": {
|
30 |
+
"limit_side_len": args.det_limit_side_len,
|
31 |
+
"limit_type": args.det_limit_type,
|
32 |
+
}
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"NormalizeImage": {
|
36 |
+
"std": [0.229, 0.224, 0.225],
|
37 |
+
"mean": [0.485, 0.456, 0.406],
|
38 |
+
"scale": "1./255.",
|
39 |
+
"order": "hwc",
|
40 |
+
}
|
41 |
+
},
|
42 |
+
{"ToCHWImage": None},
|
43 |
+
{"KeepKeys": {"keep_keys": ["image", "shape"]}},
|
44 |
+
]
|
45 |
+
postprocess_params = {}
|
46 |
+
if self.det_algorithm == "DB":
|
47 |
+
postprocess_params["name"] = "DBPostProcess"
|
48 |
+
postprocess_params["thresh"] = args.det_db_thresh
|
49 |
+
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
50 |
+
postprocess_params["max_candidates"] = 1000
|
51 |
+
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
52 |
+
postprocess_params["use_dilation"] = args.use_dilation
|
53 |
+
postprocess_params["score_mode"] = args.det_db_score_mode
|
54 |
+
elif self.det_algorithm == "EAST":
|
55 |
+
postprocess_params["name"] = "EASTPostProcess"
|
56 |
+
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
57 |
+
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
58 |
+
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
59 |
+
elif self.det_algorithm == "SAST":
|
60 |
+
pre_process_list[0] = {
|
61 |
+
"DetResizeForTest": {"resize_long": args.det_limit_side_len}
|
62 |
+
}
|
63 |
+
postprocess_params["name"] = "SASTPostProcess"
|
64 |
+
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
65 |
+
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
66 |
+
self.det_sast_polygon = args.det_sast_polygon
|
67 |
+
if self.det_sast_polygon:
|
68 |
+
postprocess_params["sample_pts_num"] = 6
|
69 |
+
postprocess_params["expand_scale"] = 1.2
|
70 |
+
postprocess_params["shrink_ratio_of_width"] = 0.2
|
71 |
+
else:
|
72 |
+
postprocess_params["sample_pts_num"] = 2
|
73 |
+
postprocess_params["expand_scale"] = 1.0
|
74 |
+
postprocess_params["shrink_ratio_of_width"] = 0.3
|
75 |
+
elif self.det_algorithm == "PSE":
|
76 |
+
postprocess_params["name"] = "PSEPostProcess"
|
77 |
+
postprocess_params["thresh"] = args.det_pse_thresh
|
78 |
+
postprocess_params["box_thresh"] = args.det_pse_box_thresh
|
79 |
+
postprocess_params["min_area"] = args.det_pse_min_area
|
80 |
+
postprocess_params["box_type"] = args.det_pse_box_type
|
81 |
+
postprocess_params["scale"] = args.det_pse_scale
|
82 |
+
self.det_pse_box_type = args.det_pse_box_type
|
83 |
+
elif self.det_algorithm == "FCE":
|
84 |
+
pre_process_list[0] = {"DetResizeForTest": {"rescale_img": [1080, 736]}}
|
85 |
+
postprocess_params["name"] = "FCEPostProcess"
|
86 |
+
postprocess_params["scales"] = args.scales
|
87 |
+
postprocess_params["alpha"] = args.alpha
|
88 |
+
postprocess_params["beta"] = args.beta
|
89 |
+
postprocess_params["fourier_degree"] = args.fourier_degree
|
90 |
+
postprocess_params["box_type"] = args.det_fce_box_type
|
91 |
+
|
92 |
+
self.preprocess_op = create_operators(pre_process_list)
|
93 |
+
self.postprocess_op = build_post_process(postprocess_params)
|
94 |
+
(
|
95 |
+
self.predictor,
|
96 |
+
self.input_tensor,
|
97 |
+
self.output_tensors,
|
98 |
+
self.config,
|
99 |
+
) = utility.create_predictor(args, "det")
|
100 |
+
|
101 |
+
if self.use_onnx:
|
102 |
+
img_h, img_w = self.input_tensor.shape[2:]
|
103 |
+
if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
|
104 |
+
pre_process_list[0] = {
|
105 |
+
"DetResizeForTest": {"image_shape": [img_h, img_w]}
|
106 |
+
}
|
107 |
+
self.preprocess_op = create_operators(pre_process_list)
|
108 |
+
|
109 |
+
def order_points_clockwise(self, pts):
|
110 |
+
rect = np.zeros((4, 2), dtype="float32")
|
111 |
+
s = pts.sum(axis=1)
|
112 |
+
rect[0] = pts[np.argmin(s)]
|
113 |
+
rect[2] = pts[np.argmax(s)]
|
114 |
+
diff = np.diff(pts, axis=1)
|
115 |
+
rect[1] = pts[np.argmin(diff)]
|
116 |
+
rect[3] = pts[np.argmax(diff)]
|
117 |
+
return rect
|
118 |
+
|
119 |
+
def clip_det_res(self, points, img_height, img_width):
|
120 |
+
for pno in range(points.shape[0]):
|
121 |
+
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
122 |
+
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
123 |
+
return points
|
124 |
+
|
125 |
+
def filter_tag_det_res(self, dt_boxes, image_shape):
|
126 |
+
img_height, img_width = image_shape[0:2]
|
127 |
+
dt_boxes_new = []
|
128 |
+
for box in dt_boxes:
|
129 |
+
box = self.order_points_clockwise(box)
|
130 |
+
box = self.clip_det_res(box, img_height, img_width)
|
131 |
+
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
132 |
+
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
133 |
+
if rect_width <= 3 or rect_height <= 3:
|
134 |
+
continue
|
135 |
+
dt_boxes_new.append(box)
|
136 |
+
dt_boxes = np.array(dt_boxes_new)
|
137 |
+
return dt_boxes
|
138 |
+
|
139 |
+
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
140 |
+
img_height, img_width = image_shape[0:2]
|
141 |
+
dt_boxes_new = []
|
142 |
+
for box in dt_boxes:
|
143 |
+
box = self.clip_det_res(box, img_height, img_width)
|
144 |
+
dt_boxes_new.append(box)
|
145 |
+
dt_boxes = np.array(dt_boxes_new)
|
146 |
+
return dt_boxes
|
147 |
+
|
148 |
+
def __call__(self, img):
|
149 |
+
ori_im = img.copy()
|
150 |
+
data = {"image": img}
|
151 |
+
|
152 |
+
st = time.time()
|
153 |
+
|
154 |
+
data = transform(data, self.preprocess_op)
|
155 |
+
img, shape_list = data
|
156 |
+
if img is None:
|
157 |
+
return None, 0
|
158 |
+
img = np.expand_dims(img, axis=0)
|
159 |
+
shape_list = np.expand_dims(shape_list, axis=0)
|
160 |
+
img = img.copy()
|
161 |
+
|
162 |
+
if self.use_onnx:
|
163 |
+
input_dict = {}
|
164 |
+
input_dict[self.input_tensor.name] = img
|
165 |
+
outputs = self.predictor.run(self.output_tensors, input_dict)
|
166 |
+
else:
|
167 |
+
self.input_tensor.copy_from_cpu(img)
|
168 |
+
self.predictor.run()
|
169 |
+
outputs = []
|
170 |
+
for output_tensor in self.output_tensors:
|
171 |
+
output = output_tensor.copy_to_cpu()
|
172 |
+
outputs.append(output)
|
173 |
+
|
174 |
+
preds = {}
|
175 |
+
if self.det_algorithm == "EAST":
|
176 |
+
preds["f_geo"] = outputs[0]
|
177 |
+
preds["f_score"] = outputs[1]
|
178 |
+
elif self.det_algorithm == "SAST":
|
179 |
+
preds["f_border"] = outputs[0]
|
180 |
+
preds["f_score"] = outputs[1]
|
181 |
+
preds["f_tco"] = outputs[2]
|
182 |
+
preds["f_tvo"] = outputs[3]
|
183 |
+
elif self.det_algorithm in ["DB", "PSE"]:
|
184 |
+
preds["maps"] = outputs[0]
|
185 |
+
elif self.det_algorithm == "FCE":
|
186 |
+
for i, output in enumerate(outputs):
|
187 |
+
preds["level_{}".format(i)] = output
|
188 |
+
else:
|
189 |
+
raise NotImplementedError
|
190 |
+
|
191 |
+
# self.predictor.try_shrink_memory()
|
192 |
+
post_result = self.postprocess_op(preds, shape_list)
|
193 |
+
dt_boxes = post_result[0]["points"]
|
194 |
+
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
|
195 |
+
self.det_algorithm in ["PSE", "FCE"]
|
196 |
+
and self.postprocess_op.box_type == "poly"
|
197 |
+
):
|
198 |
+
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
199 |
+
else:
|
200 |
+
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
201 |
+
|
202 |
+
et = time.time()
|
203 |
+
return dt_boxes, et - st
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
args = utility.parse_args()
|
208 |
+
image_file_list = ["images/y.png"]
|
209 |
+
text_detector = TextDetector(args)
|
210 |
+
count = 0
|
211 |
+
total_time = 0
|
212 |
+
draw_img_save = "./inference_results"
|
213 |
+
|
214 |
+
if args.warmup:
|
215 |
+
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
216 |
+
for i in range(2):
|
217 |
+
res = text_detector(img)
|
218 |
+
|
219 |
+
if not os.path.exists(draw_img_save):
|
220 |
+
os.makedirs(draw_img_save)
|
221 |
+
|
222 |
+
save_results = []
|
223 |
+
for image_file in image_file_list:
|
224 |
+
img = cv2.imread(image_file)
|
225 |
+
|
226 |
+
for _ in range(10):
|
227 |
+
st = time.time()
|
228 |
+
dt_boxes, _ = text_detector(img)
|
229 |
+
elapse = time.time() - st
|
230 |
+
print(elapse * 1000)
|
231 |
+
if count > 0:
|
232 |
+
total_time += elapse
|
233 |
+
count += 1
|
234 |
+
save_pred = (
|
235 |
+
os.path.basename(image_file)
|
236 |
+
+ "\t"
|
237 |
+
+ str(json.dumps([x.tolist() for x in dt_boxes]))
|
238 |
+
+ "\n"
|
239 |
+
)
|
240 |
+
save_results.append(save_pred)
|
241 |
+
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
242 |
+
img_name_pure = os.path.split(image_file)[-1]
|
243 |
+
img_path = os.path.join(draw_img_save, "det_res_{}".format(img_name_pure))
|
244 |
+
cv2.imwrite(img_path, src_im)
|
245 |
+
|
246 |
+
with open(os.path.join(draw_img_save, "det_results.txt"), "w") as f:
|
247 |
+
f.writelines(save_results)
|
248 |
+
f.close()
|
ocr/inference.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
4 |
+
|
5 |
+
import time
|
6 |
+
import requests
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
import utility
|
10 |
+
from detector import *
|
11 |
+
from recognizer import *
|
12 |
+
|
13 |
+
# Global Detector and Recognizer
|
14 |
+
args = utility.parse_args()
|
15 |
+
text_recognizer = TextRecognizer(args)
|
16 |
+
text_detector = TextDetector(args)
|
17 |
+
|
18 |
+
|
19 |
+
def apply_ocr(img):
|
20 |
+
# Detect text regions
|
21 |
+
dt_boxes, _ = text_detector(img)
|
22 |
+
|
23 |
+
boxes = []
|
24 |
+
for box in dt_boxes:
|
25 |
+
p1, p2, p3, p4 = box
|
26 |
+
x1 = min(p1[0], p2[0], p3[0], p4[0])
|
27 |
+
y1 = min(p1[1], p2[1], p3[1], p4[1])
|
28 |
+
x2 = max(p1[0], p2[0], p3[0], p4[0])
|
29 |
+
y2 = max(p1[1], p2[1], p3[1], p4[1])
|
30 |
+
boxes.append([x1, y1, x2, y2])
|
31 |
+
|
32 |
+
# Recognize text
|
33 |
+
img_list = []
|
34 |
+
for i in range(len(boxes)):
|
35 |
+
x1, y1, x2, y2 = map(int, boxes[i])
|
36 |
+
img_list.append(img.copy()[y1:y2, x1:x2])
|
37 |
+
img_list.reverse()
|
38 |
+
|
39 |
+
rec_res, _ = text_recognizer(img_list)
|
40 |
+
|
41 |
+
# Postprocess
|
42 |
+
total_text = ""
|
43 |
+
table = dict()
|
44 |
+
for i in range(len(rec_res)):
|
45 |
+
table[i] = {
|
46 |
+
"text": rec_res[i][0],
|
47 |
+
}
|
48 |
+
total_text += rec_res[i][0] + " "
|
49 |
+
|
50 |
+
total_text = total_text.strip()
|
51 |
+
return total_text
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
image_url = "https://i.ibb.co/kQvHGjj/aewrg.png"
|
56 |
+
response = requests.get(image_url)
|
57 |
+
img = np.array(Image.open(BytesIO(response.content)).convert("RGB"))
|
58 |
+
|
59 |
+
t0 = time.time()
|
60 |
+
epoch = 1
|
61 |
+
for _ in range(epoch):
|
62 |
+
ocr_text = apply_ocr(img)
|
63 |
+
print("Elapsed time:", (time.time() - t0) * 1000 / epoch, "ms")
|
64 |
+
print("Output:", ocr_text)
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
main()
|
ocr/postprocess/__init__.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import copy
|
4 |
+
|
5 |
+
__all__ = ["build_post_process"]
|
6 |
+
|
7 |
+
from .cls_postprocess import ClsPostProcess
|
8 |
+
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
9 |
+
from .east_postprocess import EASTPostProcess
|
10 |
+
from .fce_postprocess import FCEPostProcess
|
11 |
+
from .pg_postprocess import PGPostProcess
|
12 |
+
from .rec_postprocess import (
|
13 |
+
AttnLabelDecode,
|
14 |
+
CTCLabelDecode,
|
15 |
+
DistillationCTCLabelDecode,
|
16 |
+
NRTRLabelDecode,
|
17 |
+
PRENLabelDecode,
|
18 |
+
SARLabelDecode,
|
19 |
+
SEEDLabelDecode,
|
20 |
+
SRNLabelDecode,
|
21 |
+
TableLabelDecode,
|
22 |
+
)
|
23 |
+
from .sast_postprocess import SASTPostProcess
|
24 |
+
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
|
25 |
+
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
|
26 |
+
|
27 |
+
|
28 |
+
def build_post_process(config, global_config=None):
|
29 |
+
support_dict = [
|
30 |
+
"DBPostProcess",
|
31 |
+
"EASTPostProcess",
|
32 |
+
"SASTPostProcess",
|
33 |
+
"FCEPostProcess",
|
34 |
+
"CTCLabelDecode",
|
35 |
+
"AttnLabelDecode",
|
36 |
+
"ClsPostProcess",
|
37 |
+
"SRNLabelDecode",
|
38 |
+
"PGPostProcess",
|
39 |
+
"DistillationCTCLabelDecode",
|
40 |
+
"TableLabelDecode",
|
41 |
+
"DistillationDBPostProcess",
|
42 |
+
"NRTRLabelDecode",
|
43 |
+
"SARLabelDecode",
|
44 |
+
"SEEDLabelDecode",
|
45 |
+
"VQASerTokenLayoutLMPostProcess",
|
46 |
+
"VQAReTokenLayoutLMPostProcess",
|
47 |
+
"PRENLabelDecode",
|
48 |
+
"DistillationSARLabelDecode",
|
49 |
+
]
|
50 |
+
|
51 |
+
if config["name"] == "PSEPostProcess":
|
52 |
+
from .pse_postprocess import PSEPostProcess
|
53 |
+
|
54 |
+
support_dict.append("PSEPostProcess")
|
55 |
+
|
56 |
+
config = copy.deepcopy(config)
|
57 |
+
module_name = config.pop("name")
|
58 |
+
if module_name == "None":
|
59 |
+
return
|
60 |
+
if global_config is not None:
|
61 |
+
config.update(global_config)
|
62 |
+
assert module_name in support_dict, Exception(
|
63 |
+
"post process only support {}".format(support_dict)
|
64 |
+
)
|
65 |
+
module_class = eval(module_name)(**config)
|
66 |
+
return module_class
|
ocr/postprocess/cls_postprocess.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import paddle
|
2 |
+
|
3 |
+
|
4 |
+
class ClsPostProcess(object):
|
5 |
+
"""Convert between text-label and text-index"""
|
6 |
+
|
7 |
+
def __init__(self, label_list=None, key=None, **kwargs):
|
8 |
+
super(ClsPostProcess, self).__init__()
|
9 |
+
self.label_list = label_list
|
10 |
+
self.key = key
|
11 |
+
|
12 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
13 |
+
if self.key is not None:
|
14 |
+
preds = preds[self.key]
|
15 |
+
|
16 |
+
label_list = self.label_list
|
17 |
+
if label_list is None:
|
18 |
+
label_list = {idx: idx for idx in range(preds.shape[-1])}
|
19 |
+
|
20 |
+
if isinstance(preds, paddle.Tensor):
|
21 |
+
preds = preds.numpy()
|
22 |
+
|
23 |
+
pred_idxs = preds.argmax(axis=1)
|
24 |
+
decode_out = [
|
25 |
+
(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
|
26 |
+
]
|
27 |
+
if label is None:
|
28 |
+
return decode_out
|
29 |
+
label = [(label_list[idx], 1.0) for idx in label]
|
30 |
+
return decode_out, label
|
ocr/postprocess/db_postprocess.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import paddle
|
6 |
+
import pyclipper
|
7 |
+
from shapely.geometry import Polygon
|
8 |
+
|
9 |
+
|
10 |
+
class DBPostProcess(object):
|
11 |
+
"""
|
12 |
+
The post process for Differentiable Binarization (DB).
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
thresh=0.3,
|
18 |
+
box_thresh=0.7,
|
19 |
+
max_candidates=1000,
|
20 |
+
unclip_ratio=2.0,
|
21 |
+
use_dilation=False,
|
22 |
+
score_mode="fast",
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
+
self.thresh = thresh
|
26 |
+
self.box_thresh = box_thresh
|
27 |
+
self.max_candidates = max_candidates
|
28 |
+
self.unclip_ratio = unclip_ratio
|
29 |
+
self.min_size = 3
|
30 |
+
self.score_mode = score_mode
|
31 |
+
assert score_mode in [
|
32 |
+
"slow",
|
33 |
+
"fast",
|
34 |
+
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
35 |
+
|
36 |
+
self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
|
37 |
+
|
38 |
+
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
39 |
+
"""
|
40 |
+
_bitmap: single map with shape (1, H, W),
|
41 |
+
whose values are binarized as {0, 1}
|
42 |
+
"""
|
43 |
+
|
44 |
+
bitmap = _bitmap
|
45 |
+
height, width = bitmap.shape
|
46 |
+
|
47 |
+
outs = cv2.findContours(
|
48 |
+
(bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
|
49 |
+
)
|
50 |
+
if len(outs) == 3:
|
51 |
+
img, contours, _ = outs[0], outs[1], outs[2]
|
52 |
+
elif len(outs) == 2:
|
53 |
+
contours, _ = outs[0], outs[1]
|
54 |
+
|
55 |
+
num_contours = min(len(contours), self.max_candidates)
|
56 |
+
|
57 |
+
boxes = []
|
58 |
+
scores = []
|
59 |
+
for index in range(num_contours):
|
60 |
+
contour = contours[index]
|
61 |
+
points, sside = self.get_mini_boxes(contour)
|
62 |
+
if sside < self.min_size:
|
63 |
+
continue
|
64 |
+
points = np.array(points)
|
65 |
+
if self.score_mode == "fast":
|
66 |
+
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
67 |
+
else:
|
68 |
+
score = self.box_score_slow(pred, contour)
|
69 |
+
if self.box_thresh > score:
|
70 |
+
continue
|
71 |
+
|
72 |
+
box = self.unclip(points).reshape(-1, 1, 2)
|
73 |
+
box, sside = self.get_mini_boxes(box)
|
74 |
+
if sside < self.min_size + 2:
|
75 |
+
continue
|
76 |
+
box = np.array(box)
|
77 |
+
|
78 |
+
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
79 |
+
box[:, 1] = np.clip(
|
80 |
+
np.round(box[:, 1] / height * dest_height), 0, dest_height
|
81 |
+
)
|
82 |
+
boxes.append(box.astype(np.int16))
|
83 |
+
scores.append(score)
|
84 |
+
return np.array(boxes, dtype=np.int16), scores
|
85 |
+
|
86 |
+
def unclip(self, box):
|
87 |
+
unclip_ratio = self.unclip_ratio
|
88 |
+
poly = Polygon(box)
|
89 |
+
distance = poly.area * unclip_ratio / poly.length
|
90 |
+
offset = pyclipper.PyclipperOffset()
|
91 |
+
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
92 |
+
expanded = np.array(offset.Execute(distance))
|
93 |
+
return expanded
|
94 |
+
|
95 |
+
def get_mini_boxes(self, contour):
|
96 |
+
bounding_box = cv2.minAreaRect(contour)
|
97 |
+
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
98 |
+
|
99 |
+
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
100 |
+
if points[1][1] > points[0][1]:
|
101 |
+
index_1 = 0
|
102 |
+
index_4 = 1
|
103 |
+
else:
|
104 |
+
index_1 = 1
|
105 |
+
index_4 = 0
|
106 |
+
if points[3][1] > points[2][1]:
|
107 |
+
index_2 = 2
|
108 |
+
index_3 = 3
|
109 |
+
else:
|
110 |
+
index_2 = 3
|
111 |
+
index_3 = 2
|
112 |
+
|
113 |
+
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
|
114 |
+
return box, min(bounding_box[1])
|
115 |
+
|
116 |
+
def box_score_fast(self, bitmap, _box):
|
117 |
+
"""
|
118 |
+
box_score_fast: use bbox mean score as the mean score
|
119 |
+
"""
|
120 |
+
h, w = bitmap.shape[:2]
|
121 |
+
box = _box.copy()
|
122 |
+
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
|
123 |
+
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
|
124 |
+
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
|
125 |
+
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
|
126 |
+
|
127 |
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
128 |
+
box[:, 0] = box[:, 0] - xmin
|
129 |
+
box[:, 1] = box[:, 1] - ymin
|
130 |
+
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
131 |
+
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
|
132 |
+
|
133 |
+
def box_score_slow(self, bitmap, contour):
|
134 |
+
"""
|
135 |
+
box_score_slow: use polyon mean score as the mean score
|
136 |
+
"""
|
137 |
+
h, w = bitmap.shape[:2]
|
138 |
+
contour = contour.copy()
|
139 |
+
contour = np.reshape(contour, (-1, 2))
|
140 |
+
|
141 |
+
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
142 |
+
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
143 |
+
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
144 |
+
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
145 |
+
|
146 |
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
147 |
+
|
148 |
+
contour[:, 0] = contour[:, 0] - xmin
|
149 |
+
contour[:, 1] = contour[:, 1] - ymin
|
150 |
+
|
151 |
+
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
|
152 |
+
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
|
153 |
+
|
154 |
+
def __call__(self, outs_dict, shape_list):
|
155 |
+
pred = outs_dict["maps"]
|
156 |
+
if isinstance(pred, paddle.Tensor):
|
157 |
+
pred = pred.numpy()
|
158 |
+
pred = pred[:, 0, :, :]
|
159 |
+
segmentation = pred > self.thresh
|
160 |
+
|
161 |
+
boxes_batch = []
|
162 |
+
for batch_index in range(pred.shape[0]):
|
163 |
+
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
164 |
+
if self.dilation_kernel is not None:
|
165 |
+
mask = cv2.dilate(
|
166 |
+
np.array(segmentation[batch_index]).astype(np.uint8),
|
167 |
+
self.dilation_kernel,
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
mask = segmentation[batch_index]
|
171 |
+
boxes, scores = self.boxes_from_bitmap(
|
172 |
+
pred[batch_index], mask, src_w, src_h
|
173 |
+
)
|
174 |
+
|
175 |
+
boxes_batch.append({"points": boxes})
|
176 |
+
return boxes_batch
|
177 |
+
|
178 |
+
|
179 |
+
class DistillationDBPostProcess(object):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
model_name=["student"],
|
183 |
+
key=None,
|
184 |
+
thresh=0.3,
|
185 |
+
box_thresh=0.6,
|
186 |
+
max_candidates=1000,
|
187 |
+
unclip_ratio=1.5,
|
188 |
+
use_dilation=False,
|
189 |
+
score_mode="fast",
|
190 |
+
**kwargs
|
191 |
+
):
|
192 |
+
self.model_name = model_name
|
193 |
+
self.key = key
|
194 |
+
self.post_process = DBPostProcess(
|
195 |
+
thresh=thresh,
|
196 |
+
box_thresh=box_thresh,
|
197 |
+
max_candidates=max_candidates,
|
198 |
+
unclip_ratio=unclip_ratio,
|
199 |
+
use_dilation=use_dilation,
|
200 |
+
score_mode=score_mode,
|
201 |
+
)
|
202 |
+
|
203 |
+
def __call__(self, predicts, shape_list):
|
204 |
+
results = {}
|
205 |
+
for k in self.model_name:
|
206 |
+
results[k] = self.post_process(predicts[k], shape_list=shape_list)
|
207 |
+
return results
|
ocr/postprocess/east_postprocess.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import paddle
|
6 |
+
|
7 |
+
from .locality_aware_nms import nms_locality
|
8 |
+
|
9 |
+
|
10 |
+
class EASTPostProcess(object):
|
11 |
+
"""
|
12 |
+
The post process for EAST.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs):
|
16 |
+
|
17 |
+
self.score_thresh = score_thresh
|
18 |
+
self.cover_thresh = cover_thresh
|
19 |
+
self.nms_thresh = nms_thresh
|
20 |
+
|
21 |
+
def restore_rectangle_quad(self, origin, geometry):
|
22 |
+
"""
|
23 |
+
Restore rectangle from quadrangle.
|
24 |
+
"""
|
25 |
+
# quad
|
26 |
+
origin_concat = np.concatenate(
|
27 |
+
(origin, origin, origin, origin), axis=1
|
28 |
+
) # (n, 8)
|
29 |
+
pred_quads = origin_concat - geometry
|
30 |
+
pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
|
31 |
+
return pred_quads
|
32 |
+
|
33 |
+
def detect(
|
34 |
+
self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
restore text boxes from score map and geo map
|
38 |
+
"""
|
39 |
+
|
40 |
+
score_map = score_map[0]
|
41 |
+
geo_map = np.swapaxes(geo_map, 1, 0)
|
42 |
+
geo_map = np.swapaxes(geo_map, 1, 2)
|
43 |
+
# filter the score map
|
44 |
+
xy_text = np.argwhere(score_map > score_thresh)
|
45 |
+
if len(xy_text) == 0:
|
46 |
+
return []
|
47 |
+
# sort the text boxes via the y axis
|
48 |
+
xy_text = xy_text[np.argsort(xy_text[:, 0])]
|
49 |
+
# restore quad proposals
|
50 |
+
text_box_restored = self.restore_rectangle_quad(
|
51 |
+
xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]
|
52 |
+
)
|
53 |
+
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
|
54 |
+
boxes[:, :8] = text_box_restored.reshape((-1, 8))
|
55 |
+
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
|
56 |
+
|
57 |
+
try:
|
58 |
+
import lanms
|
59 |
+
|
60 |
+
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
|
61 |
+
except:
|
62 |
+
print(
|
63 |
+
"you should install lanms by pip3 install lanms-nova to speed up nms_locality"
|
64 |
+
)
|
65 |
+
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
|
66 |
+
if boxes.shape[0] == 0:
|
67 |
+
return []
|
68 |
+
# Here we filter some low score boxes by the average score map,
|
69 |
+
# this is different from the orginal paper.
|
70 |
+
for i, box in enumerate(boxes):
|
71 |
+
mask = np.zeros_like(score_map, dtype=np.uint8)
|
72 |
+
cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
|
73 |
+
boxes[i, 8] = cv2.mean(score_map, mask)[0]
|
74 |
+
boxes = boxes[boxes[:, 8] > cover_thresh]
|
75 |
+
return boxes
|
76 |
+
|
77 |
+
def sort_poly(self, p):
|
78 |
+
"""
|
79 |
+
Sort polygons.
|
80 |
+
"""
|
81 |
+
min_axis = np.argmin(np.sum(p, axis=1))
|
82 |
+
p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
|
83 |
+
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
|
84 |
+
return p
|
85 |
+
else:
|
86 |
+
return p[[0, 3, 2, 1]]
|
87 |
+
|
88 |
+
def __call__(self, outs_dict, shape_list):
|
89 |
+
score_list = outs_dict["f_score"]
|
90 |
+
geo_list = outs_dict["f_geo"]
|
91 |
+
if isinstance(score_list, paddle.Tensor):
|
92 |
+
score_list = score_list.numpy()
|
93 |
+
geo_list = geo_list.numpy()
|
94 |
+
img_num = len(shape_list)
|
95 |
+
dt_boxes_list = []
|
96 |
+
for ino in range(img_num):
|
97 |
+
score = score_list[ino]
|
98 |
+
geo = geo_list[ino]
|
99 |
+
boxes = self.detect(
|
100 |
+
score_map=score,
|
101 |
+
geo_map=geo,
|
102 |
+
score_thresh=self.score_thresh,
|
103 |
+
cover_thresh=self.cover_thresh,
|
104 |
+
nms_thresh=self.nms_thresh,
|
105 |
+
)
|
106 |
+
boxes_norm = []
|
107 |
+
if len(boxes) > 0:
|
108 |
+
h, w = score.shape[1:]
|
109 |
+
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
110 |
+
boxes = boxes[:, :8].reshape((-1, 4, 2))
|
111 |
+
boxes[:, :, 0] /= ratio_w
|
112 |
+
boxes[:, :, 1] /= ratio_h
|
113 |
+
for i_box, box in enumerate(boxes):
|
114 |
+
box = self.sort_poly(box.astype(np.int32))
|
115 |
+
if (
|
116 |
+
np.linalg.norm(box[0] - box[1]) < 5
|
117 |
+
or np.linalg.norm(box[3] - box[0]) < 5
|
118 |
+
):
|
119 |
+
continue
|
120 |
+
boxes_norm.append(box)
|
121 |
+
dt_boxes_list.append({"points": np.array(boxes_norm)})
|
122 |
+
return dt_boxes_list
|
ocr/postprocess/extract_textpoint_fast.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
from itertools import groupby
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from skimage.morphology._skeletonize import thin
|
8 |
+
|
9 |
+
|
10 |
+
def get_dict(character_dict_path):
|
11 |
+
character_str = ""
|
12 |
+
with open(character_dict_path, "rb") as fin:
|
13 |
+
lines = fin.readlines()
|
14 |
+
for line in lines:
|
15 |
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
16 |
+
character_str += line
|
17 |
+
dict_character = list(character_str)
|
18 |
+
return dict_character
|
19 |
+
|
20 |
+
|
21 |
+
def softmax(logits):
|
22 |
+
"""
|
23 |
+
logits: N x d
|
24 |
+
"""
|
25 |
+
max_value = np.max(logits, axis=1, keepdims=True)
|
26 |
+
exp = np.exp(logits - max_value)
|
27 |
+
exp_sum = np.sum(exp, axis=1, keepdims=True)
|
28 |
+
dist = exp / exp_sum
|
29 |
+
return dist
|
30 |
+
|
31 |
+
|
32 |
+
def get_keep_pos_idxs(labels, remove_blank=None):
|
33 |
+
"""
|
34 |
+
Remove duplicate and get pos idxs of keep items.
|
35 |
+
The value of keep_blank should be [None, 95].
|
36 |
+
"""
|
37 |
+
duplicate_len_list = []
|
38 |
+
keep_pos_idx_list = []
|
39 |
+
keep_char_idx_list = []
|
40 |
+
for k, v_ in groupby(labels):
|
41 |
+
current_len = len(list(v_))
|
42 |
+
if k != remove_blank:
|
43 |
+
current_idx = int(sum(duplicate_len_list) + current_len // 2)
|
44 |
+
keep_pos_idx_list.append(current_idx)
|
45 |
+
keep_char_idx_list.append(k)
|
46 |
+
duplicate_len_list.append(current_len)
|
47 |
+
return keep_char_idx_list, keep_pos_idx_list
|
48 |
+
|
49 |
+
|
50 |
+
def remove_blank(labels, blank=0):
|
51 |
+
new_labels = [x for x in labels if x != blank]
|
52 |
+
return new_labels
|
53 |
+
|
54 |
+
|
55 |
+
def insert_blank(labels, blank=0):
|
56 |
+
new_labels = [blank]
|
57 |
+
for l in labels:
|
58 |
+
new_labels += [l, blank]
|
59 |
+
return new_labels
|
60 |
+
|
61 |
+
|
62 |
+
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
63 |
+
"""
|
64 |
+
CTC greedy (best path) decoder.
|
65 |
+
"""
|
66 |
+
raw_str = np.argmax(np.array(probs_seq), axis=1)
|
67 |
+
remove_blank_in_pos = None if keep_blank_in_idxs else blank
|
68 |
+
dedup_str, keep_idx_list = get_keep_pos_idxs(
|
69 |
+
raw_str, remove_blank=remove_blank_in_pos
|
70 |
+
)
|
71 |
+
dst_str = remove_blank(dedup_str, blank=blank)
|
72 |
+
return dst_str, keep_idx_list
|
73 |
+
|
74 |
+
|
75 |
+
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
|
76 |
+
_, _, C = logits_map.shape
|
77 |
+
ys, xs = zip(*gather_info)
|
78 |
+
logits_seq = logits_map[list(ys), list(xs)]
|
79 |
+
probs_seq = logits_seq
|
80 |
+
labels = np.argmax(probs_seq, axis=1)
|
81 |
+
dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
|
82 |
+
detal = len(gather_info) // (pts_num - 1)
|
83 |
+
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
|
84 |
+
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
|
85 |
+
return dst_str, keep_gather_list
|
86 |
+
|
87 |
+
|
88 |
+
def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, pts_num=6):
|
89 |
+
"""
|
90 |
+
CTC decoder using multiple processes.
|
91 |
+
"""
|
92 |
+
decoder_str = []
|
93 |
+
decoder_xys = []
|
94 |
+
for gather_info in gather_info_list:
|
95 |
+
if len(gather_info) < pts_num:
|
96 |
+
continue
|
97 |
+
dst_str, xys_list = instance_ctc_greedy_decoder(
|
98 |
+
gather_info, logits_map, pts_num=pts_num
|
99 |
+
)
|
100 |
+
dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str])
|
101 |
+
if len(dst_str_readable) < 2:
|
102 |
+
continue
|
103 |
+
decoder_str.append(dst_str_readable)
|
104 |
+
decoder_xys.append(xys_list)
|
105 |
+
return decoder_str, decoder_xys
|
106 |
+
|
107 |
+
|
108 |
+
def sort_with_direction(pos_list, f_direction):
|
109 |
+
"""
|
110 |
+
f_direction: h x w x 2
|
111 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
112 |
+
"""
|
113 |
+
|
114 |
+
def sort_part_with_direction(pos_list, point_direction):
|
115 |
+
pos_list = np.array(pos_list).reshape(-1, 2)
|
116 |
+
point_direction = np.array(point_direction).reshape(-1, 2)
|
117 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
118 |
+
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
119 |
+
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
|
120 |
+
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
121 |
+
return sorted_list, sorted_direction
|
122 |
+
|
123 |
+
pos_list = np.array(pos_list).reshape(-1, 2)
|
124 |
+
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
125 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
126 |
+
sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
|
127 |
+
|
128 |
+
point_num = len(sorted_point)
|
129 |
+
if point_num >= 16:
|
130 |
+
middle_num = point_num // 2
|
131 |
+
first_part_point = sorted_point[:middle_num]
|
132 |
+
first_point_direction = sorted_direction[:middle_num]
|
133 |
+
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
134 |
+
first_part_point, first_point_direction
|
135 |
+
)
|
136 |
+
|
137 |
+
last_part_point = sorted_point[middle_num:]
|
138 |
+
last_point_direction = sorted_direction[middle_num:]
|
139 |
+
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
140 |
+
last_part_point, last_point_direction
|
141 |
+
)
|
142 |
+
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
143 |
+
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
144 |
+
|
145 |
+
return sorted_point, np.array(sorted_direction)
|
146 |
+
|
147 |
+
|
148 |
+
def add_id(pos_list, image_id=0):
|
149 |
+
"""
|
150 |
+
Add id for gather feature, for inference.
|
151 |
+
"""
|
152 |
+
new_list = []
|
153 |
+
for item in pos_list:
|
154 |
+
new_list.append((image_id, item[0], item[1]))
|
155 |
+
return new_list
|
156 |
+
|
157 |
+
|
158 |
+
def sort_and_expand_with_direction(pos_list, f_direction):
|
159 |
+
"""
|
160 |
+
f_direction: h x w x 2
|
161 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
162 |
+
"""
|
163 |
+
h, w, _ = f_direction.shape
|
164 |
+
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
165 |
+
|
166 |
+
point_num = len(sorted_list)
|
167 |
+
sub_direction_len = max(point_num // 3, 2)
|
168 |
+
left_direction = point_direction[:sub_direction_len, :]
|
169 |
+
right_dirction = point_direction[point_num - sub_direction_len :, :]
|
170 |
+
|
171 |
+
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
172 |
+
left_average_len = np.linalg.norm(left_average_direction)
|
173 |
+
left_start = np.array(sorted_list[0])
|
174 |
+
left_step = left_average_direction / (left_average_len + 1e-6)
|
175 |
+
|
176 |
+
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
177 |
+
right_average_len = np.linalg.norm(right_average_direction)
|
178 |
+
right_step = right_average_direction / (right_average_len + 1e-6)
|
179 |
+
right_start = np.array(sorted_list[-1])
|
180 |
+
|
181 |
+
append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
182 |
+
left_list = []
|
183 |
+
right_list = []
|
184 |
+
for i in range(append_num):
|
185 |
+
ly, lx = (
|
186 |
+
np.round(left_start + left_step * (i + 1))
|
187 |
+
.flatten()
|
188 |
+
.astype("int32")
|
189 |
+
.tolist()
|
190 |
+
)
|
191 |
+
if ly < h and lx < w and (ly, lx) not in left_list:
|
192 |
+
left_list.append((ly, lx))
|
193 |
+
ry, rx = (
|
194 |
+
np.round(right_start + right_step * (i + 1))
|
195 |
+
.flatten()
|
196 |
+
.astype("int32")
|
197 |
+
.tolist()
|
198 |
+
)
|
199 |
+
if ry < h and rx < w and (ry, rx) not in right_list:
|
200 |
+
right_list.append((ry, rx))
|
201 |
+
|
202 |
+
all_list = left_list[::-1] + sorted_list + right_list
|
203 |
+
return all_list
|
204 |
+
|
205 |
+
|
206 |
+
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
207 |
+
"""
|
208 |
+
f_direction: h x w x 2
|
209 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
210 |
+
binary_tcl_map: h x w
|
211 |
+
"""
|
212 |
+
h, w, _ = f_direction.shape
|
213 |
+
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
214 |
+
|
215 |
+
point_num = len(sorted_list)
|
216 |
+
sub_direction_len = max(point_num // 3, 2)
|
217 |
+
left_direction = point_direction[:sub_direction_len, :]
|
218 |
+
right_dirction = point_direction[point_num - sub_direction_len :, :]
|
219 |
+
|
220 |
+
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
221 |
+
left_average_len = np.linalg.norm(left_average_direction)
|
222 |
+
left_start = np.array(sorted_list[0])
|
223 |
+
left_step = left_average_direction / (left_average_len + 1e-6)
|
224 |
+
|
225 |
+
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
226 |
+
right_average_len = np.linalg.norm(right_average_direction)
|
227 |
+
right_step = right_average_direction / (right_average_len + 1e-6)
|
228 |
+
right_start = np.array(sorted_list[-1])
|
229 |
+
|
230 |
+
append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
231 |
+
max_append_num = 2 * append_num
|
232 |
+
|
233 |
+
left_list = []
|
234 |
+
right_list = []
|
235 |
+
for i in range(max_append_num):
|
236 |
+
ly, lx = (
|
237 |
+
np.round(left_start + left_step * (i + 1))
|
238 |
+
.flatten()
|
239 |
+
.astype("int32")
|
240 |
+
.tolist()
|
241 |
+
)
|
242 |
+
if ly < h and lx < w and (ly, lx) not in left_list:
|
243 |
+
if binary_tcl_map[ly, lx] > 0.5:
|
244 |
+
left_list.append((ly, lx))
|
245 |
+
else:
|
246 |
+
break
|
247 |
+
|
248 |
+
for i in range(max_append_num):
|
249 |
+
ry, rx = (
|
250 |
+
np.round(right_start + right_step * (i + 1))
|
251 |
+
.flatten()
|
252 |
+
.astype("int32")
|
253 |
+
.tolist()
|
254 |
+
)
|
255 |
+
if ry < h and rx < w and (ry, rx) not in right_list:
|
256 |
+
if binary_tcl_map[ry, rx] > 0.5:
|
257 |
+
right_list.append((ry, rx))
|
258 |
+
else:
|
259 |
+
break
|
260 |
+
|
261 |
+
all_list = left_list[::-1] + sorted_list + right_list
|
262 |
+
return all_list
|
263 |
+
|
264 |
+
|
265 |
+
def point_pair2poly(point_pair_list):
|
266 |
+
"""
|
267 |
+
Transfer vertical point_pairs into poly point in clockwise.
|
268 |
+
"""
|
269 |
+
point_num = len(point_pair_list) * 2
|
270 |
+
point_list = [0] * point_num
|
271 |
+
for idx, point_pair in enumerate(point_pair_list):
|
272 |
+
point_list[idx] = point_pair[0]
|
273 |
+
point_list[point_num - 1 - idx] = point_pair[1]
|
274 |
+
return np.array(point_list).reshape(-1, 2)
|
275 |
+
|
276 |
+
|
277 |
+
def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
|
278 |
+
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
279 |
+
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
280 |
+
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
281 |
+
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
282 |
+
|
283 |
+
|
284 |
+
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
285 |
+
"""
|
286 |
+
expand poly along width.
|
287 |
+
"""
|
288 |
+
point_num = poly.shape[0]
|
289 |
+
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
290 |
+
left_ratio = (
|
291 |
+
-shrink_ratio_of_width
|
292 |
+
* np.linalg.norm(left_quad[0] - left_quad[3])
|
293 |
+
/ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
294 |
+
)
|
295 |
+
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
296 |
+
right_quad = np.array(
|
297 |
+
[
|
298 |
+
poly[point_num // 2 - 2],
|
299 |
+
poly[point_num // 2 - 1],
|
300 |
+
poly[point_num // 2],
|
301 |
+
poly[point_num // 2 + 1],
|
302 |
+
],
|
303 |
+
dtype=np.float32,
|
304 |
+
)
|
305 |
+
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
|
306 |
+
right_quad[0] - right_quad[3]
|
307 |
+
) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
308 |
+
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
309 |
+
poly[0] = left_quad_expand[0]
|
310 |
+
poly[-1] = left_quad_expand[-1]
|
311 |
+
poly[point_num // 2 - 1] = right_quad_expand[1]
|
312 |
+
poly[point_num // 2] = right_quad_expand[2]
|
313 |
+
return poly
|
314 |
+
|
315 |
+
|
316 |
+
def restore_poly(
|
317 |
+
instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, valid_set
|
318 |
+
):
|
319 |
+
poly_list = []
|
320 |
+
keep_str_list = []
|
321 |
+
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
322 |
+
if len(keep_str) < 2:
|
323 |
+
print("--> too short, {}".format(keep_str))
|
324 |
+
continue
|
325 |
+
|
326 |
+
offset_expand = 1.0
|
327 |
+
if valid_set == "totaltext":
|
328 |
+
offset_expand = 1.2
|
329 |
+
|
330 |
+
point_pair_list = []
|
331 |
+
for y, x in yx_center_line:
|
332 |
+
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
|
333 |
+
ori_yx = np.array([y, x], dtype=np.float32)
|
334 |
+
point_pair = (
|
335 |
+
(ori_yx + offset)[:, ::-1]
|
336 |
+
* 4.0
|
337 |
+
/ np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
338 |
+
)
|
339 |
+
point_pair_list.append(point_pair)
|
340 |
+
|
341 |
+
detected_poly = point_pair2poly(point_pair_list)
|
342 |
+
detected_poly = expand_poly_along_width(
|
343 |
+
detected_poly, shrink_ratio_of_width=0.2
|
344 |
+
)
|
345 |
+
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
346 |
+
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
347 |
+
|
348 |
+
keep_str_list.append(keep_str)
|
349 |
+
if valid_set == "partvgg":
|
350 |
+
middle_point = len(detected_poly) // 2
|
351 |
+
detected_poly = detected_poly[[0, middle_point - 1, middle_point, -1], :]
|
352 |
+
poly_list.append(detected_poly)
|
353 |
+
elif valid_set == "totaltext":
|
354 |
+
poly_list.append(detected_poly)
|
355 |
+
else:
|
356 |
+
print("--> Not supported format.")
|
357 |
+
exit(-1)
|
358 |
+
return poly_list, keep_str_list
|
359 |
+
|
360 |
+
|
361 |
+
def generate_pivot_list_fast(
|
362 |
+
p_score, p_char_maps, f_direction, Lexicon_Table, score_thresh=0.5
|
363 |
+
):
|
364 |
+
"""
|
365 |
+
return center point and end point of TCL instance; filter with the char maps;
|
366 |
+
"""
|
367 |
+
p_score = p_score[0]
|
368 |
+
f_direction = f_direction.transpose(1, 2, 0)
|
369 |
+
p_tcl_map = (p_score > score_thresh) * 1.0
|
370 |
+
skeleton_map = thin(p_tcl_map.astype(np.uint8))
|
371 |
+
instance_count, instance_label_map = cv2.connectedComponents(
|
372 |
+
skeleton_map.astype(np.uint8), connectivity=8
|
373 |
+
)
|
374 |
+
|
375 |
+
# get TCL Instance
|
376 |
+
all_pos_yxs = []
|
377 |
+
if instance_count > 0:
|
378 |
+
for instance_id in range(1, instance_count):
|
379 |
+
pos_list = []
|
380 |
+
ys, xs = np.where(instance_label_map == instance_id)
|
381 |
+
pos_list = list(zip(ys, xs))
|
382 |
+
|
383 |
+
if len(pos_list) < 3:
|
384 |
+
continue
|
385 |
+
|
386 |
+
pos_list_sorted = sort_and_expand_with_direction_v2(
|
387 |
+
pos_list, f_direction, p_tcl_map
|
388 |
+
)
|
389 |
+
all_pos_yxs.append(pos_list_sorted)
|
390 |
+
|
391 |
+
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
392 |
+
decoded_str, keep_yxs_list = ctc_decoder_for_image(
|
393 |
+
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table
|
394 |
+
)
|
395 |
+
return keep_yxs_list, decoded_str
|
396 |
+
|
397 |
+
|
398 |
+
def extract_main_direction(pos_list, f_direction):
|
399 |
+
"""
|
400 |
+
f_direction: h x w x 2
|
401 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
402 |
+
"""
|
403 |
+
pos_list = np.array(pos_list)
|
404 |
+
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
|
405 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
406 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
407 |
+
average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
|
408 |
+
return average_direction
|
409 |
+
|
410 |
+
|
411 |
+
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
|
412 |
+
"""
|
413 |
+
f_direction: h x w x 2
|
414 |
+
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
|
415 |
+
"""
|
416 |
+
pos_list_full = np.array(pos_list).reshape(-1, 3)
|
417 |
+
pos_list = pos_list_full[:, 1:]
|
418 |
+
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
419 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
420 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
421 |
+
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
422 |
+
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
423 |
+
return sorted_list
|
424 |
+
|
425 |
+
|
426 |
+
def sort_by_direction_with_image_id(pos_list, f_direction):
|
427 |
+
"""
|
428 |
+
f_direction: h x w x 2
|
429 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
430 |
+
"""
|
431 |
+
|
432 |
+
def sort_part_with_direction(pos_list_full, point_direction):
|
433 |
+
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
|
434 |
+
pos_list = pos_list_full[:, 1:]
|
435 |
+
point_direction = np.array(point_direction).reshape(-1, 2)
|
436 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
437 |
+
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
438 |
+
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
439 |
+
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
440 |
+
return sorted_list, sorted_direction
|
441 |
+
|
442 |
+
pos_list = np.array(pos_list).reshape(-1, 3)
|
443 |
+
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
|
444 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
445 |
+
sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
|
446 |
+
|
447 |
+
point_num = len(sorted_point)
|
448 |
+
if point_num >= 16:
|
449 |
+
middle_num = point_num // 2
|
450 |
+
first_part_point = sorted_point[:middle_num]
|
451 |
+
first_point_direction = sorted_direction[:middle_num]
|
452 |
+
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
453 |
+
first_part_point, first_point_direction
|
454 |
+
)
|
455 |
+
|
456 |
+
last_part_point = sorted_point[middle_num:]
|
457 |
+
last_point_direction = sorted_direction[middle_num:]
|
458 |
+
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
459 |
+
last_part_point, last_point_direction
|
460 |
+
)
|
461 |
+
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
462 |
+
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
463 |
+
|
464 |
+
return sorted_point
|
ocr/postprocess/extract_textpoint_slow.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import math
|
4 |
+
from itertools import groupby
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from skimage.morphology._skeletonize import thin
|
9 |
+
|
10 |
+
|
11 |
+
def get_dict(character_dict_path):
|
12 |
+
character_str = ""
|
13 |
+
with open(character_dict_path, "rb") as fin:
|
14 |
+
lines = fin.readlines()
|
15 |
+
for line in lines:
|
16 |
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
17 |
+
character_str += line
|
18 |
+
dict_character = list(character_str)
|
19 |
+
return dict_character
|
20 |
+
|
21 |
+
|
22 |
+
def point_pair2poly(point_pair_list):
|
23 |
+
"""
|
24 |
+
Transfer vertical point_pairs into poly point in clockwise.
|
25 |
+
"""
|
26 |
+
pair_length_list = []
|
27 |
+
for point_pair in point_pair_list:
|
28 |
+
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
|
29 |
+
pair_length_list.append(pair_length)
|
30 |
+
pair_length_list = np.array(pair_length_list)
|
31 |
+
pair_info = (
|
32 |
+
pair_length_list.max(),
|
33 |
+
pair_length_list.min(),
|
34 |
+
pair_length_list.mean(),
|
35 |
+
)
|
36 |
+
|
37 |
+
point_num = len(point_pair_list) * 2
|
38 |
+
point_list = [0] * point_num
|
39 |
+
for idx, point_pair in enumerate(point_pair_list):
|
40 |
+
point_list[idx] = point_pair[0]
|
41 |
+
point_list[point_num - 1 - idx] = point_pair[1]
|
42 |
+
return np.array(point_list).reshape(-1, 2), pair_info
|
43 |
+
|
44 |
+
|
45 |
+
def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
|
46 |
+
"""
|
47 |
+
Generate shrink_quad_along_width.
|
48 |
+
"""
|
49 |
+
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
50 |
+
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
51 |
+
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
52 |
+
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
53 |
+
|
54 |
+
|
55 |
+
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
56 |
+
"""
|
57 |
+
expand poly along width.
|
58 |
+
"""
|
59 |
+
point_num = poly.shape[0]
|
60 |
+
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
61 |
+
left_ratio = (
|
62 |
+
-shrink_ratio_of_width
|
63 |
+
* np.linalg.norm(left_quad[0] - left_quad[3])
|
64 |
+
/ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
65 |
+
)
|
66 |
+
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
67 |
+
right_quad = np.array(
|
68 |
+
[
|
69 |
+
poly[point_num // 2 - 2],
|
70 |
+
poly[point_num // 2 - 1],
|
71 |
+
poly[point_num // 2],
|
72 |
+
poly[point_num // 2 + 1],
|
73 |
+
],
|
74 |
+
dtype=np.float32,
|
75 |
+
)
|
76 |
+
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
|
77 |
+
right_quad[0] - right_quad[3]
|
78 |
+
) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
79 |
+
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
80 |
+
poly[0] = left_quad_expand[0]
|
81 |
+
poly[-1] = left_quad_expand[-1]
|
82 |
+
poly[point_num // 2 - 1] = right_quad_expand[1]
|
83 |
+
poly[point_num // 2] = right_quad_expand[2]
|
84 |
+
return poly
|
85 |
+
|
86 |
+
|
87 |
+
def softmax(logits):
|
88 |
+
"""
|
89 |
+
logits: N x d
|
90 |
+
"""
|
91 |
+
max_value = np.max(logits, axis=1, keepdims=True)
|
92 |
+
exp = np.exp(logits - max_value)
|
93 |
+
exp_sum = np.sum(exp, axis=1, keepdims=True)
|
94 |
+
dist = exp / exp_sum
|
95 |
+
return dist
|
96 |
+
|
97 |
+
|
98 |
+
def get_keep_pos_idxs(labels, remove_blank=None):
|
99 |
+
"""
|
100 |
+
Remove duplicate and get pos idxs of keep items.
|
101 |
+
The value of keep_blank should be [None, 95].
|
102 |
+
"""
|
103 |
+
duplicate_len_list = []
|
104 |
+
keep_pos_idx_list = []
|
105 |
+
keep_char_idx_list = []
|
106 |
+
for k, v_ in groupby(labels):
|
107 |
+
current_len = len(list(v_))
|
108 |
+
if k != remove_blank:
|
109 |
+
current_idx = int(sum(duplicate_len_list) + current_len // 2)
|
110 |
+
keep_pos_idx_list.append(current_idx)
|
111 |
+
keep_char_idx_list.append(k)
|
112 |
+
duplicate_len_list.append(current_len)
|
113 |
+
return keep_char_idx_list, keep_pos_idx_list
|
114 |
+
|
115 |
+
|
116 |
+
def remove_blank(labels, blank=0):
|
117 |
+
new_labels = [x for x in labels if x != blank]
|
118 |
+
return new_labels
|
119 |
+
|
120 |
+
|
121 |
+
def insert_blank(labels, blank=0):
|
122 |
+
new_labels = [blank]
|
123 |
+
for l in labels:
|
124 |
+
new_labels += [l, blank]
|
125 |
+
return new_labels
|
126 |
+
|
127 |
+
|
128 |
+
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
129 |
+
"""
|
130 |
+
CTC greedy (best path) decoder.
|
131 |
+
"""
|
132 |
+
raw_str = np.argmax(np.array(probs_seq), axis=1)
|
133 |
+
remove_blank_in_pos = None if keep_blank_in_idxs else blank
|
134 |
+
dedup_str, keep_idx_list = get_keep_pos_idxs(
|
135 |
+
raw_str, remove_blank=remove_blank_in_pos
|
136 |
+
)
|
137 |
+
dst_str = remove_blank(dedup_str, blank=blank)
|
138 |
+
return dst_str, keep_idx_list
|
139 |
+
|
140 |
+
|
141 |
+
def instance_ctc_greedy_decoder(gather_info, logits_map, keep_blank_in_idxs=True):
|
142 |
+
"""
|
143 |
+
gather_info: [[x, y], [x, y] ...]
|
144 |
+
logits_map: H x W X (n_chars + 1)
|
145 |
+
"""
|
146 |
+
_, _, C = logits_map.shape
|
147 |
+
ys, xs = zip(*gather_info)
|
148 |
+
logits_seq = logits_map[list(ys), list(xs)] # n x 96
|
149 |
+
probs_seq = softmax(logits_seq)
|
150 |
+
dst_str, keep_idx_list = ctc_greedy_decoder(
|
151 |
+
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs
|
152 |
+
)
|
153 |
+
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
|
154 |
+
return dst_str, keep_gather_list
|
155 |
+
|
156 |
+
|
157 |
+
def ctc_decoder_for_image(gather_info_list, logits_map, keep_blank_in_idxs=True):
|
158 |
+
"""
|
159 |
+
CTC decoder using multiple processes.
|
160 |
+
"""
|
161 |
+
decoder_results = []
|
162 |
+
for gather_info in gather_info_list:
|
163 |
+
res = instance_ctc_greedy_decoder(
|
164 |
+
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs
|
165 |
+
)
|
166 |
+
decoder_results.append(res)
|
167 |
+
return decoder_results
|
168 |
+
|
169 |
+
|
170 |
+
def sort_with_direction(pos_list, f_direction):
|
171 |
+
"""
|
172 |
+
f_direction: h x w x 2
|
173 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
174 |
+
"""
|
175 |
+
|
176 |
+
def sort_part_with_direction(pos_list, point_direction):
|
177 |
+
pos_list = np.array(pos_list).reshape(-1, 2)
|
178 |
+
point_direction = np.array(point_direction).reshape(-1, 2)
|
179 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
180 |
+
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
181 |
+
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
|
182 |
+
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
183 |
+
return sorted_list, sorted_direction
|
184 |
+
|
185 |
+
pos_list = np.array(pos_list).reshape(-1, 2)
|
186 |
+
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
187 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
188 |
+
sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
|
189 |
+
|
190 |
+
point_num = len(sorted_point)
|
191 |
+
if point_num >= 16:
|
192 |
+
middle_num = point_num // 2
|
193 |
+
first_part_point = sorted_point[:middle_num]
|
194 |
+
first_point_direction = sorted_direction[:middle_num]
|
195 |
+
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
196 |
+
first_part_point, first_point_direction
|
197 |
+
)
|
198 |
+
|
199 |
+
last_part_point = sorted_point[middle_num:]
|
200 |
+
last_point_direction = sorted_direction[middle_num:]
|
201 |
+
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
202 |
+
last_part_point, last_point_direction
|
203 |
+
)
|
204 |
+
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
205 |
+
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
206 |
+
|
207 |
+
return sorted_point, np.array(sorted_direction)
|
208 |
+
|
209 |
+
|
210 |
+
def add_id(pos_list, image_id=0):
|
211 |
+
"""
|
212 |
+
Add id for gather feature, for inference.
|
213 |
+
"""
|
214 |
+
new_list = []
|
215 |
+
for item in pos_list:
|
216 |
+
new_list.append((image_id, item[0], item[1]))
|
217 |
+
return new_list
|
218 |
+
|
219 |
+
|
220 |
+
def sort_and_expand_with_direction(pos_list, f_direction):
|
221 |
+
"""
|
222 |
+
f_direction: h x w x 2
|
223 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
224 |
+
"""
|
225 |
+
h, w, _ = f_direction.shape
|
226 |
+
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
227 |
+
|
228 |
+
# expand along
|
229 |
+
point_num = len(sorted_list)
|
230 |
+
sub_direction_len = max(point_num // 3, 2)
|
231 |
+
left_direction = point_direction[:sub_direction_len, :]
|
232 |
+
right_dirction = point_direction[point_num - sub_direction_len :, :]
|
233 |
+
|
234 |
+
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
235 |
+
left_average_len = np.linalg.norm(left_average_direction)
|
236 |
+
left_start = np.array(sorted_list[0])
|
237 |
+
left_step = left_average_direction / (left_average_len + 1e-6)
|
238 |
+
|
239 |
+
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
240 |
+
right_average_len = np.linalg.norm(right_average_direction)
|
241 |
+
right_step = right_average_direction / (right_average_len + 1e-6)
|
242 |
+
right_start = np.array(sorted_list[-1])
|
243 |
+
|
244 |
+
append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
245 |
+
left_list = []
|
246 |
+
right_list = []
|
247 |
+
for i in range(append_num):
|
248 |
+
ly, lx = (
|
249 |
+
np.round(left_start + left_step * (i + 1))
|
250 |
+
.flatten()
|
251 |
+
.astype("int32")
|
252 |
+
.tolist()
|
253 |
+
)
|
254 |
+
if ly < h and lx < w and (ly, lx) not in left_list:
|
255 |
+
left_list.append((ly, lx))
|
256 |
+
ry, rx = (
|
257 |
+
np.round(right_start + right_step * (i + 1))
|
258 |
+
.flatten()
|
259 |
+
.astype("int32")
|
260 |
+
.tolist()
|
261 |
+
)
|
262 |
+
if ry < h and rx < w and (ry, rx) not in right_list:
|
263 |
+
right_list.append((ry, rx))
|
264 |
+
|
265 |
+
all_list = left_list[::-1] + sorted_list + right_list
|
266 |
+
return all_list
|
267 |
+
|
268 |
+
|
269 |
+
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
270 |
+
"""
|
271 |
+
f_direction: h x w x 2
|
272 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
273 |
+
binary_tcl_map: h x w
|
274 |
+
"""
|
275 |
+
h, w, _ = f_direction.shape
|
276 |
+
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
277 |
+
|
278 |
+
# expand along
|
279 |
+
point_num = len(sorted_list)
|
280 |
+
sub_direction_len = max(point_num // 3, 2)
|
281 |
+
left_direction = point_direction[:sub_direction_len, :]
|
282 |
+
right_dirction = point_direction[point_num - sub_direction_len :, :]
|
283 |
+
|
284 |
+
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
285 |
+
left_average_len = np.linalg.norm(left_average_direction)
|
286 |
+
left_start = np.array(sorted_list[0])
|
287 |
+
left_step = left_average_direction / (left_average_len + 1e-6)
|
288 |
+
|
289 |
+
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
290 |
+
right_average_len = np.linalg.norm(right_average_direction)
|
291 |
+
right_step = right_average_direction / (right_average_len + 1e-6)
|
292 |
+
right_start = np.array(sorted_list[-1])
|
293 |
+
|
294 |
+
append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
295 |
+
max_append_num = 2 * append_num
|
296 |
+
|
297 |
+
left_list = []
|
298 |
+
right_list = []
|
299 |
+
for i in range(max_append_num):
|
300 |
+
ly, lx = (
|
301 |
+
np.round(left_start + left_step * (i + 1))
|
302 |
+
.flatten()
|
303 |
+
.astype("int32")
|
304 |
+
.tolist()
|
305 |
+
)
|
306 |
+
if ly < h and lx < w and (ly, lx) not in left_list:
|
307 |
+
if binary_tcl_map[ly, lx] > 0.5:
|
308 |
+
left_list.append((ly, lx))
|
309 |
+
else:
|
310 |
+
break
|
311 |
+
|
312 |
+
for i in range(max_append_num):
|
313 |
+
ry, rx = (
|
314 |
+
np.round(right_start + right_step * (i + 1))
|
315 |
+
.flatten()
|
316 |
+
.astype("int32")
|
317 |
+
.tolist()
|
318 |
+
)
|
319 |
+
if ry < h and rx < w and (ry, rx) not in right_list:
|
320 |
+
if binary_tcl_map[ry, rx] > 0.5:
|
321 |
+
right_list.append((ry, rx))
|
322 |
+
else:
|
323 |
+
break
|
324 |
+
|
325 |
+
all_list = left_list[::-1] + sorted_list + right_list
|
326 |
+
return all_list
|
327 |
+
|
328 |
+
|
329 |
+
def generate_pivot_list_curved(
|
330 |
+
p_score,
|
331 |
+
p_char_maps,
|
332 |
+
f_direction,
|
333 |
+
score_thresh=0.5,
|
334 |
+
is_expand=True,
|
335 |
+
is_backbone=False,
|
336 |
+
image_id=0,
|
337 |
+
):
|
338 |
+
"""
|
339 |
+
return center point and end point of TCL instance; filter with the char maps;
|
340 |
+
"""
|
341 |
+
p_score = p_score[0]
|
342 |
+
f_direction = f_direction.transpose(1, 2, 0)
|
343 |
+
p_tcl_map = (p_score > score_thresh) * 1.0
|
344 |
+
skeleton_map = thin(p_tcl_map)
|
345 |
+
instance_count, instance_label_map = cv2.connectedComponents(
|
346 |
+
skeleton_map.astype(np.uint8), connectivity=8
|
347 |
+
)
|
348 |
+
|
349 |
+
# get TCL Instance
|
350 |
+
all_pos_yxs = []
|
351 |
+
center_pos_yxs = []
|
352 |
+
end_points_yxs = []
|
353 |
+
instance_center_pos_yxs = []
|
354 |
+
pred_strs = []
|
355 |
+
if instance_count > 0:
|
356 |
+
for instance_id in range(1, instance_count):
|
357 |
+
pos_list = []
|
358 |
+
ys, xs = np.where(instance_label_map == instance_id)
|
359 |
+
pos_list = list(zip(ys, xs))
|
360 |
+
|
361 |
+
### FIX-ME, eliminate outlier
|
362 |
+
if len(pos_list) < 3:
|
363 |
+
continue
|
364 |
+
|
365 |
+
if is_expand:
|
366 |
+
pos_list_sorted = sort_and_expand_with_direction_v2(
|
367 |
+
pos_list, f_direction, p_tcl_map
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
|
371 |
+
all_pos_yxs.append(pos_list_sorted)
|
372 |
+
|
373 |
+
# use decoder to filter backgroud points.
|
374 |
+
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
375 |
+
decode_res = ctc_decoder_for_image(
|
376 |
+
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
|
377 |
+
)
|
378 |
+
for decoded_str, keep_yxs_list in decode_res:
|
379 |
+
if is_backbone:
|
380 |
+
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
381 |
+
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
382 |
+
pred_strs.append(decoded_str)
|
383 |
+
else:
|
384 |
+
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
385 |
+
center_pos_yxs.extend(keep_yxs_list)
|
386 |
+
|
387 |
+
if is_backbone:
|
388 |
+
return pred_strs, instance_center_pos_yxs
|
389 |
+
else:
|
390 |
+
return center_pos_yxs, end_points_yxs
|
391 |
+
|
392 |
+
|
393 |
+
def generate_pivot_list_horizontal(
|
394 |
+
p_score, p_char_maps, f_direction, score_thresh=0.5, is_backbone=False, image_id=0
|
395 |
+
):
|
396 |
+
"""
|
397 |
+
return center point and end point of TCL instance; filter with the char maps;
|
398 |
+
"""
|
399 |
+
p_score = p_score[0]
|
400 |
+
f_direction = f_direction.transpose(1, 2, 0)
|
401 |
+
p_tcl_map_bi = (p_score > score_thresh) * 1.0
|
402 |
+
instance_count, instance_label_map = cv2.connectedComponents(
|
403 |
+
p_tcl_map_bi.astype(np.uint8), connectivity=8
|
404 |
+
)
|
405 |
+
|
406 |
+
# get TCL Instance
|
407 |
+
all_pos_yxs = []
|
408 |
+
center_pos_yxs = []
|
409 |
+
end_points_yxs = []
|
410 |
+
instance_center_pos_yxs = []
|
411 |
+
|
412 |
+
if instance_count > 0:
|
413 |
+
for instance_id in range(1, instance_count):
|
414 |
+
pos_list = []
|
415 |
+
ys, xs = np.where(instance_label_map == instance_id)
|
416 |
+
pos_list = list(zip(ys, xs))
|
417 |
+
|
418 |
+
### FIX-ME, eliminate outlier
|
419 |
+
if len(pos_list) < 5:
|
420 |
+
continue
|
421 |
+
|
422 |
+
# add rule here
|
423 |
+
main_direction = extract_main_direction(pos_list, f_direction) # y x
|
424 |
+
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
|
425 |
+
is_h_angle = abs(np.sum(main_direction * reference_directin)) < math.cos(
|
426 |
+
math.pi / 180 * 70
|
427 |
+
)
|
428 |
+
|
429 |
+
point_yxs = np.array(pos_list)
|
430 |
+
max_y, max_x = np.max(point_yxs, axis=0)
|
431 |
+
min_y, min_x = np.min(point_yxs, axis=0)
|
432 |
+
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
|
433 |
+
|
434 |
+
pos_list_final = []
|
435 |
+
if is_h_len:
|
436 |
+
xs = np.unique(xs)
|
437 |
+
for x in xs:
|
438 |
+
ys = instance_label_map[:, x].copy().reshape((-1,))
|
439 |
+
y = int(np.where(ys == instance_id)[0].mean())
|
440 |
+
pos_list_final.append((y, x))
|
441 |
+
else:
|
442 |
+
ys = np.unique(ys)
|
443 |
+
for y in ys:
|
444 |
+
xs = instance_label_map[y, :].copy().reshape((-1,))
|
445 |
+
x = int(np.where(xs == instance_id)[0].mean())
|
446 |
+
pos_list_final.append((y, x))
|
447 |
+
|
448 |
+
pos_list_sorted, _ = sort_with_direction(pos_list_final, f_direction)
|
449 |
+
all_pos_yxs.append(pos_list_sorted)
|
450 |
+
|
451 |
+
# use decoder to filter backgroud points.
|
452 |
+
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
453 |
+
decode_res = ctc_decoder_for_image(
|
454 |
+
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
|
455 |
+
)
|
456 |
+
for decoded_str, keep_yxs_list in decode_res:
|
457 |
+
if is_backbone:
|
458 |
+
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
459 |
+
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
460 |
+
else:
|
461 |
+
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
462 |
+
center_pos_yxs.extend(keep_yxs_list)
|
463 |
+
|
464 |
+
if is_backbone:
|
465 |
+
return instance_center_pos_yxs
|
466 |
+
else:
|
467 |
+
return center_pos_yxs, end_points_yxs
|
468 |
+
|
469 |
+
|
470 |
+
def generate_pivot_list_slow(
|
471 |
+
p_score,
|
472 |
+
p_char_maps,
|
473 |
+
f_direction,
|
474 |
+
score_thresh=0.5,
|
475 |
+
is_backbone=False,
|
476 |
+
is_curved=True,
|
477 |
+
image_id=0,
|
478 |
+
):
|
479 |
+
"""
|
480 |
+
Warp all the function together.
|
481 |
+
"""
|
482 |
+
if is_curved:
|
483 |
+
return generate_pivot_list_curved(
|
484 |
+
p_score,
|
485 |
+
p_char_maps,
|
486 |
+
f_direction,
|
487 |
+
score_thresh=score_thresh,
|
488 |
+
is_expand=True,
|
489 |
+
is_backbone=is_backbone,
|
490 |
+
image_id=image_id,
|
491 |
+
)
|
492 |
+
else:
|
493 |
+
return generate_pivot_list_horizontal(
|
494 |
+
p_score,
|
495 |
+
p_char_maps,
|
496 |
+
f_direction,
|
497 |
+
score_thresh=score_thresh,
|
498 |
+
is_backbone=is_backbone,
|
499 |
+
image_id=image_id,
|
500 |
+
)
|
501 |
+
|
502 |
+
|
503 |
+
# for refine module
|
504 |
+
def extract_main_direction(pos_list, f_direction):
|
505 |
+
"""
|
506 |
+
f_direction: h x w x 2
|
507 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
508 |
+
"""
|
509 |
+
pos_list = np.array(pos_list)
|
510 |
+
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
|
511 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
512 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
513 |
+
average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
|
514 |
+
return average_direction
|
515 |
+
|
516 |
+
|
517 |
+
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
|
518 |
+
"""
|
519 |
+
f_direction: h x w x 2
|
520 |
+
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
|
521 |
+
"""
|
522 |
+
pos_list_full = np.array(pos_list).reshape(-1, 3)
|
523 |
+
pos_list = pos_list_full[:, 1:]
|
524 |
+
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
525 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
526 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
527 |
+
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
528 |
+
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
529 |
+
return sorted_list
|
530 |
+
|
531 |
+
|
532 |
+
def sort_by_direction_with_image_id(pos_list, f_direction):
|
533 |
+
"""
|
534 |
+
f_direction: h x w x 2
|
535 |
+
pos_list: [[y, x], [y, x], [y, x] ...]
|
536 |
+
"""
|
537 |
+
|
538 |
+
def sort_part_with_direction(pos_list_full, point_direction):
|
539 |
+
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
|
540 |
+
pos_list = pos_list_full[:, 1:]
|
541 |
+
point_direction = np.array(point_direction).reshape(-1, 2)
|
542 |
+
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
543 |
+
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
544 |
+
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
545 |
+
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
546 |
+
return sorted_list, sorted_direction
|
547 |
+
|
548 |
+
pos_list = np.array(pos_list).reshape(-1, 3)
|
549 |
+
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
|
550 |
+
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
551 |
+
sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
|
552 |
+
|
553 |
+
point_num = len(sorted_point)
|
554 |
+
if point_num >= 16:
|
555 |
+
middle_num = point_num // 2
|
556 |
+
first_part_point = sorted_point[:middle_num]
|
557 |
+
first_point_direction = sorted_direction[:middle_num]
|
558 |
+
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
559 |
+
first_part_point, first_point_direction
|
560 |
+
)
|
561 |
+
|
562 |
+
last_part_point = sorted_point[middle_num:]
|
563 |
+
last_point_direction = sorted_direction[middle_num:]
|
564 |
+
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
565 |
+
last_part_point, last_point_direction
|
566 |
+
)
|
567 |
+
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
568 |
+
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
569 |
+
|
570 |
+
return sorted_point
|
571 |
+
|
572 |
+
|
573 |
+
def generate_pivot_list_tt_inference(
|
574 |
+
p_score,
|
575 |
+
p_char_maps,
|
576 |
+
f_direction,
|
577 |
+
score_thresh=0.5,
|
578 |
+
is_backbone=False,
|
579 |
+
is_curved=True,
|
580 |
+
image_id=0,
|
581 |
+
):
|
582 |
+
"""
|
583 |
+
return center point and end point of TCL instance; filter with the char maps;
|
584 |
+
"""
|
585 |
+
p_score = p_score[0]
|
586 |
+
f_direction = f_direction.transpose(1, 2, 0)
|
587 |
+
p_tcl_map = (p_score > score_thresh) * 1.0
|
588 |
+
skeleton_map = thin(p_tcl_map)
|
589 |
+
instance_count, instance_label_map = cv2.connectedComponents(
|
590 |
+
skeleton_map.astype(np.uint8), connectivity=8
|
591 |
+
)
|
592 |
+
|
593 |
+
# get TCL Instance
|
594 |
+
all_pos_yxs = []
|
595 |
+
if instance_count > 0:
|
596 |
+
for instance_id in range(1, instance_count):
|
597 |
+
pos_list = []
|
598 |
+
ys, xs = np.where(instance_label_map == instance_id)
|
599 |
+
pos_list = list(zip(ys, xs))
|
600 |
+
### FIX-ME, eliminate outlier
|
601 |
+
if len(pos_list) < 3:
|
602 |
+
continue
|
603 |
+
pos_list_sorted = sort_and_expand_with_direction_v2(
|
604 |
+
pos_list, f_direction, p_tcl_map
|
605 |
+
)
|
606 |
+
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
|
607 |
+
all_pos_yxs.append(pos_list_sorted_with_id)
|
608 |
+
return all_pos_yxs
|
ocr/postprocess/fce_postprocess.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import paddle
|
4 |
+
from numpy.fft import ifft
|
5 |
+
|
6 |
+
from .poly_nms import *
|
7 |
+
|
8 |
+
|
9 |
+
def fill_hole(input_mask):
|
10 |
+
h, w = input_mask.shape
|
11 |
+
canvas = np.zeros((h + 2, w + 2), np.uint8)
|
12 |
+
canvas[1 : h + 1, 1 : w + 1] = input_mask.copy()
|
13 |
+
|
14 |
+
mask = np.zeros((h + 4, w + 4), np.uint8)
|
15 |
+
|
16 |
+
cv2.floodFill(canvas, mask, (0, 0), 1)
|
17 |
+
canvas = canvas[1 : h + 1, 1 : w + 1].astype(np.bool)
|
18 |
+
|
19 |
+
return ~canvas | input_mask
|
20 |
+
|
21 |
+
|
22 |
+
def fourier2poly(fourier_coeff, num_reconstr_points=50):
|
23 |
+
"""Inverse Fourier transform
|
24 |
+
Args:
|
25 |
+
fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
|
26 |
+
with n and k being candidates number and Fourier degree
|
27 |
+
respectively.
|
28 |
+
num_reconstr_points (int): Number of reconstructed polygon points.
|
29 |
+
Returns:
|
30 |
+
Polygons (ndarray): The reconstructed polygons shaped (n, n')
|
31 |
+
"""
|
32 |
+
|
33 |
+
a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype="complex")
|
34 |
+
k = (len(fourier_coeff[0]) - 1) // 2
|
35 |
+
|
36 |
+
a[:, 0 : k + 1] = fourier_coeff[:, k:]
|
37 |
+
a[:, -k:] = fourier_coeff[:, :k]
|
38 |
+
|
39 |
+
poly_complex = ifft(a) * num_reconstr_points
|
40 |
+
polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
|
41 |
+
polygon[:, :, 0] = poly_complex.real
|
42 |
+
polygon[:, :, 1] = poly_complex.imag
|
43 |
+
return polygon.astype("int32").reshape((len(fourier_coeff), -1))
|
44 |
+
|
45 |
+
|
46 |
+
class FCEPostProcess(object):
|
47 |
+
"""
|
48 |
+
The post process for FCENet.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
scales,
|
54 |
+
fourier_degree=5,
|
55 |
+
num_reconstr_points=50,
|
56 |
+
decoding_type="fcenet",
|
57 |
+
score_thr=0.3,
|
58 |
+
nms_thr=0.1,
|
59 |
+
alpha=1.0,
|
60 |
+
beta=1.0,
|
61 |
+
box_type="poly",
|
62 |
+
**kwargs
|
63 |
+
):
|
64 |
+
|
65 |
+
self.scales = scales
|
66 |
+
self.fourier_degree = fourier_degree
|
67 |
+
self.num_reconstr_points = num_reconstr_points
|
68 |
+
self.decoding_type = decoding_type
|
69 |
+
self.score_thr = score_thr
|
70 |
+
self.nms_thr = nms_thr
|
71 |
+
self.alpha = alpha
|
72 |
+
self.beta = beta
|
73 |
+
self.box_type = box_type
|
74 |
+
|
75 |
+
def __call__(self, preds, shape_list):
|
76 |
+
score_maps = []
|
77 |
+
for key, value in preds.items():
|
78 |
+
if isinstance(value, paddle.Tensor):
|
79 |
+
value = value.numpy()
|
80 |
+
cls_res = value[:, :4, :, :]
|
81 |
+
reg_res = value[:, 4:, :, :]
|
82 |
+
score_maps.append([cls_res, reg_res])
|
83 |
+
|
84 |
+
return self.get_boundary(score_maps, shape_list)
|
85 |
+
|
86 |
+
def resize_boundary(self, boundaries, scale_factor):
|
87 |
+
"""Rescale boundaries via scale_factor.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
boundaries (list[list[float]]): The boundary list. Each boundary
|
91 |
+
with size 2k+1 with k>=4.
|
92 |
+
scale_factor(ndarray): The scale factor of size (4,).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
boundaries (list[list[float]]): The scaled boundaries.
|
96 |
+
"""
|
97 |
+
boxes = []
|
98 |
+
scores = []
|
99 |
+
for b in boundaries:
|
100 |
+
sz = len(b)
|
101 |
+
valid_boundary(b, True)
|
102 |
+
scores.append(b[-1])
|
103 |
+
b = (
|
104 |
+
(
|
105 |
+
np.array(b[: sz - 1])
|
106 |
+
* (np.tile(scale_factor[:2], int((sz - 1) / 2)).reshape(1, sz - 1))
|
107 |
+
)
|
108 |
+
.flatten()
|
109 |
+
.tolist()
|
110 |
+
)
|
111 |
+
boxes.append(np.array(b).reshape([-1, 2]))
|
112 |
+
|
113 |
+
return np.array(boxes, dtype=np.float32), scores
|
114 |
+
|
115 |
+
def get_boundary(self, score_maps, shape_list):
|
116 |
+
assert len(score_maps) == len(self.scales)
|
117 |
+
boundaries = []
|
118 |
+
for idx, score_map in enumerate(score_maps):
|
119 |
+
scale = self.scales[idx]
|
120 |
+
boundaries = boundaries + self._get_boundary_single(score_map, scale)
|
121 |
+
|
122 |
+
# nms
|
123 |
+
boundaries = poly_nms(boundaries, self.nms_thr)
|
124 |
+
boundaries, scores = self.resize_boundary(
|
125 |
+
boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]
|
126 |
+
)
|
127 |
+
|
128 |
+
boxes_batch = [dict(points=boundaries, scores=scores)]
|
129 |
+
return boxes_batch
|
130 |
+
|
131 |
+
def _get_boundary_single(self, score_map, scale):
|
132 |
+
assert len(score_map) == 2
|
133 |
+
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
|
134 |
+
|
135 |
+
return self.fcenet_decode(
|
136 |
+
preds=score_map,
|
137 |
+
fourier_degree=self.fourier_degree,
|
138 |
+
num_reconstr_points=self.num_reconstr_points,
|
139 |
+
scale=scale,
|
140 |
+
alpha=self.alpha,
|
141 |
+
beta=self.beta,
|
142 |
+
box_type=self.box_type,
|
143 |
+
score_thr=self.score_thr,
|
144 |
+
nms_thr=self.nms_thr,
|
145 |
+
)
|
146 |
+
|
147 |
+
def fcenet_decode(
|
148 |
+
self,
|
149 |
+
preds,
|
150 |
+
fourier_degree,
|
151 |
+
num_reconstr_points,
|
152 |
+
scale,
|
153 |
+
alpha=1.0,
|
154 |
+
beta=2.0,
|
155 |
+
box_type="poly",
|
156 |
+
score_thr=0.3,
|
157 |
+
nms_thr=0.1,
|
158 |
+
):
|
159 |
+
"""Decoding predictions of FCENet to instances.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
preds (list(Tensor)): The head output tensors.
|
163 |
+
fourier_degree (int): The maximum Fourier transform degree k.
|
164 |
+
num_reconstr_points (int): The points number of the polygon
|
165 |
+
reconstructed from predicted Fourier coefficients.
|
166 |
+
scale (int): The down-sample scale of the prediction.
|
167 |
+
alpha (float) : The parameter to calculate final scores. Score_{final}
|
168 |
+
= (Score_{text region} ^ alpha)
|
169 |
+
* (Score_{text center region}^ beta)
|
170 |
+
beta (float) : The parameter to calculate final score.
|
171 |
+
box_type (str): Boundary encoding type 'poly' or 'quad'.
|
172 |
+
score_thr (float) : The threshold used to filter out the final
|
173 |
+
candidates.
|
174 |
+
nms_thr (float) : The threshold of nms.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
boundaries (list[list[float]]): The instance boundary and confidence
|
178 |
+
list.
|
179 |
+
"""
|
180 |
+
assert isinstance(preds, list)
|
181 |
+
assert len(preds) == 2
|
182 |
+
assert box_type in ["poly", "quad"]
|
183 |
+
|
184 |
+
cls_pred = preds[0][0]
|
185 |
+
tr_pred = cls_pred[0:2]
|
186 |
+
tcl_pred = cls_pred[2:]
|
187 |
+
|
188 |
+
reg_pred = preds[1][0].transpose([1, 2, 0])
|
189 |
+
x_pred = reg_pred[:, :, : 2 * fourier_degree + 1]
|
190 |
+
y_pred = reg_pred[:, :, 2 * fourier_degree + 1 :]
|
191 |
+
|
192 |
+
score_pred = (tr_pred[1] ** alpha) * (tcl_pred[1] ** beta)
|
193 |
+
tr_pred_mask = (score_pred) > score_thr
|
194 |
+
tr_mask = fill_hole(tr_pred_mask)
|
195 |
+
|
196 |
+
tr_contours, _ = cv2.findContours(
|
197 |
+
tr_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
198 |
+
) # opencv4
|
199 |
+
|
200 |
+
mask = np.zeros_like(tr_mask)
|
201 |
+
boundaries = []
|
202 |
+
for cont in tr_contours:
|
203 |
+
deal_map = mask.copy().astype(np.int8)
|
204 |
+
cv2.drawContours(deal_map, [cont], -1, 1, -1)
|
205 |
+
|
206 |
+
score_map = score_pred * deal_map
|
207 |
+
score_mask = score_map > 0
|
208 |
+
xy_text = np.argwhere(score_mask)
|
209 |
+
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
|
210 |
+
|
211 |
+
x, y = x_pred[score_mask], y_pred[score_mask]
|
212 |
+
c = x + y * 1j
|
213 |
+
c[:, fourier_degree] = c[:, fourier_degree] + dxy
|
214 |
+
c *= scale
|
215 |
+
|
216 |
+
polygons = fourier2poly(c, num_reconstr_points)
|
217 |
+
score = score_map[score_mask].reshape(-1, 1)
|
218 |
+
polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
|
219 |
+
|
220 |
+
boundaries = boundaries + polygons
|
221 |
+
|
222 |
+
boundaries = poly_nms(boundaries, nms_thr)
|
223 |
+
|
224 |
+
if box_type == "quad":
|
225 |
+
new_boundaries = []
|
226 |
+
for boundary in boundaries:
|
227 |
+
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
|
228 |
+
score = boundary[-1]
|
229 |
+
points = cv2.boxPoints(cv2.minAreaRect(poly))
|
230 |
+
points = np.int0(points)
|
231 |
+
new_boundaries.append(points.reshape(-1).tolist() + [score])
|
232 |
+
boundaries = new_boundaries
|
233 |
+
|
234 |
+
return boundaries
|
ocr/postprocess/locality_aware_nms.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Locality aware nms.
|
3 |
+
This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from shapely.geometry import Polygon
|
8 |
+
|
9 |
+
|
10 |
+
def intersection(g, p):
|
11 |
+
"""
|
12 |
+
Intersection.
|
13 |
+
"""
|
14 |
+
g = Polygon(g[:8].reshape((4, 2)))
|
15 |
+
p = Polygon(p[:8].reshape((4, 2)))
|
16 |
+
g = g.buffer(0)
|
17 |
+
p = p.buffer(0)
|
18 |
+
if not g.is_valid or not p.is_valid:
|
19 |
+
return 0
|
20 |
+
inter = Polygon(g).intersection(Polygon(p)).area
|
21 |
+
union = g.area + p.area - inter
|
22 |
+
if union == 0:
|
23 |
+
return 0
|
24 |
+
else:
|
25 |
+
return inter / union
|
26 |
+
|
27 |
+
|
28 |
+
def intersection_iog(g, p):
|
29 |
+
"""
|
30 |
+
Intersection_iog.
|
31 |
+
"""
|
32 |
+
g = Polygon(g[:8].reshape((4, 2)))
|
33 |
+
p = Polygon(p[:8].reshape((4, 2)))
|
34 |
+
if not g.is_valid or not p.is_valid:
|
35 |
+
return 0
|
36 |
+
inter = Polygon(g).intersection(Polygon(p)).area
|
37 |
+
# union = g.area + p.area - inter
|
38 |
+
union = p.area
|
39 |
+
if union == 0:
|
40 |
+
print("p_area is very small")
|
41 |
+
return 0
|
42 |
+
else:
|
43 |
+
return inter / union
|
44 |
+
|
45 |
+
|
46 |
+
def weighted_merge(g, p):
|
47 |
+
"""
|
48 |
+
Weighted merge.
|
49 |
+
"""
|
50 |
+
g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
|
51 |
+
g[8] = g[8] + p[8]
|
52 |
+
return g
|
53 |
+
|
54 |
+
|
55 |
+
def standard_nms(S, thres):
|
56 |
+
"""
|
57 |
+
Standard nms.
|
58 |
+
"""
|
59 |
+
order = np.argsort(S[:, 8])[::-1]
|
60 |
+
keep = []
|
61 |
+
while order.size > 0:
|
62 |
+
i = order[0]
|
63 |
+
keep.append(i)
|
64 |
+
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
65 |
+
|
66 |
+
inds = np.where(ovr <= thres)[0]
|
67 |
+
order = order[inds + 1]
|
68 |
+
|
69 |
+
return S[keep]
|
70 |
+
|
71 |
+
|
72 |
+
def standard_nms_inds(S, thres):
|
73 |
+
"""
|
74 |
+
Standard nms, retun inds.
|
75 |
+
"""
|
76 |
+
order = np.argsort(S[:, 8])[::-1]
|
77 |
+
keep = []
|
78 |
+
while order.size > 0:
|
79 |
+
i = order[0]
|
80 |
+
keep.append(i)
|
81 |
+
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
82 |
+
|
83 |
+
inds = np.where(ovr <= thres)[0]
|
84 |
+
order = order[inds + 1]
|
85 |
+
|
86 |
+
return keep
|
87 |
+
|
88 |
+
|
89 |
+
def nms(S, thres):
|
90 |
+
"""
|
91 |
+
nms.
|
92 |
+
"""
|
93 |
+
order = np.argsort(S[:, 8])[::-1]
|
94 |
+
keep = []
|
95 |
+
while order.size > 0:
|
96 |
+
i = order[0]
|
97 |
+
keep.append(i)
|
98 |
+
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
99 |
+
|
100 |
+
inds = np.where(ovr <= thres)[0]
|
101 |
+
order = order[inds + 1]
|
102 |
+
|
103 |
+
return keep
|
104 |
+
|
105 |
+
|
106 |
+
def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
|
107 |
+
"""
|
108 |
+
soft_nms
|
109 |
+
:para boxes_in, N x 9 (coords + score)
|
110 |
+
:para threshould, eliminate cases min score(0.001)
|
111 |
+
:para Nt_thres, iou_threshi
|
112 |
+
:para sigma, gaussian weght
|
113 |
+
:method, linear or gaussian
|
114 |
+
"""
|
115 |
+
boxes = boxes_in.copy()
|
116 |
+
N = boxes.shape[0]
|
117 |
+
if N is None or N < 1:
|
118 |
+
return np.array([])
|
119 |
+
pos, maxpos = 0, 0
|
120 |
+
weight = 0.0
|
121 |
+
inds = np.arange(N)
|
122 |
+
tbox, sbox = boxes[0].copy(), boxes[0].copy()
|
123 |
+
for i in range(N):
|
124 |
+
maxscore = boxes[i, 8]
|
125 |
+
maxpos = i
|
126 |
+
tbox = boxes[i].copy()
|
127 |
+
ti = inds[i]
|
128 |
+
pos = i + 1
|
129 |
+
# get max box
|
130 |
+
while pos < N:
|
131 |
+
if maxscore < boxes[pos, 8]:
|
132 |
+
maxscore = boxes[pos, 8]
|
133 |
+
maxpos = pos
|
134 |
+
pos = pos + 1
|
135 |
+
# add max box as a detection
|
136 |
+
boxes[i, :] = boxes[maxpos, :]
|
137 |
+
inds[i] = inds[maxpos]
|
138 |
+
# swap
|
139 |
+
boxes[maxpos, :] = tbox
|
140 |
+
inds[maxpos] = ti
|
141 |
+
tbox = boxes[i].copy()
|
142 |
+
pos = i + 1
|
143 |
+
# NMS iteration
|
144 |
+
while pos < N:
|
145 |
+
sbox = boxes[pos].copy()
|
146 |
+
ts_iou_val = intersection(tbox, sbox)
|
147 |
+
if ts_iou_val > 0:
|
148 |
+
if method == 1:
|
149 |
+
if ts_iou_val > Nt_thres:
|
150 |
+
weight = 1 - ts_iou_val
|
151 |
+
else:
|
152 |
+
weight = 1
|
153 |
+
elif method == 2:
|
154 |
+
weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
|
155 |
+
else:
|
156 |
+
if ts_iou_val > Nt_thres:
|
157 |
+
weight = 0
|
158 |
+
else:
|
159 |
+
weight = 1
|
160 |
+
boxes[pos, 8] = weight * boxes[pos, 8]
|
161 |
+
# if box score falls below thresold, discard the box by
|
162 |
+
# swaping last box update N
|
163 |
+
if boxes[pos, 8] < threshold:
|
164 |
+
boxes[pos, :] = boxes[N - 1, :]
|
165 |
+
inds[pos] = inds[N - 1]
|
166 |
+
N = N - 1
|
167 |
+
pos = pos - 1
|
168 |
+
pos = pos + 1
|
169 |
+
|
170 |
+
return boxes[:N]
|
171 |
+
|
172 |
+
|
173 |
+
def nms_locality(polys, thres=0.3):
|
174 |
+
"""
|
175 |
+
locality aware nms of EAST
|
176 |
+
:param polys: a N*9 numpy array. first 8 coordinates, then prob
|
177 |
+
:return: boxes after nms
|
178 |
+
"""
|
179 |
+
S = []
|
180 |
+
p = None
|
181 |
+
for g in polys:
|
182 |
+
if p is not None and intersection(g, p) > thres:
|
183 |
+
p = weighted_merge(g, p)
|
184 |
+
else:
|
185 |
+
if p is not None:
|
186 |
+
S.append(p)
|
187 |
+
p = g
|
188 |
+
if p is not None:
|
189 |
+
S.append(p)
|
190 |
+
|
191 |
+
if len(S) == 0:
|
192 |
+
return np.array([])
|
193 |
+
return standard_nms(np.array(S), thres)
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
# 343,350,448,135,474,143,369,359
|
198 |
+
print(Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]])).area)
|
ocr/postprocess/pg_postprocess.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import paddle
|
7 |
+
|
8 |
+
from .extract_textpoint_fast import *
|
9 |
+
from .extract_textpoint_slow import *
|
10 |
+
|
11 |
+
__dir__ = os.path.dirname(__file__)
|
12 |
+
sys.path.append(__dir__)
|
13 |
+
sys.path.append(os.path.join(__dir__, ".."))
|
14 |
+
|
15 |
+
|
16 |
+
class PGNet_PostProcess(object):
|
17 |
+
# two different post-process
|
18 |
+
def __init__(
|
19 |
+
self, character_dict_path, valid_set, score_thresh, outs_dict, shape_list
|
20 |
+
):
|
21 |
+
self.Lexicon_Table = get_dict(character_dict_path)
|
22 |
+
self.valid_set = valid_set
|
23 |
+
self.score_thresh = score_thresh
|
24 |
+
self.outs_dict = outs_dict
|
25 |
+
self.shape_list = shape_list
|
26 |
+
|
27 |
+
def pg_postprocess_fast(self):
|
28 |
+
p_score = self.outs_dict["f_score"]
|
29 |
+
p_border = self.outs_dict["f_border"]
|
30 |
+
p_char = self.outs_dict["f_char"]
|
31 |
+
p_direction = self.outs_dict["f_direction"]
|
32 |
+
if isinstance(p_score, paddle.Tensor):
|
33 |
+
p_score = p_score[0].numpy()
|
34 |
+
p_border = p_border[0].numpy()
|
35 |
+
p_direction = p_direction[0].numpy()
|
36 |
+
p_char = p_char[0].numpy()
|
37 |
+
else:
|
38 |
+
p_score = p_score[0]
|
39 |
+
p_border = p_border[0]
|
40 |
+
p_direction = p_direction[0]
|
41 |
+
p_char = p_char[0]
|
42 |
+
|
43 |
+
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
|
44 |
+
instance_yxs_list, seq_strs = generate_pivot_list_fast(
|
45 |
+
p_score,
|
46 |
+
p_char,
|
47 |
+
p_direction,
|
48 |
+
self.Lexicon_Table,
|
49 |
+
score_thresh=self.score_thresh,
|
50 |
+
)
|
51 |
+
poly_list, keep_str_list = restore_poly(
|
52 |
+
instance_yxs_list,
|
53 |
+
seq_strs,
|
54 |
+
p_border,
|
55 |
+
ratio_w,
|
56 |
+
ratio_h,
|
57 |
+
src_w,
|
58 |
+
src_h,
|
59 |
+
self.valid_set,
|
60 |
+
)
|
61 |
+
data = {
|
62 |
+
"points": poly_list,
|
63 |
+
"texts": keep_str_list,
|
64 |
+
}
|
65 |
+
return data
|
66 |
+
|
67 |
+
def pg_postprocess_slow(self):
|
68 |
+
p_score = self.outs_dict["f_score"]
|
69 |
+
p_border = self.outs_dict["f_border"]
|
70 |
+
p_char = self.outs_dict["f_char"]
|
71 |
+
p_direction = self.outs_dict["f_direction"]
|
72 |
+
if isinstance(p_score, paddle.Tensor):
|
73 |
+
p_score = p_score[0].numpy()
|
74 |
+
p_border = p_border[0].numpy()
|
75 |
+
p_direction = p_direction[0].numpy()
|
76 |
+
p_char = p_char[0].numpy()
|
77 |
+
else:
|
78 |
+
p_score = p_score[0]
|
79 |
+
p_border = p_border[0]
|
80 |
+
p_direction = p_direction[0]
|
81 |
+
p_char = p_char[0]
|
82 |
+
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
|
83 |
+
is_curved = self.valid_set == "totaltext"
|
84 |
+
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
|
85 |
+
p_score,
|
86 |
+
p_char,
|
87 |
+
p_direction,
|
88 |
+
score_thresh=self.score_thresh,
|
89 |
+
is_backbone=True,
|
90 |
+
is_curved=is_curved,
|
91 |
+
)
|
92 |
+
seq_strs = []
|
93 |
+
for char_idx_set in char_seq_idx_set:
|
94 |
+
pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
95 |
+
seq_strs.append(pr_str)
|
96 |
+
poly_list = []
|
97 |
+
keep_str_list = []
|
98 |
+
all_point_list = []
|
99 |
+
all_point_pair_list = []
|
100 |
+
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
101 |
+
if len(yx_center_line) == 1:
|
102 |
+
yx_center_line.append(yx_center_line[-1])
|
103 |
+
|
104 |
+
offset_expand = 1.0
|
105 |
+
if self.valid_set == "totaltext":
|
106 |
+
offset_expand = 1.2
|
107 |
+
|
108 |
+
point_pair_list = []
|
109 |
+
for batch_id, y, x in yx_center_line:
|
110 |
+
offset = p_border[:, y, x].reshape(2, 2)
|
111 |
+
if offset_expand != 1.0:
|
112 |
+
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
113 |
+
expand_length = np.clip(
|
114 |
+
offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
|
115 |
+
)
|
116 |
+
offset_detal = offset / offset_length * expand_length
|
117 |
+
offset = offset + offset_detal
|
118 |
+
ori_yx = np.array([y, x], dtype=np.float32)
|
119 |
+
point_pair = (
|
120 |
+
(ori_yx + offset)[:, ::-1]
|
121 |
+
* 4.0
|
122 |
+
/ np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
123 |
+
)
|
124 |
+
point_pair_list.append(point_pair)
|
125 |
+
|
126 |
+
all_point_list.append(
|
127 |
+
[int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h))]
|
128 |
+
)
|
129 |
+
all_point_pair_list.append(point_pair.round().astype(np.int32).tolist())
|
130 |
+
|
131 |
+
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
132 |
+
detected_poly = expand_poly_along_width(
|
133 |
+
detected_poly, shrink_ratio_of_width=0.2
|
134 |
+
)
|
135 |
+
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
136 |
+
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
137 |
+
|
138 |
+
if len(keep_str) < 2:
|
139 |
+
continue
|
140 |
+
|
141 |
+
keep_str_list.append(keep_str)
|
142 |
+
detected_poly = np.round(detected_poly).astype("int32")
|
143 |
+
if self.valid_set == "partvgg":
|
144 |
+
middle_point = len(detected_poly) // 2
|
145 |
+
detected_poly = detected_poly[
|
146 |
+
[0, middle_point - 1, middle_point, -1], :
|
147 |
+
]
|
148 |
+
poly_list.append(detected_poly)
|
149 |
+
elif self.valid_set == "totaltext":
|
150 |
+
poly_list.append(detected_poly)
|
151 |
+
else:
|
152 |
+
print("--> Not supported format.")
|
153 |
+
exit(-1)
|
154 |
+
data = {
|
155 |
+
"points": poly_list,
|
156 |
+
"texts": keep_str_list,
|
157 |
+
}
|
158 |
+
return data
|
159 |
+
|
160 |
+
|
161 |
+
class PGPostProcess(object):
|
162 |
+
"""
|
163 |
+
The post process for PGNet.
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self, character_dict_path, valid_set, score_thresh, mode, **kwargs):
|
167 |
+
self.character_dict_path = character_dict_path
|
168 |
+
self.valid_set = valid_set
|
169 |
+
self.score_thresh = score_thresh
|
170 |
+
self.mode = mode
|
171 |
+
|
172 |
+
# c++ la-nms is faster, but only support python 3.5
|
173 |
+
self.is_python35 = False
|
174 |
+
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
175 |
+
self.is_python35 = True
|
176 |
+
|
177 |
+
def __call__(self, outs_dict, shape_list):
|
178 |
+
post = PGNet_PostProcess(
|
179 |
+
self.character_dict_path,
|
180 |
+
self.valid_set,
|
181 |
+
self.score_thresh,
|
182 |
+
outs_dict,
|
183 |
+
shape_list,
|
184 |
+
)
|
185 |
+
if self.mode == "fast":
|
186 |
+
data = post.pg_postprocess_fast()
|
187 |
+
else:
|
188 |
+
data = post.pg_postprocess_slow()
|
189 |
+
return data
|
ocr/postprocess/poly_nms.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from shapely.geometry import Polygon
|
3 |
+
|
4 |
+
|
5 |
+
def points2polygon(points):
|
6 |
+
"""Convert k points to 1 polygon.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
points (ndarray or list): A ndarray or a list of shape (2k)
|
10 |
+
that indicates k points.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
polygon (Polygon): A polygon object.
|
14 |
+
"""
|
15 |
+
if isinstance(points, list):
|
16 |
+
points = np.array(points)
|
17 |
+
|
18 |
+
assert isinstance(points, np.ndarray)
|
19 |
+
assert (points.size % 2 == 0) and (points.size >= 8)
|
20 |
+
|
21 |
+
point_mat = points.reshape([-1, 2])
|
22 |
+
return Polygon(point_mat)
|
23 |
+
|
24 |
+
|
25 |
+
def poly_intersection(poly_det, poly_gt, buffer=0.0001):
|
26 |
+
"""Calculate the intersection area between two polygon.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
poly_det (Polygon): A polygon predicted by detector.
|
30 |
+
poly_gt (Polygon): A gt polygon.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
intersection_area (float): The intersection area between two polygons.
|
34 |
+
"""
|
35 |
+
assert isinstance(poly_det, Polygon)
|
36 |
+
assert isinstance(poly_gt, Polygon)
|
37 |
+
|
38 |
+
if buffer == 0:
|
39 |
+
poly_inter = poly_det & poly_gt
|
40 |
+
else:
|
41 |
+
poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
|
42 |
+
return poly_inter.area, poly_inter
|
43 |
+
|
44 |
+
|
45 |
+
def poly_union(poly_det, poly_gt):
|
46 |
+
"""Calculate the union area between two polygon.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
poly_det (Polygon): A polygon predicted by detector.
|
50 |
+
poly_gt (Polygon): A gt polygon.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
union_area (float): The union area between two polygons.
|
54 |
+
"""
|
55 |
+
assert isinstance(poly_det, Polygon)
|
56 |
+
assert isinstance(poly_gt, Polygon)
|
57 |
+
|
58 |
+
area_det = poly_det.area
|
59 |
+
area_gt = poly_gt.area
|
60 |
+
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
61 |
+
return area_det + area_gt - area_inters
|
62 |
+
|
63 |
+
|
64 |
+
def valid_boundary(x, with_score=True):
|
65 |
+
num = len(x)
|
66 |
+
if num < 8:
|
67 |
+
return False
|
68 |
+
if num % 2 == 0 and (not with_score):
|
69 |
+
return True
|
70 |
+
if num % 2 == 1 and with_score:
|
71 |
+
return True
|
72 |
+
|
73 |
+
return False
|
74 |
+
|
75 |
+
|
76 |
+
def boundary_iou(src, target):
|
77 |
+
"""Calculate the IOU between two boundaries.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
src (list): Source boundary.
|
81 |
+
target (list): Target boundary.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
iou (float): The iou between two boundaries.
|
85 |
+
"""
|
86 |
+
assert valid_boundary(src, False)
|
87 |
+
assert valid_boundary(target, False)
|
88 |
+
src_poly = points2polygon(src)
|
89 |
+
target_poly = points2polygon(target)
|
90 |
+
|
91 |
+
return poly_iou(src_poly, target_poly)
|
92 |
+
|
93 |
+
|
94 |
+
def poly_iou(poly_det, poly_gt):
|
95 |
+
"""Calculate the IOU between two polygons.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
poly_det (Polygon): A polygon predicted by detector.
|
99 |
+
poly_gt (Polygon): A gt polygon.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
iou (float): The IOU between two polygons.
|
103 |
+
"""
|
104 |
+
assert isinstance(poly_det, Polygon)
|
105 |
+
assert isinstance(poly_gt, Polygon)
|
106 |
+
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
107 |
+
area_union = poly_union(poly_det, poly_gt)
|
108 |
+
if area_union == 0:
|
109 |
+
return 0.0
|
110 |
+
return area_inters / area_union
|
111 |
+
|
112 |
+
|
113 |
+
def poly_nms(polygons, threshold):
|
114 |
+
assert isinstance(polygons, list)
|
115 |
+
|
116 |
+
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
|
117 |
+
|
118 |
+
keep_poly = []
|
119 |
+
index = [i for i in range(polygons.shape[0])]
|
120 |
+
|
121 |
+
while len(index) > 0:
|
122 |
+
keep_poly.append(polygons[index[-1]].tolist())
|
123 |
+
A = polygons[index[-1]][:-1]
|
124 |
+
index = np.delete(index, -1)
|
125 |
+
iou_list = np.zeros((len(index),))
|
126 |
+
for i in range(len(index)):
|
127 |
+
B = polygons[index[i]][:-1]
|
128 |
+
iou_list[i] = boundary_iou(A, B)
|
129 |
+
remove_index = np.where(iou_list > threshold)
|
130 |
+
index = np.delete(index, remove_index)
|
131 |
+
|
132 |
+
return keep_poly
|
ocr/postprocess/pse_postprocess/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pse_postprocess import PSEPostProcess
|
ocr/postprocess/pse_postprocess/pse/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
|
5 |
+
python_path = sys.executable
|
6 |
+
|
7 |
+
ori_path = os.getcwd()
|
8 |
+
os.chdir("ppocr/postprocess/pse_postprocess/pse")
|
9 |
+
if (
|
10 |
+
subprocess.call("{} setup.py build_ext --inplace".format(python_path), shell=True)
|
11 |
+
!= 0
|
12 |
+
):
|
13 |
+
raise RuntimeError(
|
14 |
+
"Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+".format(
|
15 |
+
os.path.dirname(os.path.realpath(__file__))
|
16 |
+
)
|
17 |
+
)
|
18 |
+
os.chdir(ori_path)
|
19 |
+
|
20 |
+
from .pse import pse
|
ocr/postprocess/pse_postprocess/pse/pse.pyx
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
cimport cython
|
6 |
+
cimport libcpp
|
7 |
+
cimport libcpp.pair
|
8 |
+
cimport libcpp.queue
|
9 |
+
cimport numpy as np
|
10 |
+
from libcpp.pair cimport *
|
11 |
+
from libcpp.queue cimport *
|
12 |
+
|
13 |
+
|
14 |
+
@cython.boundscheck(False)
|
15 |
+
@cython.wraparound(False)
|
16 |
+
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
|
17 |
+
np.ndarray[np.int32_t, ndim=2] label,
|
18 |
+
int kernel_num,
|
19 |
+
int label_num,
|
20 |
+
float min_area=0):
|
21 |
+
cdef np.ndarray[np.int32_t, ndim=2] pred
|
22 |
+
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
|
23 |
+
|
24 |
+
for label_idx in range(1, label_num):
|
25 |
+
if np.sum(label == label_idx) < min_area:
|
26 |
+
label[label == label_idx] = 0
|
27 |
+
|
28 |
+
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
|
29 |
+
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
30 |
+
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
|
31 |
+
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
32 |
+
cdef np.int16_t* dx = [-1, 1, 0, 0]
|
33 |
+
cdef np.int16_t* dy = [0, 0, -1, 1]
|
34 |
+
cdef np.int16_t tmpx, tmpy
|
35 |
+
|
36 |
+
points = np.array(np.where(label > 0)).transpose((1, 0))
|
37 |
+
for point_idx in range(points.shape[0]):
|
38 |
+
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
|
39 |
+
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
40 |
+
pred[tmpx, tmpy] = label[tmpx, tmpy]
|
41 |
+
|
42 |
+
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
|
43 |
+
cdef int cur_label
|
44 |
+
for kernel_idx in range(kernel_num - 1, -1, -1):
|
45 |
+
while not que.empty():
|
46 |
+
cur = que.front()
|
47 |
+
que.pop()
|
48 |
+
cur_label = pred[cur.first, cur.second]
|
49 |
+
|
50 |
+
is_edge = True
|
51 |
+
for j in range(4):
|
52 |
+
tmpx = cur.first + dx[j]
|
53 |
+
tmpy = cur.second + dy[j]
|
54 |
+
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
|
55 |
+
continue
|
56 |
+
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
|
57 |
+
continue
|
58 |
+
|
59 |
+
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
60 |
+
pred[tmpx, tmpy] = cur_label
|
61 |
+
is_edge = False
|
62 |
+
if is_edge:
|
63 |
+
nxt_que.push(cur)
|
64 |
+
|
65 |
+
que, nxt_que = nxt_que, que
|
66 |
+
|
67 |
+
return pred
|
68 |
+
|
69 |
+
def pse(kernels, min_area):
|
70 |
+
kernel_num = kernels.shape[0]
|
71 |
+
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
|
72 |
+
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
|
ocr/postprocess/pse_postprocess/pse/setup.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils.core import Extension, setup
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
from Cython.Build import cythonize
|
5 |
+
|
6 |
+
setup(
|
7 |
+
ext_modules=cythonize(
|
8 |
+
Extension(
|
9 |
+
"pse",
|
10 |
+
sources=["pse.pyx"],
|
11 |
+
language="c++",
|
12 |
+
include_dirs=[numpy.get_include()],
|
13 |
+
library_dirs=[],
|
14 |
+
libraries=[],
|
15 |
+
extra_compile_args=["-O3"],
|
16 |
+
extra_link_args=[],
|
17 |
+
)
|
18 |
+
)
|
19 |
+
)
|
ocr/postprocess/pse_postprocess/pse_postprocess.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import paddle
|
6 |
+
from paddle.nn import functional as F
|
7 |
+
|
8 |
+
from .pse import pse
|
9 |
+
|
10 |
+
|
11 |
+
class PSEPostProcess(object):
|
12 |
+
"""
|
13 |
+
The post process for PSE.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
thresh=0.5,
|
19 |
+
box_thresh=0.85,
|
20 |
+
min_area=16,
|
21 |
+
box_type="quad",
|
22 |
+
scale=4,
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
+
assert box_type in ["quad", "poly"], "Only quad and poly is supported"
|
26 |
+
self.thresh = thresh
|
27 |
+
self.box_thresh = box_thresh
|
28 |
+
self.min_area = min_area
|
29 |
+
self.box_type = box_type
|
30 |
+
self.scale = scale
|
31 |
+
|
32 |
+
def __call__(self, outs_dict, shape_list):
|
33 |
+
pred = outs_dict["maps"]
|
34 |
+
if not isinstance(pred, paddle.Tensor):
|
35 |
+
pred = paddle.to_tensor(pred)
|
36 |
+
pred = F.interpolate(pred, scale_factor=4 // self.scale, mode="bilinear")
|
37 |
+
|
38 |
+
score = F.sigmoid(pred[:, 0, :, :])
|
39 |
+
|
40 |
+
kernels = (pred > self.thresh).astype("float32")
|
41 |
+
text_mask = kernels[:, 0, :, :]
|
42 |
+
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
|
43 |
+
|
44 |
+
score = score.numpy()
|
45 |
+
kernels = kernels.numpy().astype(np.uint8)
|
46 |
+
|
47 |
+
boxes_batch = []
|
48 |
+
for batch_index in range(pred.shape[0]):
|
49 |
+
boxes, scores = self.boxes_from_bitmap(
|
50 |
+
score[batch_index], kernels[batch_index], shape_list[batch_index]
|
51 |
+
)
|
52 |
+
|
53 |
+
boxes_batch.append({"points": boxes, "scores": scores})
|
54 |
+
return boxes_batch
|
55 |
+
|
56 |
+
def boxes_from_bitmap(self, score, kernels, shape):
|
57 |
+
label = pse(kernels, self.min_area)
|
58 |
+
return self.generate_box(score, label, shape)
|
59 |
+
|
60 |
+
def generate_box(self, score, label, shape):
|
61 |
+
src_h, src_w, ratio_h, ratio_w = shape
|
62 |
+
label_num = np.max(label) + 1
|
63 |
+
|
64 |
+
boxes = []
|
65 |
+
scores = []
|
66 |
+
for i in range(1, label_num):
|
67 |
+
ind = label == i
|
68 |
+
points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
|
69 |
+
|
70 |
+
if points.shape[0] < self.min_area:
|
71 |
+
label[ind] = 0
|
72 |
+
continue
|
73 |
+
|
74 |
+
score_i = np.mean(score[ind])
|
75 |
+
if score_i < self.box_thresh:
|
76 |
+
label[ind] = 0
|
77 |
+
continue
|
78 |
+
|
79 |
+
if self.box_type == "quad":
|
80 |
+
rect = cv2.minAreaRect(points)
|
81 |
+
bbox = cv2.boxPoints(rect)
|
82 |
+
elif self.box_type == "poly":
|
83 |
+
box_height = np.max(points[:, 1]) + 10
|
84 |
+
box_width = np.max(points[:, 0]) + 10
|
85 |
+
|
86 |
+
mask = np.zeros((box_height, box_width), np.uint8)
|
87 |
+
mask[points[:, 1], points[:, 0]] = 255
|
88 |
+
|
89 |
+
contours, _ = cv2.findContours(
|
90 |
+
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
91 |
+
)
|
92 |
+
bbox = np.squeeze(contours[0], 1)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError
|
95 |
+
|
96 |
+
bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w)
|
97 |
+
bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h)
|
98 |
+
boxes.append(bbox)
|
99 |
+
scores.append(score_i)
|
100 |
+
return boxes, scores
|
ocr/postprocess/rec_postprocess.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import paddle
|
5 |
+
from paddle.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class BaseRecLabelDecode(object):
|
9 |
+
"""Convert between text-label and text-index"""
|
10 |
+
|
11 |
+
def __init__(self, character_dict_path=None, use_space_char=False):
|
12 |
+
self.beg_str = "sos"
|
13 |
+
self.end_str = "eos"
|
14 |
+
|
15 |
+
self.character_str = []
|
16 |
+
if character_dict_path is None:
|
17 |
+
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
18 |
+
dict_character = list(self.character_str)
|
19 |
+
else:
|
20 |
+
with open(character_dict_path, "rb") as fin:
|
21 |
+
lines = fin.readlines()
|
22 |
+
for line in lines:
|
23 |
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
24 |
+
self.character_str.append(line)
|
25 |
+
if use_space_char:
|
26 |
+
self.character_str.append(" ")
|
27 |
+
dict_character = list(self.character_str)
|
28 |
+
|
29 |
+
dict_character = self.add_special_char(dict_character)
|
30 |
+
self.dict = {}
|
31 |
+
for i, char in enumerate(dict_character):
|
32 |
+
self.dict[char] = i
|
33 |
+
self.character = dict_character
|
34 |
+
|
35 |
+
def add_special_char(self, dict_character):
|
36 |
+
return dict_character
|
37 |
+
|
38 |
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
39 |
+
"""convert text-index into text-label."""
|
40 |
+
result_list = []
|
41 |
+
ignored_tokens = self.get_ignored_tokens()
|
42 |
+
batch_size = len(text_index)
|
43 |
+
for batch_idx in range(batch_size):
|
44 |
+
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
45 |
+
if is_remove_duplicate:
|
46 |
+
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
|
47 |
+
for ignored_token in ignored_tokens:
|
48 |
+
selection &= text_index[batch_idx] != ignored_token
|
49 |
+
|
50 |
+
char_list = [
|
51 |
+
self.character[text_id] for text_id in text_index[batch_idx][selection]
|
52 |
+
]
|
53 |
+
if text_prob is not None:
|
54 |
+
conf_list = text_prob[batch_idx][selection]
|
55 |
+
else:
|
56 |
+
conf_list = [1] * len(selection)
|
57 |
+
if len(conf_list) == 0:
|
58 |
+
conf_list = [0]
|
59 |
+
|
60 |
+
text = "".join(char_list)
|
61 |
+
result_list.append((text, np.mean(conf_list).tolist()))
|
62 |
+
return result_list
|
63 |
+
|
64 |
+
def get_ignored_tokens(self):
|
65 |
+
return [0] # for ctc blank
|
66 |
+
|
67 |
+
|
68 |
+
class CTCLabelDecode(BaseRecLabelDecode):
|
69 |
+
"""Convert between text-label and text-index"""
|
70 |
+
|
71 |
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
72 |
+
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
|
73 |
+
|
74 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
75 |
+
if isinstance(preds, tuple) or isinstance(preds, list):
|
76 |
+
preds = preds[-1]
|
77 |
+
if isinstance(preds, paddle.Tensor):
|
78 |
+
preds = preds.numpy()
|
79 |
+
preds_idx = preds.argmax(axis=2)
|
80 |
+
preds_prob = preds.max(axis=2)
|
81 |
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
82 |
+
if label is None:
|
83 |
+
return text
|
84 |
+
label = self.decode(label)
|
85 |
+
return text, label
|
86 |
+
|
87 |
+
def add_special_char(self, dict_character):
|
88 |
+
dict_character = ["blank"] + dict_character
|
89 |
+
return dict_character
|
90 |
+
|
91 |
+
|
92 |
+
class DistillationCTCLabelDecode(CTCLabelDecode):
|
93 |
+
"""
|
94 |
+
Convert
|
95 |
+
Convert between text-label and text-index
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
character_dict_path=None,
|
101 |
+
use_space_char=False,
|
102 |
+
model_name=["student"],
|
103 |
+
key=None,
|
104 |
+
multi_head=False,
|
105 |
+
**kwargs
|
106 |
+
):
|
107 |
+
super(DistillationCTCLabelDecode, self).__init__(
|
108 |
+
character_dict_path, use_space_char
|
109 |
+
)
|
110 |
+
if not isinstance(model_name, list):
|
111 |
+
model_name = [model_name]
|
112 |
+
self.model_name = model_name
|
113 |
+
|
114 |
+
self.key = key
|
115 |
+
self.multi_head = multi_head
|
116 |
+
|
117 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
118 |
+
output = dict()
|
119 |
+
for name in self.model_name:
|
120 |
+
pred = preds[name]
|
121 |
+
if self.key is not None:
|
122 |
+
pred = pred[self.key]
|
123 |
+
if self.multi_head and isinstance(pred, dict):
|
124 |
+
pred = pred["ctc"]
|
125 |
+
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
126 |
+
return output
|
127 |
+
|
128 |
+
|
129 |
+
class NRTRLabelDecode(BaseRecLabelDecode):
|
130 |
+
"""Convert between text-label and text-index"""
|
131 |
+
|
132 |
+
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
|
133 |
+
super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
|
134 |
+
|
135 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
136 |
+
|
137 |
+
if len(preds) == 2:
|
138 |
+
preds_id = preds[0]
|
139 |
+
preds_prob = preds[1]
|
140 |
+
if isinstance(preds_id, paddle.Tensor):
|
141 |
+
preds_id = preds_id.numpy()
|
142 |
+
if isinstance(preds_prob, paddle.Tensor):
|
143 |
+
preds_prob = preds_prob.numpy()
|
144 |
+
if preds_id[0][0] == 2:
|
145 |
+
preds_idx = preds_id[:, 1:]
|
146 |
+
preds_prob = preds_prob[:, 1:]
|
147 |
+
else:
|
148 |
+
preds_idx = preds_id
|
149 |
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
150 |
+
if label is None:
|
151 |
+
return text
|
152 |
+
label = self.decode(label[:, 1:])
|
153 |
+
else:
|
154 |
+
if isinstance(preds, paddle.Tensor):
|
155 |
+
preds = preds.numpy()
|
156 |
+
preds_idx = preds.argmax(axis=2)
|
157 |
+
preds_prob = preds.max(axis=2)
|
158 |
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
159 |
+
if label is None:
|
160 |
+
return text
|
161 |
+
label = self.decode(label[:, 1:])
|
162 |
+
return text, label
|
163 |
+
|
164 |
+
def add_special_char(self, dict_character):
|
165 |
+
dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
|
166 |
+
return dict_character
|
167 |
+
|
168 |
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
169 |
+
"""convert text-index into text-label."""
|
170 |
+
result_list = []
|
171 |
+
batch_size = len(text_index)
|
172 |
+
for batch_idx in range(batch_size):
|
173 |
+
char_list = []
|
174 |
+
conf_list = []
|
175 |
+
for idx in range(len(text_index[batch_idx])):
|
176 |
+
if text_index[batch_idx][idx] == 3: # end
|
177 |
+
break
|
178 |
+
try:
|
179 |
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
180 |
+
except:
|
181 |
+
continue
|
182 |
+
if text_prob is not None:
|
183 |
+
conf_list.append(text_prob[batch_idx][idx])
|
184 |
+
else:
|
185 |
+
conf_list.append(1)
|
186 |
+
text = "".join(char_list)
|
187 |
+
result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
188 |
+
return result_list
|
189 |
+
|
190 |
+
|
191 |
+
class AttnLabelDecode(BaseRecLabelDecode):
|
192 |
+
"""Convert between text-label and text-index"""
|
193 |
+
|
194 |
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
195 |
+
super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
|
196 |
+
|
197 |
+
def add_special_char(self, dict_character):
|
198 |
+
self.beg_str = "sos"
|
199 |
+
self.end_str = "eos"
|
200 |
+
dict_character = dict_character
|
201 |
+
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
202 |
+
return dict_character
|
203 |
+
|
204 |
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
205 |
+
"""convert text-index into text-label."""
|
206 |
+
result_list = []
|
207 |
+
ignored_tokens = self.get_ignored_tokens()
|
208 |
+
[beg_idx, end_idx] = self.get_ignored_tokens()
|
209 |
+
batch_size = len(text_index)
|
210 |
+
for batch_idx in range(batch_size):
|
211 |
+
char_list = []
|
212 |
+
conf_list = []
|
213 |
+
for idx in range(len(text_index[batch_idx])):
|
214 |
+
if text_index[batch_idx][idx] in ignored_tokens:
|
215 |
+
continue
|
216 |
+
if int(text_index[batch_idx][idx]) == int(end_idx):
|
217 |
+
break
|
218 |
+
if is_remove_duplicate:
|
219 |
+
# only for predict
|
220 |
+
if (
|
221 |
+
idx > 0
|
222 |
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
223 |
+
):
|
224 |
+
continue
|
225 |
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
226 |
+
if text_prob is not None:
|
227 |
+
conf_list.append(text_prob[batch_idx][idx])
|
228 |
+
else:
|
229 |
+
conf_list.append(1)
|
230 |
+
text = "".join(char_list)
|
231 |
+
result_list.append((text, np.mean(conf_list).tolist()))
|
232 |
+
return result_list
|
233 |
+
|
234 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
235 |
+
"""
|
236 |
+
text = self.decode(text)
|
237 |
+
if label is None:
|
238 |
+
return text
|
239 |
+
else:
|
240 |
+
label = self.decode(label, is_remove_duplicate=False)
|
241 |
+
return text, label
|
242 |
+
"""
|
243 |
+
if isinstance(preds, paddle.Tensor):
|
244 |
+
preds = preds.numpy()
|
245 |
+
|
246 |
+
preds_idx = preds.argmax(axis=2)
|
247 |
+
preds_prob = preds.max(axis=2)
|
248 |
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
249 |
+
if label is None:
|
250 |
+
return text
|
251 |
+
label = self.decode(label, is_remove_duplicate=False)
|
252 |
+
return text, label
|
253 |
+
|
254 |
+
def get_ignored_tokens(self):
|
255 |
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
256 |
+
end_idx = self.get_beg_end_flag_idx("end")
|
257 |
+
return [beg_idx, end_idx]
|
258 |
+
|
259 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
260 |
+
if beg_or_end == "beg":
|
261 |
+
idx = np.array(self.dict[self.beg_str])
|
262 |
+
elif beg_or_end == "end":
|
263 |
+
idx = np.array(self.dict[self.end_str])
|
264 |
+
else:
|
265 |
+
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
266 |
+
return idx
|
267 |
+
|
268 |
+
|
269 |
+
class SEEDLabelDecode(BaseRecLabelDecode):
|
270 |
+
"""Convert between text-label and text-index"""
|
271 |
+
|
272 |
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
273 |
+
super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
|
274 |
+
|
275 |
+
def add_special_char(self, dict_character):
|
276 |
+
self.padding_str = "padding"
|
277 |
+
self.end_str = "eos"
|
278 |
+
self.unknown = "unknown"
|
279 |
+
dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
|
280 |
+
return dict_character
|
281 |
+
|
282 |
+
def get_ignored_tokens(self):
|
283 |
+
end_idx = self.get_beg_end_flag_idx("eos")
|
284 |
+
return [end_idx]
|
285 |
+
|
286 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
287 |
+
if beg_or_end == "sos":
|
288 |
+
idx = np.array(self.dict[self.beg_str])
|
289 |
+
elif beg_or_end == "eos":
|
290 |
+
idx = np.array(self.dict[self.end_str])
|
291 |
+
else:
|
292 |
+
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
293 |
+
return idx
|
294 |
+
|
295 |
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
296 |
+
"""convert text-index into text-label."""
|
297 |
+
result_list = []
|
298 |
+
[end_idx] = self.get_ignored_tokens()
|
299 |
+
batch_size = len(text_index)
|
300 |
+
for batch_idx in range(batch_size):
|
301 |
+
char_list = []
|
302 |
+
conf_list = []
|
303 |
+
for idx in range(len(text_index[batch_idx])):
|
304 |
+
if int(text_index[batch_idx][idx]) == int(end_idx):
|
305 |
+
break
|
306 |
+
if is_remove_duplicate:
|
307 |
+
# only for predict
|
308 |
+
if (
|
309 |
+
idx > 0
|
310 |
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
311 |
+
):
|
312 |
+
continue
|
313 |
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
314 |
+
if text_prob is not None:
|
315 |
+
conf_list.append(text_prob[batch_idx][idx])
|
316 |
+
else:
|
317 |
+
conf_list.append(1)
|
318 |
+
text = "".join(char_list)
|
319 |
+
result_list.append((text, np.mean(conf_list).tolist()))
|
320 |
+
return result_list
|
321 |
+
|
322 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
323 |
+
"""
|
324 |
+
text = self.decode(text)
|
325 |
+
if label is None:
|
326 |
+
return text
|
327 |
+
else:
|
328 |
+
label = self.decode(label, is_remove_duplicate=False)
|
329 |
+
return text, label
|
330 |
+
"""
|
331 |
+
preds_idx = preds["rec_pred"]
|
332 |
+
if isinstance(preds_idx, paddle.Tensor):
|
333 |
+
preds_idx = preds_idx.numpy()
|
334 |
+
if "rec_pred_scores" in preds:
|
335 |
+
preds_idx = preds["rec_pred"]
|
336 |
+
preds_prob = preds["rec_pred_scores"]
|
337 |
+
else:
|
338 |
+
preds_idx = preds["rec_pred"].argmax(axis=2)
|
339 |
+
preds_prob = preds["rec_pred"].max(axis=2)
|
340 |
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
341 |
+
if label is None:
|
342 |
+
return text
|
343 |
+
label = self.decode(label, is_remove_duplicate=False)
|
344 |
+
return text, label
|
345 |
+
|
346 |
+
|
347 |
+
class SRNLabelDecode(BaseRecLabelDecode):
|
348 |
+
"""Convert between text-label and text-index"""
|
349 |
+
|
350 |
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
351 |
+
super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
|
352 |
+
self.max_text_length = kwargs.get("max_text_length", 25)
|
353 |
+
|
354 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
355 |
+
pred = preds["predict"]
|
356 |
+
char_num = len(self.character_str) + 2
|
357 |
+
if isinstance(pred, paddle.Tensor):
|
358 |
+
pred = pred.numpy()
|
359 |
+
pred = np.reshape(pred, [-1, char_num])
|
360 |
+
|
361 |
+
preds_idx = np.argmax(pred, axis=1)
|
362 |
+
preds_prob = np.max(pred, axis=1)
|
363 |
+
|
364 |
+
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
|
365 |
+
|
366 |
+
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
|
367 |
+
|
368 |
+
text = self.decode(preds_idx, preds_prob)
|
369 |
+
|
370 |
+
if label is None:
|
371 |
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
372 |
+
return text
|
373 |
+
label = self.decode(label)
|
374 |
+
return text, label
|
375 |
+
|
376 |
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
377 |
+
"""convert text-index into text-label."""
|
378 |
+
result_list = []
|
379 |
+
ignored_tokens = self.get_ignored_tokens()
|
380 |
+
batch_size = len(text_index)
|
381 |
+
|
382 |
+
for batch_idx in range(batch_size):
|
383 |
+
char_list = []
|
384 |
+
conf_list = []
|
385 |
+
for idx in range(len(text_index[batch_idx])):
|
386 |
+
if text_index[batch_idx][idx] in ignored_tokens:
|
387 |
+
continue
|
388 |
+
if is_remove_duplicate:
|
389 |
+
# only for predict
|
390 |
+
if (
|
391 |
+
idx > 0
|
392 |
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
393 |
+
):
|
394 |
+
continue
|
395 |
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
396 |
+
if text_prob is not None:
|
397 |
+
conf_list.append(text_prob[batch_idx][idx])
|
398 |
+
else:
|
399 |
+
conf_list.append(1)
|
400 |
+
|
401 |
+
text = "".join(char_list)
|
402 |
+
result_list.append((text, np.mean(conf_list).tolist()))
|
403 |
+
return result_list
|
404 |
+
|
405 |
+
def add_special_char(self, dict_character):
|
406 |
+
dict_character = dict_character + [self.beg_str, self.end_str]
|
407 |
+
return dict_character
|
408 |
+
|
409 |
+
def get_ignored_tokens(self):
|
410 |
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
411 |
+
end_idx = self.get_beg_end_flag_idx("end")
|
412 |
+
return [beg_idx, end_idx]
|
413 |
+
|
414 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
415 |
+
if beg_or_end == "beg":
|
416 |
+
idx = np.array(self.dict[self.beg_str])
|
417 |
+
elif beg_or_end == "end":
|
418 |
+
idx = np.array(self.dict[self.end_str])
|
419 |
+
else:
|
420 |
+
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
421 |
+
return idx
|
422 |
+
|
423 |
+
|
424 |
+
class TableLabelDecode(object):
|
425 |
+
""" """
|
426 |
+
|
427 |
+
def __init__(self, character_dict_path, **kwargs):
|
428 |
+
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
429 |
+
list_character = self.add_special_char(list_character)
|
430 |
+
list_elem = self.add_special_char(list_elem)
|
431 |
+
self.dict_character = {}
|
432 |
+
self.dict_idx_character = {}
|
433 |
+
for i, char in enumerate(list_character):
|
434 |
+
self.dict_idx_character[i] = char
|
435 |
+
self.dict_character[char] = i
|
436 |
+
self.dict_elem = {}
|
437 |
+
self.dict_idx_elem = {}
|
438 |
+
for i, elem in enumerate(list_elem):
|
439 |
+
self.dict_idx_elem[i] = elem
|
440 |
+
self.dict_elem[elem] = i
|
441 |
+
|
442 |
+
def load_char_elem_dict(self, character_dict_path):
|
443 |
+
list_character = []
|
444 |
+
list_elem = []
|
445 |
+
with open(character_dict_path, "rb") as fin:
|
446 |
+
lines = fin.readlines()
|
447 |
+
substr = lines[0].decode("utf-8").strip("\n").strip("\r\n").split("\t")
|
448 |
+
character_num = int(substr[0])
|
449 |
+
elem_num = int(substr[1])
|
450 |
+
for cno in range(1, 1 + character_num):
|
451 |
+
character = lines[cno].decode("utf-8").strip("\n").strip("\r\n")
|
452 |
+
list_character.append(character)
|
453 |
+
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
454 |
+
elem = lines[eno].decode("utf-8").strip("\n").strip("\r\n")
|
455 |
+
list_elem.append(elem)
|
456 |
+
return list_character, list_elem
|
457 |
+
|
458 |
+
def add_special_char(self, list_character):
|
459 |
+
self.beg_str = "sos"
|
460 |
+
self.end_str = "eos"
|
461 |
+
list_character = [self.beg_str] + list_character + [self.end_str]
|
462 |
+
return list_character
|
463 |
+
|
464 |
+
def __call__(self, preds):
|
465 |
+
structure_probs = preds["structure_probs"]
|
466 |
+
loc_preds = preds["loc_preds"]
|
467 |
+
if isinstance(structure_probs, paddle.Tensor):
|
468 |
+
structure_probs = structure_probs.numpy()
|
469 |
+
if isinstance(loc_preds, paddle.Tensor):
|
470 |
+
loc_preds = loc_preds.numpy()
|
471 |
+
structure_idx = structure_probs.argmax(axis=2)
|
472 |
+
structure_probs = structure_probs.max(axis=2)
|
473 |
+
(
|
474 |
+
structure_str,
|
475 |
+
structure_pos,
|
476 |
+
result_score_list,
|
477 |
+
result_elem_idx_list,
|
478 |
+
) = self.decode(structure_idx, structure_probs, "elem")
|
479 |
+
res_html_code_list = []
|
480 |
+
res_loc_list = []
|
481 |
+
batch_num = len(structure_str)
|
482 |
+
for bno in range(batch_num):
|
483 |
+
res_loc = []
|
484 |
+
for sno in range(len(structure_str[bno])):
|
485 |
+
text = structure_str[bno][sno]
|
486 |
+
if text in ["<td>", "<td"]:
|
487 |
+
pos = structure_pos[bno][sno]
|
488 |
+
res_loc.append(loc_preds[bno, pos])
|
489 |
+
res_html_code = "".join(structure_str[bno])
|
490 |
+
res_loc = np.array(res_loc)
|
491 |
+
res_html_code_list.append(res_html_code)
|
492 |
+
res_loc_list.append(res_loc)
|
493 |
+
return {
|
494 |
+
"res_html_code": res_html_code_list,
|
495 |
+
"res_loc": res_loc_list,
|
496 |
+
"res_score_list": result_score_list,
|
497 |
+
"res_elem_idx_list": result_elem_idx_list,
|
498 |
+
"structure_str_list": structure_str,
|
499 |
+
}
|
500 |
+
|
501 |
+
def decode(self, text_index, structure_probs, char_or_elem):
|
502 |
+
"""convert text-label into text-index."""
|
503 |
+
if char_or_elem == "char":
|
504 |
+
current_dict = self.dict_idx_character
|
505 |
+
else:
|
506 |
+
current_dict = self.dict_idx_elem
|
507 |
+
ignored_tokens = self.get_ignored_tokens("elem")
|
508 |
+
beg_idx, end_idx = ignored_tokens
|
509 |
+
|
510 |
+
result_list = []
|
511 |
+
result_pos_list = []
|
512 |
+
result_score_list = []
|
513 |
+
result_elem_idx_list = []
|
514 |
+
batch_size = len(text_index)
|
515 |
+
for batch_idx in range(batch_size):
|
516 |
+
char_list = []
|
517 |
+
elem_pos_list = []
|
518 |
+
elem_idx_list = []
|
519 |
+
score_list = []
|
520 |
+
for idx in range(len(text_index[batch_idx])):
|
521 |
+
tmp_elem_idx = int(text_index[batch_idx][idx])
|
522 |
+
if idx > 0 and tmp_elem_idx == end_idx:
|
523 |
+
break
|
524 |
+
if tmp_elem_idx in ignored_tokens:
|
525 |
+
continue
|
526 |
+
|
527 |
+
char_list.append(current_dict[tmp_elem_idx])
|
528 |
+
elem_pos_list.append(idx)
|
529 |
+
score_list.append(structure_probs[batch_idx, idx])
|
530 |
+
elem_idx_list.append(tmp_elem_idx)
|
531 |
+
result_list.append(char_list)
|
532 |
+
result_pos_list.append(elem_pos_list)
|
533 |
+
result_score_list.append(score_list)
|
534 |
+
result_elem_idx_list.append(elem_idx_list)
|
535 |
+
return result_list, result_pos_list, result_score_list, result_elem_idx_list
|
536 |
+
|
537 |
+
def get_ignored_tokens(self, char_or_elem):
|
538 |
+
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
539 |
+
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
540 |
+
return [beg_idx, end_idx]
|
541 |
+
|
542 |
+
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
543 |
+
if char_or_elem == "char":
|
544 |
+
if beg_or_end == "beg":
|
545 |
+
idx = self.dict_character[self.beg_str]
|
546 |
+
elif beg_or_end == "end":
|
547 |
+
idx = self.dict_character[self.end_str]
|
548 |
+
else:
|
549 |
+
assert False, (
|
550 |
+
"Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end
|
551 |
+
)
|
552 |
+
elif char_or_elem == "elem":
|
553 |
+
if beg_or_end == "beg":
|
554 |
+
idx = self.dict_elem[self.beg_str]
|
555 |
+
elif beg_or_end == "end":
|
556 |
+
idx = self.dict_elem[self.end_str]
|
557 |
+
else:
|
558 |
+
assert False, (
|
559 |
+
"Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end
|
560 |
+
)
|
561 |
+
else:
|
562 |
+
assert False, "Unsupport type %s in char_or_elem" % char_or_elem
|
563 |
+
return idx
|
564 |
+
|
565 |
+
|
566 |
+
class SARLabelDecode(BaseRecLabelDecode):
|
567 |
+
"""Convert between text-label and text-index"""
|
568 |
+
|
569 |
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
570 |
+
super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
|
571 |
+
|
572 |
+
self.rm_symbol = kwargs.get("rm_symbol", False)
|
573 |
+
|
574 |
+
def add_special_char(self, dict_character):
|
575 |
+
beg_end_str = "<BOS/EOS>"
|
576 |
+
unknown_str = "<UKN>"
|
577 |
+
padding_str = "<PAD>"
|
578 |
+
dict_character = dict_character + [unknown_str]
|
579 |
+
self.unknown_idx = len(dict_character) - 1
|
580 |
+
dict_character = dict_character + [beg_end_str]
|
581 |
+
self.start_idx = len(dict_character) - 1
|
582 |
+
self.end_idx = len(dict_character) - 1
|
583 |
+
dict_character = dict_character + [padding_str]
|
584 |
+
self.padding_idx = len(dict_character) - 1
|
585 |
+
return dict_character
|
586 |
+
|
587 |
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
588 |
+
"""convert text-index into text-label."""
|
589 |
+
result_list = []
|
590 |
+
ignored_tokens = self.get_ignored_tokens()
|
591 |
+
|
592 |
+
batch_size = len(text_index)
|
593 |
+
for batch_idx in range(batch_size):
|
594 |
+
char_list = []
|
595 |
+
conf_list = []
|
596 |
+
for idx in range(len(text_index[batch_idx])):
|
597 |
+
if text_index[batch_idx][idx] in ignored_tokens:
|
598 |
+
continue
|
599 |
+
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
600 |
+
if text_prob is None and idx == 0:
|
601 |
+
continue
|
602 |
+
else:
|
603 |
+
break
|
604 |
+
if is_remove_duplicate:
|
605 |
+
# only for predict
|
606 |
+
if (
|
607 |
+
idx > 0
|
608 |
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
609 |
+
):
|
610 |
+
continue
|
611 |
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
612 |
+
if text_prob is not None:
|
613 |
+
conf_list.append(text_prob[batch_idx][idx])
|
614 |
+
else:
|
615 |
+
conf_list.append(1)
|
616 |
+
text = "".join(char_list)
|
617 |
+
if self.rm_symbol:
|
618 |
+
comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
|
619 |
+
text = text.lower()
|
620 |
+
text = comp.sub("", text)
|
621 |
+
result_list.append((text, np.mean(conf_list).tolist()))
|
622 |
+
return result_list
|
623 |
+
|
624 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
625 |
+
if isinstance(preds, paddle.Tensor):
|
626 |
+
preds = preds.numpy()
|
627 |
+
preds_idx = preds.argmax(axis=2)
|
628 |
+
preds_prob = preds.max(axis=2)
|
629 |
+
|
630 |
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
631 |
+
|
632 |
+
if label is None:
|
633 |
+
return text
|
634 |
+
label = self.decode(label, is_remove_duplicate=False)
|
635 |
+
return text, label
|
636 |
+
|
637 |
+
def get_ignored_tokens(self):
|
638 |
+
return [self.padding_idx]
|
639 |
+
|
640 |
+
|
641 |
+
class DistillationSARLabelDecode(SARLabelDecode):
|
642 |
+
"""
|
643 |
+
Convert
|
644 |
+
Convert between text-label and text-index
|
645 |
+
"""
|
646 |
+
|
647 |
+
def __init__(
|
648 |
+
self,
|
649 |
+
character_dict_path=None,
|
650 |
+
use_space_char=False,
|
651 |
+
model_name=["student"],
|
652 |
+
key=None,
|
653 |
+
multi_head=False,
|
654 |
+
**kwargs
|
655 |
+
):
|
656 |
+
super(DistillationSARLabelDecode, self).__init__(
|
657 |
+
character_dict_path, use_space_char
|
658 |
+
)
|
659 |
+
if not isinstance(model_name, list):
|
660 |
+
model_name = [model_name]
|
661 |
+
self.model_name = model_name
|
662 |
+
|
663 |
+
self.key = key
|
664 |
+
self.multi_head = multi_head
|
665 |
+
|
666 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
667 |
+
output = dict()
|
668 |
+
for name in self.model_name:
|
669 |
+
pred = preds[name]
|
670 |
+
if self.key is not None:
|
671 |
+
pred = pred[self.key]
|
672 |
+
if self.multi_head and isinstance(pred, dict):
|
673 |
+
pred = pred["sar"]
|
674 |
+
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
675 |
+
return output
|
676 |
+
|
677 |
+
|
678 |
+
class PRENLabelDecode(BaseRecLabelDecode):
|
679 |
+
"""Convert between text-label and text-index"""
|
680 |
+
|
681 |
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
682 |
+
super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
|
683 |
+
|
684 |
+
def add_special_char(self, dict_character):
|
685 |
+
padding_str = "<PAD>" # 0
|
686 |
+
end_str = "<EOS>" # 1
|
687 |
+
unknown_str = "<UNK>" # 2
|
688 |
+
|
689 |
+
dict_character = [padding_str, end_str, unknown_str] + dict_character
|
690 |
+
self.padding_idx = 0
|
691 |
+
self.end_idx = 1
|
692 |
+
self.unknown_idx = 2
|
693 |
+
|
694 |
+
return dict_character
|
695 |
+
|
696 |
+
def decode(self, text_index, text_prob=None):
|
697 |
+
"""convert text-index into text-label."""
|
698 |
+
result_list = []
|
699 |
+
batch_size = len(text_index)
|
700 |
+
|
701 |
+
for batch_idx in range(batch_size):
|
702 |
+
char_list = []
|
703 |
+
conf_list = []
|
704 |
+
for idx in range(len(text_index[batch_idx])):
|
705 |
+
if text_index[batch_idx][idx] == self.end_idx:
|
706 |
+
break
|
707 |
+
if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
|
708 |
+
continue
|
709 |
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
710 |
+
if text_prob is not None:
|
711 |
+
conf_list.append(text_prob[batch_idx][idx])
|
712 |
+
else:
|
713 |
+
conf_list.append(1)
|
714 |
+
|
715 |
+
text = "".join(char_list)
|
716 |
+
if len(text) > 0:
|
717 |
+
result_list.append((text, np.mean(conf_list).tolist()))
|
718 |
+
else:
|
719 |
+
# here confidence of empty recog result is 1
|
720 |
+
result_list.append(("", 1))
|
721 |
+
return result_list
|
722 |
+
|
723 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
724 |
+
preds = preds.numpy()
|
725 |
+
preds_idx = preds.argmax(axis=2)
|
726 |
+
preds_prob = preds.max(axis=2)
|
727 |
+
text = self.decode(preds_idx, preds_prob)
|
728 |
+
if label is None:
|
729 |
+
return text
|
730 |
+
label = self.decode(label)
|
731 |
+
return text, label
|
ocr/postprocess/sast_postprocess.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
__dir__ = os.path.dirname(__file__)
|
7 |
+
sys.path.append(__dir__)
|
8 |
+
sys.path.append(os.path.join(__dir__, ".."))
|
9 |
+
|
10 |
+
import time
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import paddle
|
15 |
+
|
16 |
+
from .locality_aware_nms import nms_locality
|
17 |
+
|
18 |
+
|
19 |
+
class SASTPostProcess(object):
|
20 |
+
"""
|
21 |
+
The post process for SAST.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
score_thresh=0.5,
|
27 |
+
nms_thresh=0.2,
|
28 |
+
sample_pts_num=2,
|
29 |
+
shrink_ratio_of_width=0.3,
|
30 |
+
expand_scale=1.0,
|
31 |
+
tcl_map_thresh=0.5,
|
32 |
+
**kwargs
|
33 |
+
):
|
34 |
+
|
35 |
+
self.score_thresh = score_thresh
|
36 |
+
self.nms_thresh = nms_thresh
|
37 |
+
self.sample_pts_num = sample_pts_num
|
38 |
+
self.shrink_ratio_of_width = shrink_ratio_of_width
|
39 |
+
self.expand_scale = expand_scale
|
40 |
+
self.tcl_map_thresh = tcl_map_thresh
|
41 |
+
|
42 |
+
# c++ la-nms is faster, but only support python 3.5
|
43 |
+
self.is_python35 = False
|
44 |
+
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
45 |
+
self.is_python35 = True
|
46 |
+
|
47 |
+
def point_pair2poly(self, point_pair_list):
|
48 |
+
"""
|
49 |
+
Transfer vertical point_pairs into poly point in clockwise.
|
50 |
+
"""
|
51 |
+
# constract poly
|
52 |
+
point_num = len(point_pair_list) * 2
|
53 |
+
point_list = [0] * point_num
|
54 |
+
for idx, point_pair in enumerate(point_pair_list):
|
55 |
+
point_list[idx] = point_pair[0]
|
56 |
+
point_list[point_num - 1 - idx] = point_pair[1]
|
57 |
+
return np.array(point_list).reshape(-1, 2)
|
58 |
+
|
59 |
+
def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
|
60 |
+
"""
|
61 |
+
Generate shrink_quad_along_width.
|
62 |
+
"""
|
63 |
+
ratio_pair = np.array(
|
64 |
+
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32
|
65 |
+
)
|
66 |
+
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
67 |
+
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
68 |
+
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
69 |
+
|
70 |
+
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
|
71 |
+
"""
|
72 |
+
expand poly along width.
|
73 |
+
"""
|
74 |
+
point_num = poly.shape[0]
|
75 |
+
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
76 |
+
left_ratio = (
|
77 |
+
-shrink_ratio_of_width
|
78 |
+
* np.linalg.norm(left_quad[0] - left_quad[3])
|
79 |
+
/ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
80 |
+
)
|
81 |
+
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
82 |
+
right_quad = np.array(
|
83 |
+
[
|
84 |
+
poly[point_num // 2 - 2],
|
85 |
+
poly[point_num // 2 - 1],
|
86 |
+
poly[point_num // 2],
|
87 |
+
poly[point_num // 2 + 1],
|
88 |
+
],
|
89 |
+
dtype=np.float32,
|
90 |
+
)
|
91 |
+
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
|
92 |
+
right_quad[0] - right_quad[3]
|
93 |
+
) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
94 |
+
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
95 |
+
poly[0] = left_quad_expand[0]
|
96 |
+
poly[-1] = left_quad_expand[-1]
|
97 |
+
poly[point_num // 2 - 1] = right_quad_expand[1]
|
98 |
+
poly[point_num // 2] = right_quad_expand[2]
|
99 |
+
return poly
|
100 |
+
|
101 |
+
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
|
102 |
+
"""Restore quad."""
|
103 |
+
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
104 |
+
xy_text = xy_text[:, ::-1] # (n, 2)
|
105 |
+
|
106 |
+
# Sort the text boxes via the y axis
|
107 |
+
xy_text = xy_text[np.argsort(xy_text[:, 1])]
|
108 |
+
|
109 |
+
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
110 |
+
scores = scores[:, np.newaxis]
|
111 |
+
|
112 |
+
# Restore
|
113 |
+
point_num = int(tvo_map.shape[-1] / 2)
|
114 |
+
assert point_num == 4
|
115 |
+
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
|
116 |
+
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
117 |
+
quads = xy_text_tile - tvo_map
|
118 |
+
|
119 |
+
return scores, quads, xy_text
|
120 |
+
|
121 |
+
def quad_area(self, quad):
|
122 |
+
"""
|
123 |
+
compute area of a quad.
|
124 |
+
"""
|
125 |
+
edge = [
|
126 |
+
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
127 |
+
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
128 |
+
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
129 |
+
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]),
|
130 |
+
]
|
131 |
+
return np.sum(edge) / 2.0
|
132 |
+
|
133 |
+
def nms(self, dets):
|
134 |
+
if self.is_python35:
|
135 |
+
import lanms
|
136 |
+
|
137 |
+
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
|
138 |
+
else:
|
139 |
+
dets = nms_locality(dets, self.nms_thresh)
|
140 |
+
return dets
|
141 |
+
|
142 |
+
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
|
143 |
+
"""
|
144 |
+
Cluster pixels in tcl_map based on quads.
|
145 |
+
"""
|
146 |
+
instance_count = quads.shape[0] + 1 # contain background
|
147 |
+
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
|
148 |
+
if instance_count == 1:
|
149 |
+
return instance_count, instance_label_map
|
150 |
+
|
151 |
+
# predict text center
|
152 |
+
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
153 |
+
n = xy_text.shape[0]
|
154 |
+
xy_text = xy_text[:, ::-1] # (n, 2)
|
155 |
+
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
156 |
+
pred_tc = xy_text - tco
|
157 |
+
|
158 |
+
# get gt text center
|
159 |
+
m = quads.shape[0]
|
160 |
+
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
161 |
+
|
162 |
+
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
|
163 |
+
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
164 |
+
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
165 |
+
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
166 |
+
|
167 |
+
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
|
168 |
+
return instance_count, instance_label_map
|
169 |
+
|
170 |
+
def estimate_sample_pts_num(self, quad, xy_text):
|
171 |
+
"""
|
172 |
+
Estimate sample points number.
|
173 |
+
"""
|
174 |
+
eh = (
|
175 |
+
np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
|
176 |
+
) / 2.0
|
177 |
+
ew = (
|
178 |
+
np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
|
179 |
+
) / 2.0
|
180 |
+
|
181 |
+
dense_sample_pts_num = max(2, int(ew))
|
182 |
+
dense_xy_center_line = xy_text[
|
183 |
+
np.linspace(
|
184 |
+
0,
|
185 |
+
xy_text.shape[0] - 1,
|
186 |
+
dense_sample_pts_num,
|
187 |
+
endpoint=True,
|
188 |
+
dtype=np.float32,
|
189 |
+
).astype(np.int32)
|
190 |
+
]
|
191 |
+
|
192 |
+
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
|
193 |
+
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
|
194 |
+
|
195 |
+
sample_pts_num = max(2, int(estimate_arc_len / eh))
|
196 |
+
return sample_pts_num
|
197 |
+
|
198 |
+
def detect_sast(
|
199 |
+
self,
|
200 |
+
tcl_map,
|
201 |
+
tvo_map,
|
202 |
+
tbo_map,
|
203 |
+
tco_map,
|
204 |
+
ratio_w,
|
205 |
+
ratio_h,
|
206 |
+
src_w,
|
207 |
+
src_h,
|
208 |
+
shrink_ratio_of_width=0.3,
|
209 |
+
tcl_map_thresh=0.5,
|
210 |
+
offset_expand=1.0,
|
211 |
+
out_strid=4.0,
|
212 |
+
):
|
213 |
+
"""
|
214 |
+
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
|
215 |
+
"""
|
216 |
+
# restore quad
|
217 |
+
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
|
218 |
+
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
|
219 |
+
dets = self.nms(dets)
|
220 |
+
if dets.shape[0] == 0:
|
221 |
+
return []
|
222 |
+
quads = dets[:, :-1].reshape(-1, 4, 2)
|
223 |
+
|
224 |
+
# Compute quad area
|
225 |
+
quad_areas = []
|
226 |
+
for quad in quads:
|
227 |
+
quad_areas.append(-self.quad_area(quad))
|
228 |
+
|
229 |
+
# instance segmentation
|
230 |
+
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
|
231 |
+
instance_count, instance_label_map = self.cluster_by_quads_tco(
|
232 |
+
tcl_map, tcl_map_thresh, quads, tco_map
|
233 |
+
)
|
234 |
+
|
235 |
+
# restore single poly with tcl instance.
|
236 |
+
poly_list = []
|
237 |
+
for instance_idx in range(1, instance_count):
|
238 |
+
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
|
239 |
+
quad = quads[instance_idx - 1]
|
240 |
+
q_area = quad_areas[instance_idx - 1]
|
241 |
+
if q_area < 5:
|
242 |
+
continue
|
243 |
+
|
244 |
+
#
|
245 |
+
len1 = float(np.linalg.norm(quad[0] - quad[1]))
|
246 |
+
len2 = float(np.linalg.norm(quad[1] - quad[2]))
|
247 |
+
min_len = min(len1, len2)
|
248 |
+
if min_len < 3:
|
249 |
+
continue
|
250 |
+
|
251 |
+
# filter small CC
|
252 |
+
if xy_text.shape[0] <= 0:
|
253 |
+
continue
|
254 |
+
|
255 |
+
# filter low confidence instance
|
256 |
+
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
257 |
+
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
|
258 |
+
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
259 |
+
continue
|
260 |
+
|
261 |
+
# sort xy_text
|
262 |
+
left_center_pt = np.array(
|
263 |
+
[[(quad[0, 0] + quad[-1, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]]
|
264 |
+
) # (1, 2)
|
265 |
+
right_center_pt = np.array(
|
266 |
+
[[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[1, 1] + quad[2, 1]) / 2.0]]
|
267 |
+
) # (1, 2)
|
268 |
+
proj_unit_vec = (right_center_pt - left_center_pt) / (
|
269 |
+
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
|
270 |
+
)
|
271 |
+
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
|
272 |
+
xy_text = xy_text[np.argsort(proj_value)]
|
273 |
+
|
274 |
+
# Sample pts in tcl map
|
275 |
+
if self.sample_pts_num == 0:
|
276 |
+
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
|
277 |
+
else:
|
278 |
+
sample_pts_num = self.sample_pts_num
|
279 |
+
xy_center_line = xy_text[
|
280 |
+
np.linspace(
|
281 |
+
0,
|
282 |
+
xy_text.shape[0] - 1,
|
283 |
+
sample_pts_num,
|
284 |
+
endpoint=True,
|
285 |
+
dtype=np.float32,
|
286 |
+
).astype(np.int32)
|
287 |
+
]
|
288 |
+
|
289 |
+
point_pair_list = []
|
290 |
+
for x, y in xy_center_line:
|
291 |
+
# get corresponding offset
|
292 |
+
offset = tbo_map[y, x, :].reshape(2, 2)
|
293 |
+
if offset_expand != 1.0:
|
294 |
+
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
295 |
+
expand_length = np.clip(
|
296 |
+
offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
|
297 |
+
)
|
298 |
+
offset_detal = offset / offset_length * expand_length
|
299 |
+
offset = offset + offset_detal
|
300 |
+
# original point
|
301 |
+
ori_yx = np.array([y, x], dtype=np.float32)
|
302 |
+
point_pair = (
|
303 |
+
(ori_yx + offset)[:, ::-1]
|
304 |
+
* out_strid
|
305 |
+
/ np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
306 |
+
)
|
307 |
+
point_pair_list.append(point_pair)
|
308 |
+
|
309 |
+
# ndarry: (x, 2), expand poly along width
|
310 |
+
detected_poly = self.point_pair2poly(point_pair_list)
|
311 |
+
detected_poly = self.expand_poly_along_width(
|
312 |
+
detected_poly, shrink_ratio_of_width
|
313 |
+
)
|
314 |
+
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
315 |
+
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
316 |
+
poly_list.append(detected_poly)
|
317 |
+
|
318 |
+
return poly_list
|
319 |
+
|
320 |
+
def __call__(self, outs_dict, shape_list):
|
321 |
+
score_list = outs_dict["f_score"]
|
322 |
+
border_list = outs_dict["f_border"]
|
323 |
+
tvo_list = outs_dict["f_tvo"]
|
324 |
+
tco_list = outs_dict["f_tco"]
|
325 |
+
if isinstance(score_list, paddle.Tensor):
|
326 |
+
score_list = score_list.numpy()
|
327 |
+
border_list = border_list.numpy()
|
328 |
+
tvo_list = tvo_list.numpy()
|
329 |
+
tco_list = tco_list.numpy()
|
330 |
+
|
331 |
+
img_num = len(shape_list)
|
332 |
+
poly_lists = []
|
333 |
+
for ino in range(img_num):
|
334 |
+
p_score = score_list[ino].transpose((1, 2, 0))
|
335 |
+
p_border = border_list[ino].transpose((1, 2, 0))
|
336 |
+
p_tvo = tvo_list[ino].transpose((1, 2, 0))
|
337 |
+
p_tco = tco_list[ino].transpose((1, 2, 0))
|
338 |
+
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
339 |
+
|
340 |
+
poly_list = self.detect_sast(
|
341 |
+
p_score,
|
342 |
+
p_tvo,
|
343 |
+
p_border,
|
344 |
+
p_tco,
|
345 |
+
ratio_w,
|
346 |
+
ratio_h,
|
347 |
+
src_w,
|
348 |
+
src_h,
|
349 |
+
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
350 |
+
tcl_map_thresh=self.tcl_map_thresh,
|
351 |
+
offset_expand=self.expand_scale,
|
352 |
+
)
|
353 |
+
poly_lists.append({"points": np.array(poly_list)})
|
354 |
+
|
355 |
+
return poly_lists
|
ocr/postprocess/vqa_token_re_layoutlm_postprocess.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class VQAReTokenLayoutLMPostProcess(object):
|
2 |
+
"""Convert between text-label and text-index"""
|
3 |
+
|
4 |
+
def __init__(self, **kwargs):
|
5 |
+
super(VQAReTokenLayoutLMPostProcess, self).__init__()
|
6 |
+
|
7 |
+
def __call__(self, preds, label=None, *args, **kwargs):
|
8 |
+
if label is not None:
|
9 |
+
return self._metric(preds, label)
|
10 |
+
else:
|
11 |
+
return self._infer(preds, *args, **kwargs)
|
12 |
+
|
13 |
+
def _metric(self, preds, label):
|
14 |
+
return preds["pred_relations"], label[6], label[5]
|
15 |
+
|
16 |
+
def _infer(self, preds, *args, **kwargs):
|
17 |
+
ser_results = kwargs["ser_results"]
|
18 |
+
entity_idx_dict_batch = kwargs["entity_idx_dict_batch"]
|
19 |
+
pred_relations = preds["pred_relations"]
|
20 |
+
|
21 |
+
# merge relations and ocr info
|
22 |
+
results = []
|
23 |
+
for pred_relation, ser_result, entity_idx_dict in zip(
|
24 |
+
pred_relations, ser_results, entity_idx_dict_batch
|
25 |
+
):
|
26 |
+
result = []
|
27 |
+
used_tail_id = []
|
28 |
+
for relation in pred_relation:
|
29 |
+
if relation["tail_id"] in used_tail_id:
|
30 |
+
continue
|
31 |
+
used_tail_id.append(relation["tail_id"])
|
32 |
+
ocr_info_head = ser_result[entity_idx_dict[relation["head_id"]]]
|
33 |
+
ocr_info_tail = ser_result[entity_idx_dict[relation["tail_id"]]]
|
34 |
+
result.append((ocr_info_head, ocr_info_tail))
|
35 |
+
results.append(result)
|
36 |
+
return results
|
ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import paddle
|
3 |
+
|
4 |
+
|
5 |
+
def load_vqa_bio_label_maps(label_map_path):
|
6 |
+
with open(label_map_path, "r", encoding="utf-8") as fin:
|
7 |
+
lines = fin.readlines()
|
8 |
+
lines = [line.strip() for line in lines]
|
9 |
+
if "O" not in lines:
|
10 |
+
lines.insert(0, "O")
|
11 |
+
labels = []
|
12 |
+
for line in lines:
|
13 |
+
if line == "O":
|
14 |
+
labels.append("O")
|
15 |
+
else:
|
16 |
+
labels.append("B-" + line)
|
17 |
+
labels.append("I-" + line)
|
18 |
+
label2id_map = {label: idx for idx, label in enumerate(labels)}
|
19 |
+
id2label_map = {idx: label for idx, label in enumerate(labels)}
|
20 |
+
return label2id_map, id2label_map
|
21 |
+
|
22 |
+
|
23 |
+
class VQASerTokenLayoutLMPostProcess(object):
|
24 |
+
"""Convert between text-label and text-index"""
|
25 |
+
|
26 |
+
def __init__(self, class_path, **kwargs):
|
27 |
+
super(VQASerTokenLayoutLMPostProcess, self).__init__()
|
28 |
+
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
|
29 |
+
|
30 |
+
self.label2id_map_for_draw = dict()
|
31 |
+
for key in label2id_map:
|
32 |
+
if key.startswith("I-"):
|
33 |
+
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
|
34 |
+
else:
|
35 |
+
self.label2id_map_for_draw[key] = label2id_map[key]
|
36 |
+
|
37 |
+
self.id2label_map_for_show = dict()
|
38 |
+
for key in self.label2id_map_for_draw:
|
39 |
+
val = self.label2id_map_for_draw[key]
|
40 |
+
if key == "O":
|
41 |
+
self.id2label_map_for_show[val] = key
|
42 |
+
if key.startswith("B-") or key.startswith("I-"):
|
43 |
+
self.id2label_map_for_show[val] = key[2:]
|
44 |
+
else:
|
45 |
+
self.id2label_map_for_show[val] = key
|
46 |
+
|
47 |
+
def __call__(self, preds, batch=None, *args, **kwargs):
|
48 |
+
if isinstance(preds, paddle.Tensor):
|
49 |
+
preds = preds.numpy()
|
50 |
+
|
51 |
+
if batch is not None:
|
52 |
+
return self._metric(preds, batch[1])
|
53 |
+
else:
|
54 |
+
return self._infer(preds, **kwargs)
|
55 |
+
|
56 |
+
def _metric(self, preds, label):
|
57 |
+
pred_idxs = preds.argmax(axis=2)
|
58 |
+
decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
|
59 |
+
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
|
60 |
+
|
61 |
+
for i in range(pred_idxs.shape[0]):
|
62 |
+
for j in range(pred_idxs.shape[1]):
|
63 |
+
if label[i, j] != -100:
|
64 |
+
label_decode_out_list[i].append(self.id2label_map[label[i, j]])
|
65 |
+
decode_out_list[i].append(self.id2label_map[pred_idxs[i, j]])
|
66 |
+
return decode_out_list, label_decode_out_list
|
67 |
+
|
68 |
+
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
|
69 |
+
results = []
|
70 |
+
|
71 |
+
for pred, attention_mask, segment_offset_id, ocr_info in zip(
|
72 |
+
preds, attention_masks, segment_offset_ids, ocr_infos
|
73 |
+
):
|
74 |
+
pred = np.argmax(pred, axis=1)
|
75 |
+
pred = [self.id2label_map[idx] for idx in pred]
|
76 |
+
|
77 |
+
for idx in range(len(segment_offset_id)):
|
78 |
+
if idx == 0:
|
79 |
+
start_id = 0
|
80 |
+
else:
|
81 |
+
start_id = segment_offset_id[idx - 1]
|
82 |
+
|
83 |
+
end_id = segment_offset_id[idx]
|
84 |
+
|
85 |
+
curr_pred = pred[start_id:end_id]
|
86 |
+
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
|
87 |
+
|
88 |
+
if len(curr_pred) <= 0:
|
89 |
+
pred_id = 0
|
90 |
+
else:
|
91 |
+
counts = np.bincount(curr_pred)
|
92 |
+
pred_id = np.argmax(counts)
|
93 |
+
ocr_info[idx]["pred_id"] = int(pred_id)
|
94 |
+
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
|
95 |
+
results.append(ocr_info)
|
96 |
+
return results
|
ocr/ppocr/__init__.py
ADDED
File without changes
|
ocr/ppocr/data/__init__.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import os
|
4 |
+
import signal
|
5 |
+
import sys
|
6 |
+
|
7 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
8 |
+
sys.path.append(os.path.abspath(os.path.join(__dir__, "../..")))
|
9 |
+
|
10 |
+
import copy
|
11 |
+
|
12 |
+
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
|
13 |
+
|
14 |
+
from .imaug import create_operators, transform
|
15 |
+
|
16 |
+
__all__ = ["build_dataloader", "transform", "create_operators"]
|
17 |
+
|
18 |
+
|
19 |
+
def term_mp(sig_num, frame):
|
20 |
+
"""kill all child processes"""
|
21 |
+
pid = os.getpid()
|
22 |
+
pgid = os.getpgid(os.getpid())
|
23 |
+
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
24 |
+
os.killpg(pgid, signal.SIGKILL)
|
25 |
+
|
26 |
+
|
27 |
+
def build_dataloader(config, mode, device, logger, seed=None):
|
28 |
+
config = copy.deepcopy(config)
|
29 |
+
|
30 |
+
support_dict = ["SimpleDataSet", "LMDBDataSet", "PGDataSet", "PubTabDataSet"]
|
31 |
+
module_name = config[mode]["dataset"]["name"]
|
32 |
+
assert module_name in support_dict, Exception(
|
33 |
+
"DataSet only support {}".format(support_dict)
|
34 |
+
)
|
35 |
+
assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
|
36 |
+
|
37 |
+
dataset = eval(module_name)(config, mode, logger, seed)
|
38 |
+
loader_config = config[mode]["loader"]
|
39 |
+
batch_size = loader_config["batch_size_per_card"]
|
40 |
+
drop_last = loader_config["drop_last"]
|
41 |
+
shuffle = loader_config["shuffle"]
|
42 |
+
num_workers = loader_config["num_workers"]
|
43 |
+
if "use_shared_memory" in loader_config.keys():
|
44 |
+
use_shared_memory = loader_config["use_shared_memory"]
|
45 |
+
else:
|
46 |
+
use_shared_memory = True
|
47 |
+
|
48 |
+
if mode == "Train":
|
49 |
+
# Distribute data to multiple cards
|
50 |
+
batch_sampler = DistributedBatchSampler(
|
51 |
+
dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
# Distribute data to single card
|
55 |
+
batch_sampler = BatchSampler(
|
56 |
+
dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
|
57 |
+
)
|
58 |
+
|
59 |
+
if "collate_fn" in loader_config:
|
60 |
+
from . import collate_fn
|
61 |
+
|
62 |
+
collate_fn = getattr(collate_fn, loader_config["collate_fn"])()
|
63 |
+
else:
|
64 |
+
collate_fn = None
|
65 |
+
data_loader = DataLoader(
|
66 |
+
dataset=dataset,
|
67 |
+
batch_sampler=batch_sampler,
|
68 |
+
places=device,
|
69 |
+
num_workers=num_workers,
|
70 |
+
return_list=True,
|
71 |
+
use_shared_memory=use_shared_memory,
|
72 |
+
collate_fn=collate_fn,
|
73 |
+
)
|
74 |
+
|
75 |
+
# support exit using ctrl+c
|
76 |
+
signal.signal(signal.SIGINT, term_mp)
|
77 |
+
signal.signal(signal.SIGTERM, term_mp)
|
78 |
+
|
79 |
+
return data_loader
|
ocr/ppocr/data/collate_fn.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import paddle
|
6 |
+
|
7 |
+
|
8 |
+
class DictCollator(object):
|
9 |
+
"""
|
10 |
+
data batch
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __call__(self, batch):
|
14 |
+
# todo:support batch operators
|
15 |
+
data_dict = defaultdict(list)
|
16 |
+
to_tensor_keys = []
|
17 |
+
for sample in batch:
|
18 |
+
for k, v in sample.items():
|
19 |
+
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
20 |
+
if k not in to_tensor_keys:
|
21 |
+
to_tensor_keys.append(k)
|
22 |
+
data_dict[k].append(v)
|
23 |
+
for k in to_tensor_keys:
|
24 |
+
data_dict[k] = paddle.to_tensor(data_dict[k])
|
25 |
+
return data_dict
|
26 |
+
|
27 |
+
|
28 |
+
class ListCollator(object):
|
29 |
+
"""
|
30 |
+
data batch
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __call__(self, batch):
|
34 |
+
# todo:support batch operators
|
35 |
+
data_dict = defaultdict(list)
|
36 |
+
to_tensor_idxs = []
|
37 |
+
for sample in batch:
|
38 |
+
for idx, v in enumerate(sample):
|
39 |
+
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
40 |
+
if idx not in to_tensor_idxs:
|
41 |
+
to_tensor_idxs.append(idx)
|
42 |
+
data_dict[idx].append(v)
|
43 |
+
for idx in to_tensor_idxs:
|
44 |
+
data_dict[idx] = paddle.to_tensor(data_dict[idx])
|
45 |
+
return list(data_dict.values())
|
46 |
+
|
47 |
+
|
48 |
+
class SSLRotateCollate(object):
|
49 |
+
"""
|
50 |
+
bach: [
|
51 |
+
[(4*3xH*W), (4,)]
|
52 |
+
[(4*3xH*W), (4,)]
|
53 |
+
...
|
54 |
+
]
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __call__(self, batch):
|
58 |
+
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
|
59 |
+
return output
|
ocr/ppocr/data/imaug/ColorJitter.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
|
2 |
+
|
3 |
+
__all__ = ["ColorJitter"]
|
4 |
+
|
5 |
+
|
6 |
+
class ColorJitter(object):
|
7 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, **kwargs):
|
8 |
+
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
|
9 |
+
|
10 |
+
def __call__(self, data):
|
11 |
+
image = data["image"]
|
12 |
+
image = self.aug(image)
|
13 |
+
data["image"] = image
|
14 |
+
return data
|
ocr/ppocr/data/imaug/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
from .ColorJitter import ColorJitter
|
4 |
+
from .copy_paste import CopyPaste
|
5 |
+
from .east_process import *
|
6 |
+
from .fce_aug import *
|
7 |
+
from .fce_targets import FCENetTargets
|
8 |
+
from .gen_table_mask import *
|
9 |
+
from .iaa_augment import IaaAugment
|
10 |
+
from .label_ops import *
|
11 |
+
from .make_border_map import MakeBorderMap
|
12 |
+
from .make_pse_gt import MakePseGt
|
13 |
+
from .make_shrink_map import MakeShrinkMap
|
14 |
+
from .operators import *
|
15 |
+
from .pg_process import *
|
16 |
+
from .randaugment import RandAugment
|
17 |
+
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
18 |
+
from .rec_img_aug import (
|
19 |
+
ClsResizeImg,
|
20 |
+
NRTRRecResizeImg,
|
21 |
+
PRENResizeImg,
|
22 |
+
RecAug,
|
23 |
+
RecConAug,
|
24 |
+
RecResizeImg,
|
25 |
+
SARRecResizeImg,
|
26 |
+
SRNRecResizeImg,
|
27 |
+
)
|
28 |
+
from .sast_process import *
|
29 |
+
from .ssl_img_aug import SSLRotateResize
|
30 |
+
from .vqa import *
|
31 |
+
|
32 |
+
|
33 |
+
def transform(data, ops=None):
|
34 |
+
"""transform"""
|
35 |
+
if ops is None:
|
36 |
+
ops = []
|
37 |
+
for op in ops:
|
38 |
+
data = op(data)
|
39 |
+
if data is None:
|
40 |
+
return None
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
def create_operators(op_param_list, global_config=None):
|
45 |
+
"""
|
46 |
+
create operators based on the config
|
47 |
+
|
48 |
+
Args:
|
49 |
+
params(list): a dict list, used to create some operators
|
50 |
+
"""
|
51 |
+
assert isinstance(op_param_list, list), "operator config should be a list"
|
52 |
+
ops = []
|
53 |
+
for operator in op_param_list:
|
54 |
+
assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
|
55 |
+
op_name = list(operator)[0]
|
56 |
+
param = {} if operator[op_name] is None else operator[op_name]
|
57 |
+
if global_config is not None:
|
58 |
+
param.update(global_config)
|
59 |
+
op = eval(op_name)(**param)
|
60 |
+
ops.append(op)
|
61 |
+
return ops
|
ocr/ppocr/data/imaug/copy_paste.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from shapely.geometry import Polygon
|
8 |
+
|
9 |
+
from ppocr.data.imaug.iaa_augment import IaaAugment
|
10 |
+
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
|
11 |
+
from utility import get_rotate_crop_image
|
12 |
+
|
13 |
+
|
14 |
+
class CopyPaste(object):
|
15 |
+
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
|
16 |
+
self.ext_data_num = 1
|
17 |
+
self.objects_paste_ratio = objects_paste_ratio
|
18 |
+
self.limit_paste = limit_paste
|
19 |
+
augmenter_args = [{"type": "Resize", "args": {"size": [0.5, 3]}}]
|
20 |
+
self.aug = IaaAugment(augmenter_args)
|
21 |
+
|
22 |
+
def __call__(self, data):
|
23 |
+
point_num = data["polys"].shape[1]
|
24 |
+
src_img = data["image"]
|
25 |
+
src_polys = data["polys"].tolist()
|
26 |
+
src_texts = data["texts"]
|
27 |
+
src_ignores = data["ignore_tags"].tolist()
|
28 |
+
ext_data = data["ext_data"][0]
|
29 |
+
ext_image = ext_data["image"]
|
30 |
+
ext_polys = ext_data["polys"]
|
31 |
+
ext_texts = ext_data["texts"]
|
32 |
+
ext_ignores = ext_data["ignore_tags"]
|
33 |
+
|
34 |
+
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
35 |
+
select_num = max(1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
|
36 |
+
|
37 |
+
random.shuffle(indexs)
|
38 |
+
select_idxs = indexs[:select_num]
|
39 |
+
select_polys = ext_polys[select_idxs]
|
40 |
+
select_ignores = ext_ignores[select_idxs]
|
41 |
+
|
42 |
+
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
43 |
+
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
44 |
+
src_img = Image.fromarray(src_img).convert("RGBA")
|
45 |
+
for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
|
46 |
+
box_img = get_rotate_crop_image(ext_image, poly)
|
47 |
+
|
48 |
+
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
49 |
+
if box is not None:
|
50 |
+
box = box.tolist()
|
51 |
+
for _ in range(len(box), point_num):
|
52 |
+
box.append(box[-1])
|
53 |
+
src_polys.append(box)
|
54 |
+
src_texts.append(ext_texts[idx])
|
55 |
+
src_ignores.append(tag)
|
56 |
+
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
57 |
+
h, w = src_img.shape[:2]
|
58 |
+
src_polys = np.array(src_polys)
|
59 |
+
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
|
60 |
+
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
61 |
+
data["image"] = src_img
|
62 |
+
data["polys"] = src_polys
|
63 |
+
data["texts"] = src_texts
|
64 |
+
data["ignore_tags"] = np.array(src_ignores)
|
65 |
+
return data
|
66 |
+
|
67 |
+
def paste_img(self, src_img, box_img, src_polys):
|
68 |
+
box_img_pil = Image.fromarray(box_img).convert("RGBA")
|
69 |
+
src_w, src_h = src_img.size
|
70 |
+
box_w, box_h = box_img_pil.size
|
71 |
+
|
72 |
+
angle = np.random.randint(0, 360)
|
73 |
+
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
|
74 |
+
box = rotate_bbox(box_img, box, angle)[0]
|
75 |
+
box_img_pil = box_img_pil.rotate(angle, expand=1)
|
76 |
+
box_w, box_h = box_img_pil.width, box_img_pil.height
|
77 |
+
if src_w - box_w < 0 or src_h - box_h < 0:
|
78 |
+
return src_img, None
|
79 |
+
|
80 |
+
paste_x, paste_y = self.select_coord(
|
81 |
+
src_polys, box, src_w - box_w, src_h - box_h
|
82 |
+
)
|
83 |
+
if paste_x is None:
|
84 |
+
return src_img, None
|
85 |
+
box[:, 0] += paste_x
|
86 |
+
box[:, 1] += paste_y
|
87 |
+
r, g, b, A = box_img_pil.split()
|
88 |
+
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
|
89 |
+
|
90 |
+
return src_img, box
|
91 |
+
|
92 |
+
def select_coord(self, src_polys, box, endx, endy):
|
93 |
+
if self.limit_paste:
|
94 |
+
xmin, ymin, xmax, ymax = (
|
95 |
+
box[:, 0].min(),
|
96 |
+
box[:, 1].min(),
|
97 |
+
box[:, 0].max(),
|
98 |
+
box[:, 1].max(),
|
99 |
+
)
|
100 |
+
for _ in range(50):
|
101 |
+
paste_x = random.randint(0, endx)
|
102 |
+
paste_y = random.randint(0, endy)
|
103 |
+
xmin1 = xmin + paste_x
|
104 |
+
xmax1 = xmax + paste_x
|
105 |
+
ymin1 = ymin + paste_y
|
106 |
+
ymax1 = ymax + paste_y
|
107 |
+
|
108 |
+
num_poly_in_rect = 0
|
109 |
+
for poly in src_polys:
|
110 |
+
if not is_poly_outside_rect(
|
111 |
+
poly, xmin1, ymin1, xmax1 - xmin1, ymax1 - ymin1
|
112 |
+
):
|
113 |
+
num_poly_in_rect += 1
|
114 |
+
break
|
115 |
+
if num_poly_in_rect == 0:
|
116 |
+
return paste_x, paste_y
|
117 |
+
return None, None
|
118 |
+
else:
|
119 |
+
paste_x = random.randint(0, endx)
|
120 |
+
paste_y = random.randint(0, endy)
|
121 |
+
return paste_x, paste_y
|
122 |
+
|
123 |
+
|
124 |
+
def get_union(pD, pG):
|
125 |
+
return Polygon(pD).union(Polygon(pG)).area
|
126 |
+
|
127 |
+
|
128 |
+
def get_intersection_over_union(pD, pG):
|
129 |
+
return get_intersection(pD, pG) / get_union(pD, pG)
|
130 |
+
|
131 |
+
|
132 |
+
def get_intersection(pD, pG):
|
133 |
+
return Polygon(pD).intersection(Polygon(pG)).area
|
134 |
+
|
135 |
+
|
136 |
+
def rotate_bbox(img, text_polys, angle, scale=1):
|
137 |
+
"""
|
138 |
+
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
|
139 |
+
Args:
|
140 |
+
img: np.ndarray
|
141 |
+
text_polys: np.ndarray N*4*2
|
142 |
+
angle: int
|
143 |
+
scale: int
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
|
147 |
+
"""
|
148 |
+
w = img.shape[1]
|
149 |
+
h = img.shape[0]
|
150 |
+
|
151 |
+
rangle = np.deg2rad(angle)
|
152 |
+
nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
|
153 |
+
nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
|
154 |
+
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
|
155 |
+
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
|
156 |
+
rot_mat[0, 2] += rot_move[0]
|
157 |
+
rot_mat[1, 2] += rot_move[1]
|
158 |
+
|
159 |
+
# ---------------------- rotate box ----------------------
|
160 |
+
rot_text_polys = list()
|
161 |
+
for bbox in text_polys:
|
162 |
+
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
|
163 |
+
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
|
164 |
+
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
|
165 |
+
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
|
166 |
+
rot_text_polys.append([point1, point2, point3, point4])
|
167 |
+
return np.array(rot_text_polys, dtype=np.float32)
|
ocr/ppocr/data/imaug/east_process.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
__all__ = ["EASTProcessTrain"]
|
7 |
+
|
8 |
+
|
9 |
+
class EASTProcessTrain(object):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
image_shape=[512, 512],
|
13 |
+
background_ratio=0.125,
|
14 |
+
min_crop_side_ratio=0.1,
|
15 |
+
min_text_size=10,
|
16 |
+
**kwargs
|
17 |
+
):
|
18 |
+
self.input_size = image_shape[1]
|
19 |
+
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
20 |
+
self.background_ratio = background_ratio
|
21 |
+
self.min_crop_side_ratio = min_crop_side_ratio
|
22 |
+
self.min_text_size = min_text_size
|
23 |
+
|
24 |
+
def preprocess(self, im):
|
25 |
+
input_size = self.input_size
|
26 |
+
im_shape = im.shape
|
27 |
+
im_size_min = np.min(im_shape[0:2])
|
28 |
+
im_size_max = np.max(im_shape[0:2])
|
29 |
+
im_scale = float(input_size) / float(im_size_max)
|
30 |
+
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
|
31 |
+
img_mean = [0.485, 0.456, 0.406]
|
32 |
+
img_std = [0.229, 0.224, 0.225]
|
33 |
+
# im = im[:, :, ::-1].astype(np.float32)
|
34 |
+
im = im / 255
|
35 |
+
im -= img_mean
|
36 |
+
im /= img_std
|
37 |
+
new_h, new_w, _ = im.shape
|
38 |
+
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
|
39 |
+
im_padded[:new_h, :new_w, :] = im
|
40 |
+
im_padded = im_padded.transpose((2, 0, 1))
|
41 |
+
im_padded = im_padded[np.newaxis, :]
|
42 |
+
return im_padded, im_scale
|
43 |
+
|
44 |
+
def rotate_im_poly(self, im, text_polys):
|
45 |
+
"""
|
46 |
+
rotate image with 90 / 180 / 270 degre
|
47 |
+
"""
|
48 |
+
im_w, im_h = im.shape[1], im.shape[0]
|
49 |
+
dst_im = im.copy()
|
50 |
+
dst_polys = []
|
51 |
+
rand_degree_ratio = np.random.rand()
|
52 |
+
rand_degree_cnt = 1
|
53 |
+
if 0.333 < rand_degree_ratio < 0.666:
|
54 |
+
rand_degree_cnt = 2
|
55 |
+
elif rand_degree_ratio > 0.666:
|
56 |
+
rand_degree_cnt = 3
|
57 |
+
for i in range(rand_degree_cnt):
|
58 |
+
dst_im = np.rot90(dst_im)
|
59 |
+
rot_degree = -90 * rand_degree_cnt
|
60 |
+
rot_angle = rot_degree * math.pi / 180.0
|
61 |
+
n_poly = text_polys.shape[0]
|
62 |
+
cx, cy = 0.5 * im_w, 0.5 * im_h
|
63 |
+
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
64 |
+
for i in range(n_poly):
|
65 |
+
wordBB = text_polys[i]
|
66 |
+
poly = []
|
67 |
+
for j in range(4):
|
68 |
+
sx, sy = wordBB[j][0], wordBB[j][1]
|
69 |
+
dx = (
|
70 |
+
math.cos(rot_angle) * (sx - cx)
|
71 |
+
- math.sin(rot_angle) * (sy - cy)
|
72 |
+
+ ncx
|
73 |
+
)
|
74 |
+
dy = (
|
75 |
+
math.sin(rot_angle) * (sx - cx)
|
76 |
+
+ math.cos(rot_angle) * (sy - cy)
|
77 |
+
+ ncy
|
78 |
+
)
|
79 |
+
poly.append([dx, dy])
|
80 |
+
dst_polys.append(poly)
|
81 |
+
dst_polys = np.array(dst_polys, dtype=np.float32)
|
82 |
+
return dst_im, dst_polys
|
83 |
+
|
84 |
+
def polygon_area(self, poly):
|
85 |
+
"""
|
86 |
+
compute area of a polygon
|
87 |
+
:param poly:
|
88 |
+
:return:
|
89 |
+
"""
|
90 |
+
edge = [
|
91 |
+
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
92 |
+
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
93 |
+
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
94 |
+
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
|
95 |
+
]
|
96 |
+
return np.sum(edge) / 2.0
|
97 |
+
|
98 |
+
def check_and_validate_polys(self, polys, tags, img_height, img_width):
|
99 |
+
"""
|
100 |
+
check so that the text poly is in the same direction,
|
101 |
+
and also filter some invalid polygons
|
102 |
+
:param polys:
|
103 |
+
:param tags:
|
104 |
+
:return:
|
105 |
+
"""
|
106 |
+
h, w = img_height, img_width
|
107 |
+
if polys.shape[0] == 0:
|
108 |
+
return polys
|
109 |
+
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
110 |
+
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
111 |
+
|
112 |
+
validated_polys = []
|
113 |
+
validated_tags = []
|
114 |
+
for poly, tag in zip(polys, tags):
|
115 |
+
p_area = self.polygon_area(poly)
|
116 |
+
# invalid poly
|
117 |
+
if abs(p_area) < 1:
|
118 |
+
continue
|
119 |
+
if p_area > 0:
|
120 |
+
#'poly in wrong direction'
|
121 |
+
if not tag:
|
122 |
+
tag = True # reversed cases should be ignore
|
123 |
+
poly = poly[(0, 3, 2, 1), :]
|
124 |
+
validated_polys.append(poly)
|
125 |
+
validated_tags.append(tag)
|
126 |
+
return np.array(validated_polys), np.array(validated_tags)
|
127 |
+
|
128 |
+
def draw_img_polys(self, img, polys):
|
129 |
+
if len(img.shape) == 4:
|
130 |
+
img = np.squeeze(img, axis=0)
|
131 |
+
if img.shape[0] == 3:
|
132 |
+
img = img.transpose((1, 2, 0))
|
133 |
+
img[:, :, 2] += 123.68
|
134 |
+
img[:, :, 1] += 116.78
|
135 |
+
img[:, :, 0] += 103.94
|
136 |
+
cv2.imwrite("tmp.jpg", img)
|
137 |
+
img = cv2.imread("tmp.jpg")
|
138 |
+
for box in polys:
|
139 |
+
box = box.astype(np.int32).reshape((-1, 1, 2))
|
140 |
+
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
|
141 |
+
import random
|
142 |
+
|
143 |
+
ino = random.randint(0, 100)
|
144 |
+
cv2.imwrite("tmp_%d.jpg" % ino, img)
|
145 |
+
return
|
146 |
+
|
147 |
+
def shrink_poly(self, poly, r):
|
148 |
+
"""
|
149 |
+
fit a poly inside the origin poly, maybe bugs here...
|
150 |
+
used for generate the score map
|
151 |
+
:param poly: the text poly
|
152 |
+
:param r: r in the paper
|
153 |
+
:return: the shrinked poly
|
154 |
+
"""
|
155 |
+
# shrink ratio
|
156 |
+
R = 0.3
|
157 |
+
# find the longer pair
|
158 |
+
dist0 = np.linalg.norm(poly[0] - poly[1])
|
159 |
+
dist1 = np.linalg.norm(poly[2] - poly[3])
|
160 |
+
dist2 = np.linalg.norm(poly[0] - poly[3])
|
161 |
+
dist3 = np.linalg.norm(poly[1] - poly[2])
|
162 |
+
if dist0 + dist1 > dist2 + dist3:
|
163 |
+
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
|
164 |
+
## p0, p1
|
165 |
+
theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
|
166 |
+
poly[0][0] += R * r[0] * np.cos(theta)
|
167 |
+
poly[0][1] += R * r[0] * np.sin(theta)
|
168 |
+
poly[1][0] -= R * r[1] * np.cos(theta)
|
169 |
+
poly[1][1] -= R * r[1] * np.sin(theta)
|
170 |
+
## p2, p3
|
171 |
+
theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
|
172 |
+
poly[3][0] += R * r[3] * np.cos(theta)
|
173 |
+
poly[3][1] += R * r[3] * np.sin(theta)
|
174 |
+
poly[2][0] -= R * r[2] * np.cos(theta)
|
175 |
+
poly[2][1] -= R * r[2] * np.sin(theta)
|
176 |
+
## p0, p3
|
177 |
+
theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
|
178 |
+
poly[0][0] += R * r[0] * np.sin(theta)
|
179 |
+
poly[0][1] += R * r[0] * np.cos(theta)
|
180 |
+
poly[3][0] -= R * r[3] * np.sin(theta)
|
181 |
+
poly[3][1] -= R * r[3] * np.cos(theta)
|
182 |
+
## p1, p2
|
183 |
+
theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
|
184 |
+
poly[1][0] += R * r[1] * np.sin(theta)
|
185 |
+
poly[1][1] += R * r[1] * np.cos(theta)
|
186 |
+
poly[2][0] -= R * r[2] * np.sin(theta)
|
187 |
+
poly[2][1] -= R * r[2] * np.cos(theta)
|
188 |
+
else:
|
189 |
+
## p0, p3
|
190 |
+
# print poly
|
191 |
+
theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
|
192 |
+
poly[0][0] += R * r[0] * np.sin(theta)
|
193 |
+
poly[0][1] += R * r[0] * np.cos(theta)
|
194 |
+
poly[3][0] -= R * r[3] * np.sin(theta)
|
195 |
+
poly[3][1] -= R * r[3] * np.cos(theta)
|
196 |
+
## p1, p2
|
197 |
+
theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
|
198 |
+
poly[1][0] += R * r[1] * np.sin(theta)
|
199 |
+
poly[1][1] += R * r[1] * np.cos(theta)
|
200 |
+
poly[2][0] -= R * r[2] * np.sin(theta)
|
201 |
+
poly[2][1] -= R * r[2] * np.cos(theta)
|
202 |
+
## p0, p1
|
203 |
+
theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
|
204 |
+
poly[0][0] += R * r[0] * np.cos(theta)
|
205 |
+
poly[0][1] += R * r[0] * np.sin(theta)
|
206 |
+
poly[1][0] -= R * r[1] * np.cos(theta)
|
207 |
+
poly[1][1] -= R * r[1] * np.sin(theta)
|
208 |
+
## p2, p3
|
209 |
+
theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
|
210 |
+
poly[3][0] += R * r[3] * np.cos(theta)
|
211 |
+
poly[3][1] += R * r[3] * np.sin(theta)
|
212 |
+
poly[2][0] -= R * r[2] * np.cos(theta)
|
213 |
+
poly[2][1] -= R * r[2] * np.sin(theta)
|
214 |
+
return poly
|
215 |
+
|
216 |
+
def generate_quad(self, im_size, polys, tags):
|
217 |
+
"""
|
218 |
+
Generate quadrangle.
|
219 |
+
"""
|
220 |
+
h, w = im_size
|
221 |
+
poly_mask = np.zeros((h, w), dtype=np.uint8)
|
222 |
+
score_map = np.zeros((h, w), dtype=np.uint8)
|
223 |
+
# (x1, y1, ..., x4, y4, short_edge_norm)
|
224 |
+
geo_map = np.zeros((h, w, 9), dtype=np.float32)
|
225 |
+
# mask used during traning, to ignore some hard areas
|
226 |
+
training_mask = np.ones((h, w), dtype=np.uint8)
|
227 |
+
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
228 |
+
poly = poly_tag[0]
|
229 |
+
tag = poly_tag[1]
|
230 |
+
|
231 |
+
r = [None, None, None, None]
|
232 |
+
for i in range(4):
|
233 |
+
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
|
234 |
+
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
|
235 |
+
r[i] = min(dist1, dist2)
|
236 |
+
# score map
|
237 |
+
shrinked_poly = self.shrink_poly(poly.copy(), r).astype(np.int32)[
|
238 |
+
np.newaxis, :, :
|
239 |
+
]
|
240 |
+
cv2.fillPoly(score_map, shrinked_poly, 1)
|
241 |
+
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
|
242 |
+
# if the poly is too small, then ignore it during training
|
243 |
+
poly_h = min(
|
244 |
+
np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])
|
245 |
+
)
|
246 |
+
poly_w = min(
|
247 |
+
np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])
|
248 |
+
)
|
249 |
+
if min(poly_h, poly_w) < self.min_text_size:
|
250 |
+
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
|
251 |
+
|
252 |
+
if tag:
|
253 |
+
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
|
254 |
+
|
255 |
+
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
|
256 |
+
# geo map.
|
257 |
+
y_in_poly = xy_in_poly[:, 0]
|
258 |
+
x_in_poly = xy_in_poly[:, 1]
|
259 |
+
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
|
260 |
+
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
|
261 |
+
for pno in range(4):
|
262 |
+
geo_channel_beg = pno * 2
|
263 |
+
geo_map[y_in_poly, x_in_poly, geo_channel_beg] = (
|
264 |
+
x_in_poly - poly[pno, 0]
|
265 |
+
)
|
266 |
+
geo_map[y_in_poly, x_in_poly, geo_channel_beg + 1] = (
|
267 |
+
y_in_poly - poly[pno, 1]
|
268 |
+
)
|
269 |
+
geo_map[y_in_poly, x_in_poly, 8] = 1.0 / max(min(poly_h, poly_w), 1.0)
|
270 |
+
return score_map, geo_map, training_mask
|
271 |
+
|
272 |
+
def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
|
273 |
+
"""
|
274 |
+
make random crop from the input image
|
275 |
+
:param im:
|
276 |
+
:param polys:
|
277 |
+
:param tags:
|
278 |
+
:param crop_background:
|
279 |
+
:param max_tries:
|
280 |
+
:return:
|
281 |
+
"""
|
282 |
+
h, w, _ = im.shape
|
283 |
+
pad_h = h // 10
|
284 |
+
pad_w = w // 10
|
285 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
286 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
287 |
+
for poly in polys:
|
288 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
289 |
+
minx = np.min(poly[:, 0])
|
290 |
+
maxx = np.max(poly[:, 0])
|
291 |
+
w_array[minx + pad_w : maxx + pad_w] = 1
|
292 |
+
miny = np.min(poly[:, 1])
|
293 |
+
maxy = np.max(poly[:, 1])
|
294 |
+
h_array[miny + pad_h : maxy + pad_h] = 1
|
295 |
+
# ensure the cropped area not across a text
|
296 |
+
h_axis = np.where(h_array == 0)[0]
|
297 |
+
w_axis = np.where(w_array == 0)[0]
|
298 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
299 |
+
return im, polys, tags
|
300 |
+
|
301 |
+
for i in range(max_tries):
|
302 |
+
xx = np.random.choice(w_axis, size=2)
|
303 |
+
xmin = np.min(xx) - pad_w
|
304 |
+
xmax = np.max(xx) - pad_w
|
305 |
+
xmin = np.clip(xmin, 0, w - 1)
|
306 |
+
xmax = np.clip(xmax, 0, w - 1)
|
307 |
+
yy = np.random.choice(h_axis, size=2)
|
308 |
+
ymin = np.min(yy) - pad_h
|
309 |
+
ymax = np.max(yy) - pad_h
|
310 |
+
ymin = np.clip(ymin, 0, h - 1)
|
311 |
+
ymax = np.clip(ymax, 0, h - 1)
|
312 |
+
if (
|
313 |
+
xmax - xmin < self.min_crop_side_ratio * w
|
314 |
+
or ymax - ymin < self.min_crop_side_ratio * h
|
315 |
+
):
|
316 |
+
# area too small
|
317 |
+
continue
|
318 |
+
if polys.shape[0] != 0:
|
319 |
+
poly_axis_in_area = (
|
320 |
+
(polys[:, :, 0] >= xmin)
|
321 |
+
& (polys[:, :, 0] <= xmax)
|
322 |
+
& (polys[:, :, 1] >= ymin)
|
323 |
+
& (polys[:, :, 1] <= ymax)
|
324 |
+
)
|
325 |
+
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
326 |
+
else:
|
327 |
+
selected_polys = []
|
328 |
+
|
329 |
+
if len(selected_polys) == 0:
|
330 |
+
# no text in this area
|
331 |
+
if crop_background:
|
332 |
+
im = im[ymin : ymax + 1, xmin : xmax + 1, :]
|
333 |
+
polys = []
|
334 |
+
tags = []
|
335 |
+
return im, polys, tags
|
336 |
+
else:
|
337 |
+
continue
|
338 |
+
|
339 |
+
im = im[ymin : ymax + 1, xmin : xmax + 1, :]
|
340 |
+
polys = polys[selected_polys]
|
341 |
+
tags = tags[selected_polys]
|
342 |
+
polys[:, :, 0] -= xmin
|
343 |
+
polys[:, :, 1] -= ymin
|
344 |
+
return im, polys, tags
|
345 |
+
return im, polys, tags
|
346 |
+
|
347 |
+
def crop_background_infor(self, im, text_polys, text_tags):
|
348 |
+
im, text_polys, text_tags = self.crop_area(
|
349 |
+
im, text_polys, text_tags, crop_background=True
|
350 |
+
)
|
351 |
+
|
352 |
+
if len(text_polys) > 0:
|
353 |
+
return None
|
354 |
+
# pad and resize image
|
355 |
+
input_size = self.input_size
|
356 |
+
im, ratio = self.preprocess(im)
|
357 |
+
score_map = np.zeros((input_size, input_size), dtype=np.float32)
|
358 |
+
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
|
359 |
+
training_mask = np.ones((input_size, input_size), dtype=np.float32)
|
360 |
+
return im, score_map, geo_map, training_mask
|
361 |
+
|
362 |
+
def crop_foreground_infor(self, im, text_polys, text_tags):
|
363 |
+
im, text_polys, text_tags = self.crop_area(
|
364 |
+
im, text_polys, text_tags, crop_background=False
|
365 |
+
)
|
366 |
+
|
367 |
+
if text_polys.shape[0] == 0:
|
368 |
+
return None
|
369 |
+
# continue for all ignore case
|
370 |
+
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
371 |
+
return None
|
372 |
+
# pad and resize image
|
373 |
+
input_size = self.input_size
|
374 |
+
im, ratio = self.preprocess(im)
|
375 |
+
text_polys[:, :, 0] *= ratio
|
376 |
+
text_polys[:, :, 1] *= ratio
|
377 |
+
_, _, new_h, new_w = im.shape
|
378 |
+
# print(im.shape)
|
379 |
+
# self.draw_img_polys(im, text_polys)
|
380 |
+
score_map, geo_map, training_mask = self.generate_quad(
|
381 |
+
(new_h, new_w), text_polys, text_tags
|
382 |
+
)
|
383 |
+
return im, score_map, geo_map, training_mask
|
384 |
+
|
385 |
+
def __call__(self, data):
|
386 |
+
im = data["image"]
|
387 |
+
text_polys = data["polys"]
|
388 |
+
text_tags = data["ignore_tags"]
|
389 |
+
if im is None:
|
390 |
+
return None
|
391 |
+
if text_polys.shape[0] == 0:
|
392 |
+
return None
|
393 |
+
|
394 |
+
# add rotate cases
|
395 |
+
if np.random.rand() < 0.5:
|
396 |
+
im, text_polys = self.rotate_im_poly(im, text_polys)
|
397 |
+
h, w, _ = im.shape
|
398 |
+
text_polys, text_tags = self.check_and_validate_polys(
|
399 |
+
text_polys, text_tags, h, w
|
400 |
+
)
|
401 |
+
if text_polys.shape[0] == 0:
|
402 |
+
return None
|
403 |
+
|
404 |
+
# random scale this image
|
405 |
+
rd_scale = np.random.choice(self.random_scale)
|
406 |
+
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
407 |
+
text_polys *= rd_scale
|
408 |
+
if np.random.rand() < self.background_ratio:
|
409 |
+
outs = self.crop_background_infor(im, text_polys, text_tags)
|
410 |
+
else:
|
411 |
+
outs = self.crop_foreground_infor(im, text_polys, text_tags)
|
412 |
+
|
413 |
+
if outs is None:
|
414 |
+
return None
|
415 |
+
im, score_map, geo_map, training_mask = outs
|
416 |
+
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
|
417 |
+
geo_map = np.swapaxes(geo_map, 1, 2)
|
418 |
+
geo_map = np.swapaxes(geo_map, 1, 0)
|
419 |
+
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
|
420 |
+
training_mask = training_mask[np.newaxis, ::4, ::4]
|
421 |
+
training_mask = training_mask.astype(np.float32)
|
422 |
+
|
423 |
+
data["image"] = im[0]
|
424 |
+
data["score_map"] = score_map
|
425 |
+
data["geo_map"] = geo_map
|
426 |
+
data["training_mask"] = training_mask
|
427 |
+
return data
|
ocr/ppocr/data/imaug/fce_aug.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
from shapely.geometry import Polygon
|
8 |
+
|
9 |
+
from postprocess.poly_nms import poly_intersection
|
10 |
+
|
11 |
+
|
12 |
+
class RandomScaling:
|
13 |
+
def __init__(self, size=800, scale=(3.0 / 4, 5.0 / 2), **kwargs):
|
14 |
+
"""Random scale the image while keeping aspect.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
size (int) : Base size before scaling.
|
18 |
+
scale (tuple(float)) : The range of scaling.
|
19 |
+
"""
|
20 |
+
assert isinstance(size, int)
|
21 |
+
assert isinstance(scale, float) or isinstance(scale, tuple)
|
22 |
+
self.size = size
|
23 |
+
self.scale = scale if isinstance(scale, tuple) else (1 - scale, 1 + scale)
|
24 |
+
|
25 |
+
def __call__(self, data):
|
26 |
+
image = data["image"]
|
27 |
+
text_polys = data["polys"]
|
28 |
+
h, w, _ = image.shape
|
29 |
+
|
30 |
+
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
|
31 |
+
scales = self.size * 1.0 / max(h, w) * aspect_ratio
|
32 |
+
scales = np.array([scales, scales])
|
33 |
+
out_size = (int(h * scales[1]), int(w * scales[0]))
|
34 |
+
image = cv2.resize(image, out_size[::-1])
|
35 |
+
|
36 |
+
data["image"] = image
|
37 |
+
text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
|
38 |
+
text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
|
39 |
+
data["polys"] = text_polys
|
40 |
+
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
class RandomCropFlip:
|
45 |
+
def __init__(
|
46 |
+
self, pad_ratio=0.1, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2, **kwargs
|
47 |
+
):
|
48 |
+
"""Random crop and flip a patch of the image.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
crop_ratio (float): The ratio of cropping.
|
52 |
+
iter_num (int): Number of operations.
|
53 |
+
min_area_ratio (float): Minimal area ratio between cropped patch
|
54 |
+
and original image.
|
55 |
+
"""
|
56 |
+
assert isinstance(crop_ratio, float)
|
57 |
+
assert isinstance(iter_num, int)
|
58 |
+
assert isinstance(min_area_ratio, float)
|
59 |
+
|
60 |
+
self.pad_ratio = pad_ratio
|
61 |
+
self.epsilon = 1e-2
|
62 |
+
self.crop_ratio = crop_ratio
|
63 |
+
self.iter_num = iter_num
|
64 |
+
self.min_area_ratio = min_area_ratio
|
65 |
+
|
66 |
+
def __call__(self, results):
|
67 |
+
for i in range(self.iter_num):
|
68 |
+
results = self.random_crop_flip(results)
|
69 |
+
|
70 |
+
return results
|
71 |
+
|
72 |
+
def random_crop_flip(self, results):
|
73 |
+
image = results["image"]
|
74 |
+
polygons = results["polys"]
|
75 |
+
ignore_tags = results["ignore_tags"]
|
76 |
+
if len(polygons) == 0:
|
77 |
+
return results
|
78 |
+
|
79 |
+
if np.random.random() >= self.crop_ratio:
|
80 |
+
return results
|
81 |
+
|
82 |
+
h, w, _ = image.shape
|
83 |
+
area = h * w
|
84 |
+
pad_h = int(h * self.pad_ratio)
|
85 |
+
pad_w = int(w * self.pad_ratio)
|
86 |
+
h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h, pad_w)
|
87 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
88 |
+
return results
|
89 |
+
|
90 |
+
attempt = 0
|
91 |
+
while attempt < 50:
|
92 |
+
attempt += 1
|
93 |
+
polys_keep = []
|
94 |
+
polys_new = []
|
95 |
+
ignore_tags_keep = []
|
96 |
+
ignore_tags_new = []
|
97 |
+
xx = np.random.choice(w_axis, size=2)
|
98 |
+
xmin = np.min(xx) - pad_w
|
99 |
+
xmax = np.max(xx) - pad_w
|
100 |
+
xmin = np.clip(xmin, 0, w - 1)
|
101 |
+
xmax = np.clip(xmax, 0, w - 1)
|
102 |
+
yy = np.random.choice(h_axis, size=2)
|
103 |
+
ymin = np.min(yy) - pad_h
|
104 |
+
ymax = np.max(yy) - pad_h
|
105 |
+
ymin = np.clip(ymin, 0, h - 1)
|
106 |
+
ymax = np.clip(ymax, 0, h - 1)
|
107 |
+
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
|
108 |
+
# area too small
|
109 |
+
continue
|
110 |
+
|
111 |
+
pts = np.stack(
|
112 |
+
[[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]
|
113 |
+
).T.astype(np.int32)
|
114 |
+
pp = Polygon(pts)
|
115 |
+
fail_flag = False
|
116 |
+
for polygon, ignore_tag in zip(polygons, ignore_tags):
|
117 |
+
ppi = Polygon(polygon.reshape(-1, 2))
|
118 |
+
ppiou, _ = poly_intersection(ppi, pp, buffer=0)
|
119 |
+
if (
|
120 |
+
np.abs(ppiou - float(ppi.area)) > self.epsilon
|
121 |
+
and np.abs(ppiou) > self.epsilon
|
122 |
+
):
|
123 |
+
fail_flag = True
|
124 |
+
break
|
125 |
+
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
126 |
+
polys_new.append(polygon)
|
127 |
+
ignore_tags_new.append(ignore_tag)
|
128 |
+
else:
|
129 |
+
polys_keep.append(polygon)
|
130 |
+
ignore_tags_keep.append(ignore_tag)
|
131 |
+
|
132 |
+
if fail_flag:
|
133 |
+
continue
|
134 |
+
else:
|
135 |
+
break
|
136 |
+
|
137 |
+
cropped = image[ymin:ymax, xmin:xmax, :]
|
138 |
+
select_type = np.random.randint(3)
|
139 |
+
if select_type == 0:
|
140 |
+
img = np.ascontiguousarray(cropped[:, ::-1])
|
141 |
+
elif select_type == 1:
|
142 |
+
img = np.ascontiguousarray(cropped[::-1, :])
|
143 |
+
else:
|
144 |
+
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
145 |
+
image[ymin:ymax, xmin:xmax, :] = img
|
146 |
+
results["img"] = image
|
147 |
+
|
148 |
+
if len(polys_new) != 0:
|
149 |
+
height, width, _ = cropped.shape
|
150 |
+
if select_type == 0:
|
151 |
+
for idx, polygon in enumerate(polys_new):
|
152 |
+
poly = polygon.reshape(-1, 2)
|
153 |
+
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
154 |
+
polys_new[idx] = poly
|
155 |
+
elif select_type == 1:
|
156 |
+
for idx, polygon in enumerate(polys_new):
|
157 |
+
poly = polygon.reshape(-1, 2)
|
158 |
+
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
159 |
+
polys_new[idx] = poly
|
160 |
+
else:
|
161 |
+
for idx, polygon in enumerate(polys_new):
|
162 |
+
poly = polygon.reshape(-1, 2)
|
163 |
+
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
164 |
+
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
165 |
+
polys_new[idx] = poly
|
166 |
+
polygons = polys_keep + polys_new
|
167 |
+
ignore_tags = ignore_tags_keep + ignore_tags_new
|
168 |
+
results["polys"] = np.array(polygons)
|
169 |
+
results["ignore_tags"] = ignore_tags
|
170 |
+
|
171 |
+
return results
|
172 |
+
|
173 |
+
def generate_crop_target(self, image, all_polys, pad_h, pad_w):
|
174 |
+
"""Generate crop target and make sure not to crop the polygon
|
175 |
+
instances.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
image (ndarray): The image waited to be crop.
|
179 |
+
all_polys (list[list[ndarray]]): All polygons including ground
|
180 |
+
truth polygons and ground truth ignored polygons.
|
181 |
+
pad_h (int): Padding length of height.
|
182 |
+
pad_w (int): Padding length of width.
|
183 |
+
Returns:
|
184 |
+
h_axis (ndarray): Vertical cropping range.
|
185 |
+
w_axis (ndarray): Horizontal cropping range.
|
186 |
+
"""
|
187 |
+
h, w, _ = image.shape
|
188 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
189 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
190 |
+
|
191 |
+
text_polys = []
|
192 |
+
for polygon in all_polys:
|
193 |
+
rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
|
194 |
+
box = cv2.boxPoints(rect)
|
195 |
+
box = np.int0(box)
|
196 |
+
text_polys.append([box[0], box[1], box[2], box[3]])
|
197 |
+
|
198 |
+
polys = np.array(text_polys, dtype=np.int32)
|
199 |
+
for poly in polys:
|
200 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
201 |
+
minx = np.min(poly[:, 0])
|
202 |
+
maxx = np.max(poly[:, 0])
|
203 |
+
w_array[minx + pad_w : maxx + pad_w] = 1
|
204 |
+
miny = np.min(poly[:, 1])
|
205 |
+
maxy = np.max(poly[:, 1])
|
206 |
+
h_array[miny + pad_h : maxy + pad_h] = 1
|
207 |
+
|
208 |
+
h_axis = np.where(h_array == 0)[0]
|
209 |
+
w_axis = np.where(w_array == 0)[0]
|
210 |
+
return h_axis, w_axis
|
211 |
+
|
212 |
+
|
213 |
+
class RandomCropPolyInstances:
|
214 |
+
"""Randomly crop images and make sure to contain at least one intact
|
215 |
+
instance."""
|
216 |
+
|
217 |
+
def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
|
218 |
+
super().__init__()
|
219 |
+
self.crop_ratio = crop_ratio
|
220 |
+
self.min_side_ratio = min_side_ratio
|
221 |
+
|
222 |
+
def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
|
223 |
+
|
224 |
+
assert isinstance(min_len, int)
|
225 |
+
assert len(valid_array) > min_len
|
226 |
+
|
227 |
+
start_array = valid_array.copy()
|
228 |
+
max_start = min(len(start_array) - min_len, max_start)
|
229 |
+
start_array[max_start:] = 0
|
230 |
+
start_array[0] = 1
|
231 |
+
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
|
232 |
+
region_starts = np.where(diff_array < 0)[0]
|
233 |
+
region_ends = np.where(diff_array > 0)[0]
|
234 |
+
region_ind = np.random.randint(0, len(region_starts))
|
235 |
+
start = np.random.randint(region_starts[region_ind], region_ends[region_ind])
|
236 |
+
|
237 |
+
end_array = valid_array.copy()
|
238 |
+
min_end = max(start + min_len, min_end)
|
239 |
+
end_array[:min_end] = 0
|
240 |
+
end_array[-1] = 1
|
241 |
+
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
|
242 |
+
region_starts = np.where(diff_array < 0)[0]
|
243 |
+
region_ends = np.where(diff_array > 0)[0]
|
244 |
+
region_ind = np.random.randint(0, len(region_starts))
|
245 |
+
end = np.random.randint(region_starts[region_ind], region_ends[region_ind])
|
246 |
+
return start, end
|
247 |
+
|
248 |
+
def sample_crop_box(self, img_size, results):
|
249 |
+
"""Generate crop box and make sure not to crop the polygon instances.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
img_size (tuple(int)): The image size (h, w).
|
253 |
+
results (dict): The results dict.
|
254 |
+
"""
|
255 |
+
|
256 |
+
assert isinstance(img_size, tuple)
|
257 |
+
h, w = img_size[:2]
|
258 |
+
|
259 |
+
key_masks = results["polys"]
|
260 |
+
|
261 |
+
x_valid_array = np.ones(w, dtype=np.int32)
|
262 |
+
y_valid_array = np.ones(h, dtype=np.int32)
|
263 |
+
|
264 |
+
selected_mask = key_masks[np.random.randint(0, len(key_masks))]
|
265 |
+
selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
|
266 |
+
max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
|
267 |
+
min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
|
268 |
+
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
|
269 |
+
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
|
270 |
+
|
271 |
+
for mask in key_masks:
|
272 |
+
mask = mask.reshape((-1, 2)).astype(np.int32)
|
273 |
+
clip_x = np.clip(mask[:, 0], 0, w - 1)
|
274 |
+
clip_y = np.clip(mask[:, 1], 0, h - 1)
|
275 |
+
min_x, max_x = np.min(clip_x), np.max(clip_x)
|
276 |
+
min_y, max_y = np.min(clip_y), np.max(clip_y)
|
277 |
+
|
278 |
+
x_valid_array[min_x - 2 : max_x + 3] = 0
|
279 |
+
y_valid_array[min_y - 2 : max_y + 3] = 0
|
280 |
+
|
281 |
+
min_w = int(w * self.min_side_ratio)
|
282 |
+
min_h = int(h * self.min_side_ratio)
|
283 |
+
|
284 |
+
x1, x2 = self.sample_valid_start_end(
|
285 |
+
x_valid_array, min_w, max_x_start, min_x_end
|
286 |
+
)
|
287 |
+
y1, y2 = self.sample_valid_start_end(
|
288 |
+
y_valid_array, min_h, max_y_start, min_y_end
|
289 |
+
)
|
290 |
+
|
291 |
+
return np.array([x1, y1, x2, y2])
|
292 |
+
|
293 |
+
def crop_img(self, img, bbox):
|
294 |
+
assert img.ndim == 3
|
295 |
+
h, w, _ = img.shape
|
296 |
+
assert 0 <= bbox[1] < bbox[3] <= h
|
297 |
+
assert 0 <= bbox[0] < bbox[2] <= w
|
298 |
+
return img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
299 |
+
|
300 |
+
def __call__(self, results):
|
301 |
+
image = results["image"]
|
302 |
+
polygons = results["polys"]
|
303 |
+
ignore_tags = results["ignore_tags"]
|
304 |
+
if len(polygons) < 1:
|
305 |
+
return results
|
306 |
+
|
307 |
+
if np.random.random_sample() < self.crop_ratio:
|
308 |
+
|
309 |
+
crop_box = self.sample_crop_box(image.shape, results)
|
310 |
+
img = self.crop_img(image, crop_box)
|
311 |
+
results["image"] = img
|
312 |
+
# crop and filter masks
|
313 |
+
x1, y1, x2, y2 = crop_box
|
314 |
+
w = max(x2 - x1, 1)
|
315 |
+
h = max(y2 - y1, 1)
|
316 |
+
polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
|
317 |
+
polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
|
318 |
+
|
319 |
+
valid_masks_list = []
|
320 |
+
valid_tags_list = []
|
321 |
+
for ind, polygon in enumerate(polygons):
|
322 |
+
if (
|
323 |
+
(polygon[:, ::2] > -4).all()
|
324 |
+
and (polygon[:, ::2] < w + 4).all()
|
325 |
+
and (polygon[:, 1::2] > -4).all()
|
326 |
+
and (polygon[:, 1::2] < h + 4).all()
|
327 |
+
):
|
328 |
+
polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
|
329 |
+
polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
|
330 |
+
valid_masks_list.append(polygon)
|
331 |
+
valid_tags_list.append(ignore_tags[ind])
|
332 |
+
|
333 |
+
results["polys"] = np.array(valid_masks_list)
|
334 |
+
results["ignore_tags"] = valid_tags_list
|
335 |
+
|
336 |
+
return results
|
337 |
+
|
338 |
+
def __repr__(self):
|
339 |
+
repr_str = self.__class__.__name__
|
340 |
+
return repr_str
|
341 |
+
|
342 |
+
|
343 |
+
class RandomRotatePolyInstances:
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
rotate_ratio=0.5,
|
347 |
+
max_angle=10,
|
348 |
+
pad_with_fixed_color=False,
|
349 |
+
pad_value=(0, 0, 0),
|
350 |
+
**kwargs
|
351 |
+
):
|
352 |
+
"""Randomly rotate images and polygon masks.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
rotate_ratio (float): The ratio of samples to operate rotation.
|
356 |
+
max_angle (int): The maximum rotation angle.
|
357 |
+
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
358 |
+
image with fixed value. If set to False, the rotated image will
|
359 |
+
be padded onto cropped image.
|
360 |
+
pad_value (tuple(int)): The color value for padding rotated image.
|
361 |
+
"""
|
362 |
+
self.rotate_ratio = rotate_ratio
|
363 |
+
self.max_angle = max_angle
|
364 |
+
self.pad_with_fixed_color = pad_with_fixed_color
|
365 |
+
self.pad_value = pad_value
|
366 |
+
|
367 |
+
def rotate(self, center, points, theta, center_shift=(0, 0)):
|
368 |
+
# rotate points.
|
369 |
+
(center_x, center_y) = center
|
370 |
+
center_y = -center_y
|
371 |
+
x, y = points[:, ::2], points[:, 1::2]
|
372 |
+
y = -y
|
373 |
+
|
374 |
+
theta = theta / 180 * math.pi
|
375 |
+
cos = math.cos(theta)
|
376 |
+
sin = math.sin(theta)
|
377 |
+
|
378 |
+
x = x - center_x
|
379 |
+
y = y - center_y
|
380 |
+
|
381 |
+
_x = center_x + x * cos - y * sin + center_shift[0]
|
382 |
+
_y = -(center_y + x * sin + y * cos) + center_shift[1]
|
383 |
+
|
384 |
+
points[:, ::2], points[:, 1::2] = _x, _y
|
385 |
+
return points
|
386 |
+
|
387 |
+
def cal_canvas_size(self, ori_size, degree):
|
388 |
+
assert isinstance(ori_size, tuple)
|
389 |
+
angle = degree * math.pi / 180.0
|
390 |
+
h, w = ori_size[:2]
|
391 |
+
|
392 |
+
cos = math.cos(angle)
|
393 |
+
sin = math.sin(angle)
|
394 |
+
canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
|
395 |
+
canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
|
396 |
+
|
397 |
+
canvas_size = (canvas_h, canvas_w)
|
398 |
+
return canvas_size
|
399 |
+
|
400 |
+
def sample_angle(self, max_angle):
|
401 |
+
angle = np.random.random_sample() * 2 * max_angle - max_angle
|
402 |
+
return angle
|
403 |
+
|
404 |
+
def rotate_img(self, img, angle, canvas_size):
|
405 |
+
h, w = img.shape[:2]
|
406 |
+
rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
407 |
+
rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
|
408 |
+
rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
|
409 |
+
|
410 |
+
if self.pad_with_fixed_color:
|
411 |
+
target_img = cv2.warpAffine(
|
412 |
+
img,
|
413 |
+
rotation_matrix,
|
414 |
+
(canvas_size[1], canvas_size[0]),
|
415 |
+
flags=cv2.INTER_NEAREST,
|
416 |
+
borderValue=self.pad_value,
|
417 |
+
)
|
418 |
+
else:
|
419 |
+
mask = np.zeros_like(img)
|
420 |
+
(h_ind, w_ind) = (
|
421 |
+
np.random.randint(0, h * 7 // 8),
|
422 |
+
np.random.randint(0, w * 7 // 8),
|
423 |
+
)
|
424 |
+
img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
|
425 |
+
img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
|
426 |
+
|
427 |
+
mask = cv2.warpAffine(
|
428 |
+
mask,
|
429 |
+
rotation_matrix,
|
430 |
+
(canvas_size[1], canvas_size[0]),
|
431 |
+
borderValue=[1, 1, 1],
|
432 |
+
)
|
433 |
+
target_img = cv2.warpAffine(
|
434 |
+
img,
|
435 |
+
rotation_matrix,
|
436 |
+
(canvas_size[1], canvas_size[0]),
|
437 |
+
borderValue=[0, 0, 0],
|
438 |
+
)
|
439 |
+
target_img = target_img + img_cut * mask
|
440 |
+
|
441 |
+
return target_img
|
442 |
+
|
443 |
+
def __call__(self, results):
|
444 |
+
if np.random.random_sample() < self.rotate_ratio:
|
445 |
+
image = results["image"]
|
446 |
+
polygons = results["polys"]
|
447 |
+
h, w = image.shape[:2]
|
448 |
+
|
449 |
+
angle = self.sample_angle(self.max_angle)
|
450 |
+
canvas_size = self.cal_canvas_size((h, w), angle)
|
451 |
+
center_shift = (
|
452 |
+
int((canvas_size[1] - w) / 2),
|
453 |
+
int((canvas_size[0] - h) / 2),
|
454 |
+
)
|
455 |
+
image = self.rotate_img(image, angle, canvas_size)
|
456 |
+
results["image"] = image
|
457 |
+
# rotate polygons
|
458 |
+
rotated_masks = []
|
459 |
+
for mask in polygons:
|
460 |
+
rotated_mask = self.rotate((w / 2, h / 2), mask, angle, center_shift)
|
461 |
+
rotated_masks.append(rotated_mask)
|
462 |
+
results["polys"] = np.array(rotated_masks)
|
463 |
+
|
464 |
+
return results
|
465 |
+
|
466 |
+
def __repr__(self):
|
467 |
+
repr_str = self.__class__.__name__
|
468 |
+
return repr_str
|
469 |
+
|
470 |
+
|
471 |
+
class SquareResizePad:
|
472 |
+
def __init__(
|
473 |
+
self,
|
474 |
+
target_size,
|
475 |
+
pad_ratio=0.6,
|
476 |
+
pad_with_fixed_color=False,
|
477 |
+
pad_value=(0, 0, 0),
|
478 |
+
**kwargs
|
479 |
+
):
|
480 |
+
"""Resize or pad images to be square shape.
|
481 |
+
|
482 |
+
Args:
|
483 |
+
target_size (int): The target size of square shaped image.
|
484 |
+
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
485 |
+
image with fixed value. If set to False, the rescales image will
|
486 |
+
be padded onto cropped image.
|
487 |
+
pad_value (tuple(int)): The color value for padding rotated image.
|
488 |
+
"""
|
489 |
+
assert isinstance(target_size, int)
|
490 |
+
assert isinstance(pad_ratio, float)
|
491 |
+
assert isinstance(pad_with_fixed_color, bool)
|
492 |
+
assert isinstance(pad_value, tuple)
|
493 |
+
|
494 |
+
self.target_size = target_size
|
495 |
+
self.pad_ratio = pad_ratio
|
496 |
+
self.pad_with_fixed_color = pad_with_fixed_color
|
497 |
+
self.pad_value = pad_value
|
498 |
+
|
499 |
+
def resize_img(self, img, keep_ratio=True):
|
500 |
+
h, w, _ = img.shape
|
501 |
+
if keep_ratio:
|
502 |
+
t_h = self.target_size if h >= w else int(h * self.target_size / w)
|
503 |
+
t_w = self.target_size if h <= w else int(w * self.target_size / h)
|
504 |
+
else:
|
505 |
+
t_h = t_w = self.target_size
|
506 |
+
img = cv2.resize(img, (t_w, t_h))
|
507 |
+
return img, (t_h, t_w)
|
508 |
+
|
509 |
+
def square_pad(self, img):
|
510 |
+
h, w = img.shape[:2]
|
511 |
+
if h == w:
|
512 |
+
return img, (0, 0)
|
513 |
+
pad_size = max(h, w)
|
514 |
+
if self.pad_with_fixed_color:
|
515 |
+
expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
|
516 |
+
expand_img[:] = self.pad_value
|
517 |
+
else:
|
518 |
+
(h_ind, w_ind) = (
|
519 |
+
np.random.randint(0, h * 7 // 8),
|
520 |
+
np.random.randint(0, w * 7 // 8),
|
521 |
+
)
|
522 |
+
img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
|
523 |
+
expand_img = cv2.resize(img_cut, (pad_size, pad_size))
|
524 |
+
if h > w:
|
525 |
+
y0, x0 = 0, (h - w) // 2
|
526 |
+
else:
|
527 |
+
y0, x0 = (w - h) // 2, 0
|
528 |
+
expand_img[y0 : y0 + h, x0 : x0 + w] = img
|
529 |
+
offset = (x0, y0)
|
530 |
+
|
531 |
+
return expand_img, offset
|
532 |
+
|
533 |
+
def square_pad_mask(self, points, offset):
|
534 |
+
x0, y0 = offset
|
535 |
+
pad_points = points.copy()
|
536 |
+
pad_points[::2] = pad_points[::2] + x0
|
537 |
+
pad_points[1::2] = pad_points[1::2] + y0
|
538 |
+
return pad_points
|
539 |
+
|
540 |
+
def __call__(self, results):
|
541 |
+
image = results["image"]
|
542 |
+
polygons = results["polys"]
|
543 |
+
h, w = image.shape[:2]
|
544 |
+
|
545 |
+
if np.random.random_sample() < self.pad_ratio:
|
546 |
+
image, out_size = self.resize_img(image, keep_ratio=True)
|
547 |
+
image, offset = self.square_pad(image)
|
548 |
+
else:
|
549 |
+
image, out_size = self.resize_img(image, keep_ratio=False)
|
550 |
+
offset = (0, 0)
|
551 |
+
results["image"] = image
|
552 |
+
try:
|
553 |
+
polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[0]
|
554 |
+
polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[1]
|
555 |
+
except:
|
556 |
+
pass
|
557 |
+
results["polys"] = polygons
|
558 |
+
|
559 |
+
return results
|
560 |
+
|
561 |
+
def __repr__(self):
|
562 |
+
repr_str = self.__class__.__name__
|
563 |
+
return repr_str
|
ocr/ppocr/data/imaug/fce_targets.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from numpy.fft import fft
|
4 |
+
from numpy.linalg import norm
|
5 |
+
|
6 |
+
|
7 |
+
def vector_slope(vec):
|
8 |
+
assert len(vec) == 2
|
9 |
+
return abs(vec[1] / (vec[0] + 1e-8))
|
10 |
+
|
11 |
+
|
12 |
+
class FCENetTargets:
|
13 |
+
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
|
14 |
+
for Arbitrary-Shaped Text Detection.
|
15 |
+
|
16 |
+
[https://arxiv.org/abs/2104.10442]
|
17 |
+
|
18 |
+
Args:
|
19 |
+
fourier_degree (int): The maximum Fourier transform degree k.
|
20 |
+
resample_step (float): The step size for resampling the text center
|
21 |
+
line (TCL). It's better not to exceed half of the minimum width.
|
22 |
+
center_region_shrink_ratio (float): The shrink ratio of text center
|
23 |
+
region.
|
24 |
+
level_size_divisors (tuple(int)): The downsample ratio on each level.
|
25 |
+
level_proportion_range (tuple(tuple(int))): The range of text sizes
|
26 |
+
assigned to each level.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
fourier_degree=5,
|
32 |
+
resample_step=4.0,
|
33 |
+
center_region_shrink_ratio=0.3,
|
34 |
+
level_size_divisors=(8, 16, 32),
|
35 |
+
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
|
36 |
+
orientation_thr=2.0,
|
37 |
+
**kwargs
|
38 |
+
):
|
39 |
+
|
40 |
+
super().__init__()
|
41 |
+
assert isinstance(level_size_divisors, tuple)
|
42 |
+
assert isinstance(level_proportion_range, tuple)
|
43 |
+
assert len(level_size_divisors) == len(level_proportion_range)
|
44 |
+
self.fourier_degree = fourier_degree
|
45 |
+
self.resample_step = resample_step
|
46 |
+
self.center_region_shrink_ratio = center_region_shrink_ratio
|
47 |
+
self.level_size_divisors = level_size_divisors
|
48 |
+
self.level_proportion_range = level_proportion_range
|
49 |
+
|
50 |
+
self.orientation_thr = orientation_thr
|
51 |
+
|
52 |
+
def vector_angle(self, vec1, vec2):
|
53 |
+
if vec1.ndim > 1:
|
54 |
+
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
|
55 |
+
else:
|
56 |
+
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
|
57 |
+
if vec2.ndim > 1:
|
58 |
+
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
|
59 |
+
else:
|
60 |
+
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
|
61 |
+
return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
|
62 |
+
|
63 |
+
def resample_line(self, line, n):
|
64 |
+
"""Resample n points on a line.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
line (ndarray): The points composing a line.
|
68 |
+
n (int): The resampled points number.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
resampled_line (ndarray): The points composing the resampled line.
|
72 |
+
"""
|
73 |
+
|
74 |
+
assert line.ndim == 2
|
75 |
+
assert line.shape[0] >= 2
|
76 |
+
assert line.shape[1] == 2
|
77 |
+
assert isinstance(n, int)
|
78 |
+
assert n > 0
|
79 |
+
|
80 |
+
length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
|
81 |
+
total_length = sum(length_list)
|
82 |
+
length_cumsum = np.cumsum([0.0] + length_list)
|
83 |
+
delta_length = total_length / (float(n) + 1e-8)
|
84 |
+
|
85 |
+
current_edge_ind = 0
|
86 |
+
resampled_line = [line[0]]
|
87 |
+
|
88 |
+
for i in range(1, n):
|
89 |
+
current_line_len = i * delta_length
|
90 |
+
|
91 |
+
while current_line_len >= length_cumsum[current_edge_ind + 1]:
|
92 |
+
current_edge_ind += 1
|
93 |
+
current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
|
94 |
+
end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
|
95 |
+
current_point = (
|
96 |
+
line[current_edge_ind]
|
97 |
+
+ (line[current_edge_ind + 1] - line[current_edge_ind])
|
98 |
+
* end_shift_ratio
|
99 |
+
)
|
100 |
+
resampled_line.append(current_point)
|
101 |
+
|
102 |
+
resampled_line.append(line[-1])
|
103 |
+
resampled_line = np.array(resampled_line)
|
104 |
+
|
105 |
+
return resampled_line
|
106 |
+
|
107 |
+
def reorder_poly_edge(self, points):
|
108 |
+
"""Get the respective points composing head edge, tail edge, top
|
109 |
+
sideline and bottom sideline.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
points (ndarray): The points composing a text polygon.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
head_edge (ndarray): The two points composing the head edge of text
|
116 |
+
polygon.
|
117 |
+
tail_edge (ndarray): The two points composing the tail edge of text
|
118 |
+
polygon.
|
119 |
+
top_sideline (ndarray): The points composing top curved sideline of
|
120 |
+
text polygon.
|
121 |
+
bot_sideline (ndarray): The points composing bottom curved sideline
|
122 |
+
of text polygon.
|
123 |
+
"""
|
124 |
+
|
125 |
+
assert points.ndim == 2
|
126 |
+
assert points.shape[0] >= 4
|
127 |
+
assert points.shape[1] == 2
|
128 |
+
|
129 |
+
head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
|
130 |
+
head_edge, tail_edge = points[head_inds], points[tail_inds]
|
131 |
+
|
132 |
+
pad_points = np.vstack([points, points])
|
133 |
+
if tail_inds[1] < 1:
|
134 |
+
tail_inds[1] = len(points)
|
135 |
+
sideline1 = pad_points[head_inds[1] : tail_inds[1]]
|
136 |
+
sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
|
137 |
+
sideline_mean_shift = np.mean(sideline1, axis=0) - np.mean(sideline2, axis=0)
|
138 |
+
|
139 |
+
if sideline_mean_shift[1] > 0:
|
140 |
+
top_sideline, bot_sideline = sideline2, sideline1
|
141 |
+
else:
|
142 |
+
top_sideline, bot_sideline = sideline1, sideline2
|
143 |
+
|
144 |
+
return head_edge, tail_edge, top_sideline, bot_sideline
|
145 |
+
|
146 |
+
def find_head_tail(self, points, orientation_thr):
|
147 |
+
"""Find the head edge and tail edge of a text polygon.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
points (ndarray): The points composing a text polygon.
|
151 |
+
orientation_thr (float): The threshold for distinguishing between
|
152 |
+
head edge and tail edge among the horizontal and vertical edges
|
153 |
+
of a quadrangle.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
head_inds (list): The indexes of two points composing head edge.
|
157 |
+
tail_inds (list): The indexes of two points composing tail edge.
|
158 |
+
"""
|
159 |
+
|
160 |
+
assert points.ndim == 2
|
161 |
+
assert points.shape[0] >= 4
|
162 |
+
assert points.shape[1] == 2
|
163 |
+
assert isinstance(orientation_thr, float)
|
164 |
+
|
165 |
+
if len(points) > 4:
|
166 |
+
pad_points = np.vstack([points, points[0]])
|
167 |
+
edge_vec = pad_points[1:] - pad_points[:-1]
|
168 |
+
|
169 |
+
theta_sum = []
|
170 |
+
adjacent_vec_theta = []
|
171 |
+
for i, edge_vec1 in enumerate(edge_vec):
|
172 |
+
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
173 |
+
adjacent_edge_vec = edge_vec[adjacent_ind]
|
174 |
+
temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
|
175 |
+
temp_adjacent_theta = self.vector_angle(
|
176 |
+
adjacent_edge_vec[0], adjacent_edge_vec[1]
|
177 |
+
)
|
178 |
+
theta_sum.append(temp_theta_sum)
|
179 |
+
adjacent_vec_theta.append(temp_adjacent_theta)
|
180 |
+
theta_sum_score = np.array(theta_sum) / np.pi
|
181 |
+
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
|
182 |
+
poly_center = np.mean(points, axis=0)
|
183 |
+
edge_dist = np.maximum(
|
184 |
+
norm(pad_points[1:] - poly_center, axis=-1),
|
185 |
+
norm(pad_points[:-1] - poly_center, axis=-1),
|
186 |
+
)
|
187 |
+
dist_score = edge_dist / np.max(edge_dist)
|
188 |
+
position_score = np.zeros(len(edge_vec))
|
189 |
+
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
|
190 |
+
score += 0.35 * dist_score
|
191 |
+
if len(points) % 2 == 0:
|
192 |
+
position_score[(len(score) // 2 - 1)] += 1
|
193 |
+
position_score[-1] += 1
|
194 |
+
score += 0.1 * position_score
|
195 |
+
pad_score = np.concatenate([score, score])
|
196 |
+
score_matrix = np.zeros((len(score), len(score) - 3))
|
197 |
+
x = np.arange(len(score) - 3) / float(len(score) - 4)
|
198 |
+
gaussian = (
|
199 |
+
1.0
|
200 |
+
/ (np.sqrt(2.0 * np.pi) * 0.5)
|
201 |
+
* np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
|
202 |
+
)
|
203 |
+
gaussian = gaussian / np.max(gaussian)
|
204 |
+
for i in range(len(score)):
|
205 |
+
score_matrix[i, :] = (
|
206 |
+
score[i]
|
207 |
+
+ pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
|
208 |
+
)
|
209 |
+
|
210 |
+
head_start, tail_increment = np.unravel_index(
|
211 |
+
score_matrix.argmax(), score_matrix.shape
|
212 |
+
)
|
213 |
+
tail_start = (head_start + tail_increment + 2) % len(points)
|
214 |
+
head_end = (head_start + 1) % len(points)
|
215 |
+
tail_end = (tail_start + 1) % len(points)
|
216 |
+
|
217 |
+
if head_end > tail_end:
|
218 |
+
head_start, tail_start = tail_start, head_start
|
219 |
+
head_end, tail_end = tail_end, head_end
|
220 |
+
head_inds = [head_start, head_end]
|
221 |
+
tail_inds = [tail_start, tail_end]
|
222 |
+
else:
|
223 |
+
if vector_slope(points[1] - points[0]) + vector_slope(
|
224 |
+
points[3] - points[2]
|
225 |
+
) < vector_slope(points[2] - points[1]) + vector_slope(
|
226 |
+
points[0] - points[3]
|
227 |
+
):
|
228 |
+
horizontal_edge_inds = [[0, 1], [2, 3]]
|
229 |
+
vertical_edge_inds = [[3, 0], [1, 2]]
|
230 |
+
else:
|
231 |
+
horizontal_edge_inds = [[3, 0], [1, 2]]
|
232 |
+
vertical_edge_inds = [[0, 1], [2, 3]]
|
233 |
+
|
234 |
+
vertical_len_sum = norm(
|
235 |
+
points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
|
236 |
+
) + norm(
|
237 |
+
points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
|
238 |
+
)
|
239 |
+
horizontal_len_sum = norm(
|
240 |
+
points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
|
241 |
+
) + norm(
|
242 |
+
points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
|
243 |
+
)
|
244 |
+
|
245 |
+
if vertical_len_sum > horizontal_len_sum * orientation_thr:
|
246 |
+
head_inds = horizontal_edge_inds[0]
|
247 |
+
tail_inds = horizontal_edge_inds[1]
|
248 |
+
else:
|
249 |
+
head_inds = vertical_edge_inds[0]
|
250 |
+
tail_inds = vertical_edge_inds[1]
|
251 |
+
|
252 |
+
return head_inds, tail_inds
|
253 |
+
|
254 |
+
def resample_sidelines(self, sideline1, sideline2, resample_step):
|
255 |
+
"""Resample two sidelines to be of the same points number according to
|
256 |
+
step size.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
sideline1 (ndarray): The points composing a sideline of a text
|
260 |
+
polygon.
|
261 |
+
sideline2 (ndarray): The points composing another sideline of a
|
262 |
+
text polygon.
|
263 |
+
resample_step (float): The resampled step size.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
resampled_line1 (ndarray): The resampled line 1.
|
267 |
+
resampled_line2 (ndarray): The resampled line 2.
|
268 |
+
"""
|
269 |
+
|
270 |
+
assert sideline1.ndim == sideline2.ndim == 2
|
271 |
+
assert sideline1.shape[1] == sideline2.shape[1] == 2
|
272 |
+
assert sideline1.shape[0] >= 2
|
273 |
+
assert sideline2.shape[0] >= 2
|
274 |
+
assert isinstance(resample_step, float)
|
275 |
+
|
276 |
+
length1 = sum(
|
277 |
+
[norm(sideline1[i + 1] - sideline1[i]) for i in range(len(sideline1) - 1)]
|
278 |
+
)
|
279 |
+
length2 = sum(
|
280 |
+
[norm(sideline2[i + 1] - sideline2[i]) for i in range(len(sideline2) - 1)]
|
281 |
+
)
|
282 |
+
|
283 |
+
total_length = (length1 + length2) / 2
|
284 |
+
resample_point_num = max(int(float(total_length) / resample_step), 1)
|
285 |
+
|
286 |
+
resampled_line1 = self.resample_line(sideline1, resample_point_num)
|
287 |
+
resampled_line2 = self.resample_line(sideline2, resample_point_num)
|
288 |
+
|
289 |
+
return resampled_line1, resampled_line2
|
290 |
+
|
291 |
+
def generate_center_region_mask(self, img_size, text_polys):
|
292 |
+
"""Generate text center region mask.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
img_size (tuple): The image size of (height, width).
|
296 |
+
text_polys (list[list[ndarray]]): The list of text polygons.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
center_region_mask (ndarray): The text center region mask.
|
300 |
+
"""
|
301 |
+
|
302 |
+
assert isinstance(img_size, tuple)
|
303 |
+
# assert check_argument.is_2dlist(text_polys)
|
304 |
+
|
305 |
+
h, w = img_size
|
306 |
+
|
307 |
+
center_region_mask = np.zeros((h, w), np.uint8)
|
308 |
+
|
309 |
+
center_region_boxes = []
|
310 |
+
for poly in text_polys:
|
311 |
+
# assert len(poly) == 1
|
312 |
+
polygon_points = poly.reshape(-1, 2)
|
313 |
+
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
314 |
+
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
315 |
+
top_line, bot_line, self.resample_step
|
316 |
+
)
|
317 |
+
resampled_bot_line = resampled_bot_line[::-1]
|
318 |
+
center_line = (resampled_top_line + resampled_bot_line) / 2
|
319 |
+
|
320 |
+
line_head_shrink_len = (
|
321 |
+
norm(resampled_top_line[0] - resampled_bot_line[0]) / 4.0
|
322 |
+
)
|
323 |
+
line_tail_shrink_len = (
|
324 |
+
norm(resampled_top_line[-1] - resampled_bot_line[-1]) / 4.0
|
325 |
+
)
|
326 |
+
head_shrink_num = int(line_head_shrink_len // self.resample_step)
|
327 |
+
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
|
328 |
+
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
|
329 |
+
center_line = center_line[
|
330 |
+
head_shrink_num : len(center_line) - tail_shrink_num
|
331 |
+
]
|
332 |
+
resampled_top_line = resampled_top_line[
|
333 |
+
head_shrink_num : len(resampled_top_line) - tail_shrink_num
|
334 |
+
]
|
335 |
+
resampled_bot_line = resampled_bot_line[
|
336 |
+
head_shrink_num : len(resampled_bot_line) - tail_shrink_num
|
337 |
+
]
|
338 |
+
|
339 |
+
for i in range(0, len(center_line) - 1):
|
340 |
+
tl = (
|
341 |
+
center_line[i]
|
342 |
+
+ (resampled_top_line[i] - center_line[i])
|
343 |
+
* self.center_region_shrink_ratio
|
344 |
+
)
|
345 |
+
tr = (
|
346 |
+
center_line[i + 1]
|
347 |
+
+ (resampled_top_line[i + 1] - center_line[i + 1])
|
348 |
+
* self.center_region_shrink_ratio
|
349 |
+
)
|
350 |
+
br = (
|
351 |
+
center_line[i + 1]
|
352 |
+
+ (resampled_bot_line[i + 1] - center_line[i + 1])
|
353 |
+
* self.center_region_shrink_ratio
|
354 |
+
)
|
355 |
+
bl = (
|
356 |
+
center_line[i]
|
357 |
+
+ (resampled_bot_line[i] - center_line[i])
|
358 |
+
* self.center_region_shrink_ratio
|
359 |
+
)
|
360 |
+
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
|
361 |
+
center_region_boxes.append(current_center_box)
|
362 |
+
|
363 |
+
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
|
364 |
+
return center_region_mask
|
365 |
+
|
366 |
+
def resample_polygon(self, polygon, n=400):
|
367 |
+
"""Resample one polygon with n points on its boundary.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
polygon (list[float]): The input polygon.
|
371 |
+
n (int): The number of resampled points.
|
372 |
+
Returns:
|
373 |
+
resampled_polygon (list[float]): The resampled polygon.
|
374 |
+
"""
|
375 |
+
length = []
|
376 |
+
|
377 |
+
for i in range(len(polygon)):
|
378 |
+
p1 = polygon[i]
|
379 |
+
if i == len(polygon) - 1:
|
380 |
+
p2 = polygon[0]
|
381 |
+
else:
|
382 |
+
p2 = polygon[i + 1]
|
383 |
+
length.append(((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5)
|
384 |
+
|
385 |
+
total_length = sum(length)
|
386 |
+
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
|
387 |
+
n_on_each_line = n_on_each_line.astype(np.int32)
|
388 |
+
new_polygon = []
|
389 |
+
|
390 |
+
for i in range(len(polygon)):
|
391 |
+
num = n_on_each_line[i]
|
392 |
+
p1 = polygon[i]
|
393 |
+
if i == len(polygon) - 1:
|
394 |
+
p2 = polygon[0]
|
395 |
+
else:
|
396 |
+
p2 = polygon[i + 1]
|
397 |
+
|
398 |
+
if num == 0:
|
399 |
+
continue
|
400 |
+
|
401 |
+
dxdy = (p2 - p1) / num
|
402 |
+
for j in range(num):
|
403 |
+
point = p1 + dxdy * j
|
404 |
+
new_polygon.append(point)
|
405 |
+
|
406 |
+
return np.array(new_polygon)
|
407 |
+
|
408 |
+
def normalize_polygon(self, polygon):
|
409 |
+
"""Normalize one polygon so that its start point is at right most.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
polygon (list[float]): The origin polygon.
|
413 |
+
Returns:
|
414 |
+
new_polygon (lost[float]): The polygon with start point at right.
|
415 |
+
"""
|
416 |
+
temp_polygon = polygon - polygon.mean(axis=0)
|
417 |
+
x = np.abs(temp_polygon[:, 0])
|
418 |
+
y = temp_polygon[:, 1]
|
419 |
+
index_x = np.argsort(x)
|
420 |
+
index_y = np.argmin(y[index_x[:8]])
|
421 |
+
index = index_x[index_y]
|
422 |
+
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
|
423 |
+
return new_polygon
|
424 |
+
|
425 |
+
def poly2fourier(self, polygon, fourier_degree):
|
426 |
+
"""Perform Fourier transformation to generate Fourier coefficients ck
|
427 |
+
from polygon.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
polygon (ndarray): An input polygon.
|
431 |
+
fourier_degree (int): The maximum Fourier degree K.
|
432 |
+
Returns:
|
433 |
+
c (ndarray(complex)): Fourier coefficients.
|
434 |
+
"""
|
435 |
+
points = polygon[:, 0] + polygon[:, 1] * 1j
|
436 |
+
c_fft = fft(points) / len(points)
|
437 |
+
c = np.hstack((c_fft[-fourier_degree:], c_fft[: fourier_degree + 1]))
|
438 |
+
return c
|
439 |
+
|
440 |
+
def clockwise(self, c, fourier_degree):
|
441 |
+
"""Make sure the polygon reconstructed from Fourier coefficients c in
|
442 |
+
the clockwise direction.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
polygon (list[float]): The origin polygon.
|
446 |
+
Returns:
|
447 |
+
new_polygon (lost[float]): The polygon in clockwise point order.
|
448 |
+
"""
|
449 |
+
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
|
450 |
+
return c
|
451 |
+
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
|
452 |
+
return c[::-1]
|
453 |
+
else:
|
454 |
+
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
|
455 |
+
return c
|
456 |
+
else:
|
457 |
+
return c[::-1]
|
458 |
+
|
459 |
+
def cal_fourier_signature(self, polygon, fourier_degree):
|
460 |
+
"""Calculate Fourier signature from input polygon.
|
461 |
+
|
462 |
+
Args:
|
463 |
+
polygon (ndarray): The input polygon.
|
464 |
+
fourier_degree (int): The maximum Fourier degree K.
|
465 |
+
Returns:
|
466 |
+
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
|
467 |
+
real part and image part of 2k+1 Fourier coefficients.
|
468 |
+
"""
|
469 |
+
resampled_polygon = self.resample_polygon(polygon)
|
470 |
+
resampled_polygon = self.normalize_polygon(resampled_polygon)
|
471 |
+
|
472 |
+
fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
|
473 |
+
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
|
474 |
+
|
475 |
+
real_part = np.real(fourier_coeff).reshape((-1, 1))
|
476 |
+
image_part = np.imag(fourier_coeff).reshape((-1, 1))
|
477 |
+
fourier_signature = np.hstack([real_part, image_part])
|
478 |
+
|
479 |
+
return fourier_signature
|
480 |
+
|
481 |
+
def generate_fourier_maps(self, img_size, text_polys):
|
482 |
+
"""Generate Fourier coefficient maps.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
img_size (tuple): The image size of (height, width).
|
486 |
+
text_polys (list[list[ndarray]]): The list of text polygons.
|
487 |
+
|
488 |
+
Returns:
|
489 |
+
fourier_real_map (ndarray): The Fourier coefficient real part maps.
|
490 |
+
fourier_image_map (ndarray): The Fourier coefficient image part
|
491 |
+
maps.
|
492 |
+
"""
|
493 |
+
|
494 |
+
assert isinstance(img_size, tuple)
|
495 |
+
|
496 |
+
h, w = img_size
|
497 |
+
k = self.fourier_degree
|
498 |
+
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
499 |
+
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
500 |
+
|
501 |
+
for poly in text_polys:
|
502 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
503 |
+
polygon = np.array(poly).reshape((1, -1, 2))
|
504 |
+
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
|
505 |
+
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
|
506 |
+
for i in range(-k, k + 1):
|
507 |
+
if i != 0:
|
508 |
+
real_map[i + k, :, :] = (
|
509 |
+
mask * fourier_coeff[i + k, 0]
|
510 |
+
+ (1 - mask) * real_map[i + k, :, :]
|
511 |
+
)
|
512 |
+
imag_map[i + k, :, :] = (
|
513 |
+
mask * fourier_coeff[i + k, 1]
|
514 |
+
+ (1 - mask) * imag_map[i + k, :, :]
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
yx = np.argwhere(mask > 0.5)
|
518 |
+
k_ind = np.ones((len(yx)), dtype=np.int64) * k
|
519 |
+
y, x = yx[:, 0], yx[:, 1]
|
520 |
+
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
|
521 |
+
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
|
522 |
+
|
523 |
+
return real_map, imag_map
|
524 |
+
|
525 |
+
def generate_text_region_mask(self, img_size, text_polys):
|
526 |
+
"""Generate text center region mask and geometry attribute maps.
|
527 |
+
|
528 |
+
Args:
|
529 |
+
img_size (tuple): The image size (height, width).
|
530 |
+
text_polys (list[list[ndarray]]): The list of text polygons.
|
531 |
+
|
532 |
+
Returns:
|
533 |
+
text_region_mask (ndarray): The text region mask.
|
534 |
+
"""
|
535 |
+
|
536 |
+
assert isinstance(img_size, tuple)
|
537 |
+
|
538 |
+
h, w = img_size
|
539 |
+
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
540 |
+
|
541 |
+
for poly in text_polys:
|
542 |
+
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
|
543 |
+
cv2.fillPoly(text_region_mask, polygon, 1)
|
544 |
+
|
545 |
+
return text_region_mask
|
546 |
+
|
547 |
+
def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
|
548 |
+
"""Generate effective mask by setting the ineffective regions to 0 and
|
549 |
+
effective regions to 1.
|
550 |
+
|
551 |
+
Args:
|
552 |
+
mask_size (tuple): The mask size.
|
553 |
+
polygons_ignore (list[[ndarray]]: The list of ignored text
|
554 |
+
polygons.
|
555 |
+
|
556 |
+
Returns:
|
557 |
+
mask (ndarray): The effective mask of (height, width).
|
558 |
+
"""
|
559 |
+
|
560 |
+
mask = np.ones(mask_size, dtype=np.uint8)
|
561 |
+
|
562 |
+
for poly in polygons_ignore:
|
563 |
+
instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
|
564 |
+
cv2.fillPoly(mask, instance, 0)
|
565 |
+
|
566 |
+
return mask
|
567 |
+
|
568 |
+
def generate_level_targets(self, img_size, text_polys, ignore_polys):
|
569 |
+
"""Generate ground truth target on each level.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
img_size (list[int]): Shape of input image.
|
573 |
+
text_polys (list[list[ndarray]]): A list of ground truth polygons.
|
574 |
+
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
|
575 |
+
Returns:
|
576 |
+
level_maps (list(ndarray)): A list of ground target on each level.
|
577 |
+
"""
|
578 |
+
h, w = img_size
|
579 |
+
lv_size_divs = self.level_size_divisors
|
580 |
+
lv_proportion_range = self.level_proportion_range
|
581 |
+
lv_text_polys = [[] for i in range(len(lv_size_divs))]
|
582 |
+
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
|
583 |
+
level_maps = []
|
584 |
+
for poly in text_polys:
|
585 |
+
polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
|
586 |
+
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
587 |
+
proportion = max(box_h, box_w) / (h + 1e-8)
|
588 |
+
|
589 |
+
for ind, proportion_range in enumerate(lv_proportion_range):
|
590 |
+
if proportion_range[0] < proportion < proportion_range[1]:
|
591 |
+
lv_text_polys[ind].append(poly / lv_size_divs[ind])
|
592 |
+
|
593 |
+
for ignore_poly in ignore_polys:
|
594 |
+
polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
|
595 |
+
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
596 |
+
proportion = max(box_h, box_w) / (h + 1e-8)
|
597 |
+
|
598 |
+
for ind, proportion_range in enumerate(lv_proportion_range):
|
599 |
+
if proportion_range[0] < proportion < proportion_range[1]:
|
600 |
+
lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
|
601 |
+
|
602 |
+
for ind, size_divisor in enumerate(lv_size_divs):
|
603 |
+
current_level_maps = []
|
604 |
+
level_img_size = (h // size_divisor, w // size_divisor)
|
605 |
+
|
606 |
+
text_region = self.generate_text_region_mask(
|
607 |
+
level_img_size, lv_text_polys[ind]
|
608 |
+
)[None]
|
609 |
+
current_level_maps.append(text_region)
|
610 |
+
|
611 |
+
center_region = self.generate_center_region_mask(
|
612 |
+
level_img_size, lv_text_polys[ind]
|
613 |
+
)[None]
|
614 |
+
current_level_maps.append(center_region)
|
615 |
+
|
616 |
+
effective_mask = self.generate_effective_mask(
|
617 |
+
level_img_size, lv_ignore_polys[ind]
|
618 |
+
)[None]
|
619 |
+
current_level_maps.append(effective_mask)
|
620 |
+
|
621 |
+
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
|
622 |
+
level_img_size, lv_text_polys[ind]
|
623 |
+
)
|
624 |
+
current_level_maps.append(fourier_real_map)
|
625 |
+
current_level_maps.append(fourier_image_maps)
|
626 |
+
|
627 |
+
level_maps.append(np.concatenate(current_level_maps))
|
628 |
+
|
629 |
+
return level_maps
|
630 |
+
|
631 |
+
def generate_targets(self, results):
|
632 |
+
"""Generate the ground truth targets for FCENet.
|
633 |
+
|
634 |
+
Args:
|
635 |
+
results (dict): The input result dictionary.
|
636 |
+
|
637 |
+
Returns:
|
638 |
+
results (dict): The output result dictionary.
|
639 |
+
"""
|
640 |
+
|
641 |
+
assert isinstance(results, dict)
|
642 |
+
image = results["image"]
|
643 |
+
polygons = results["polys"]
|
644 |
+
ignore_tags = results["ignore_tags"]
|
645 |
+
h, w, _ = image.shape
|
646 |
+
|
647 |
+
polygon_masks = []
|
648 |
+
polygon_masks_ignore = []
|
649 |
+
for tag, polygon in zip(ignore_tags, polygons):
|
650 |
+
if tag is True:
|
651 |
+
polygon_masks_ignore.append(polygon)
|
652 |
+
else:
|
653 |
+
polygon_masks.append(polygon)
|
654 |
+
|
655 |
+
level_maps = self.generate_level_targets(
|
656 |
+
(h, w), polygon_masks, polygon_masks_ignore
|
657 |
+
)
|
658 |
+
|
659 |
+
mapping = {
|
660 |
+
"p3_maps": level_maps[0],
|
661 |
+
"p4_maps": level_maps[1],
|
662 |
+
"p5_maps": level_maps[2],
|
663 |
+
}
|
664 |
+
for key, value in mapping.items():
|
665 |
+
results[key] = value
|
666 |
+
|
667 |
+
return results
|
668 |
+
|
669 |
+
def __call__(self, results):
|
670 |
+
results = self.generate_targets(results)
|
671 |
+
return results
|
ocr/ppocr/data/imaug/gen_table_mask.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class GenTableMask(object):
|
8 |
+
"""gen table mask"""
|
9 |
+
|
10 |
+
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
|
11 |
+
self.shrink_h_max = 5
|
12 |
+
self.shrink_w_max = 5
|
13 |
+
self.mask_type = mask_type
|
14 |
+
|
15 |
+
def projection(self, erosion, h, w, spilt_threshold=0):
|
16 |
+
# 水平投影
|
17 |
+
projection_map = np.ones_like(erosion)
|
18 |
+
project_val_array = [0 for _ in range(0, h)]
|
19 |
+
|
20 |
+
for j in range(0, h):
|
21 |
+
for i in range(0, w):
|
22 |
+
if erosion[j, i] == 255:
|
23 |
+
project_val_array[j] += 1
|
24 |
+
# 根据数组,获取切割点
|
25 |
+
start_idx = 0 # 记录进入字符区的索引
|
26 |
+
end_idx = 0 # 记录进入空白区域的索引
|
27 |
+
in_text = False # 是否遍历到了字符区内
|
28 |
+
box_list = []
|
29 |
+
for i in range(len(project_val_array)):
|
30 |
+
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
31 |
+
in_text = True
|
32 |
+
start_idx = i
|
33 |
+
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
34 |
+
end_idx = i
|
35 |
+
in_text = False
|
36 |
+
if end_idx - start_idx <= 2:
|
37 |
+
continue
|
38 |
+
box_list.append((start_idx, end_idx + 1))
|
39 |
+
|
40 |
+
if in_text:
|
41 |
+
box_list.append((start_idx, h - 1))
|
42 |
+
# 绘制投影直方图
|
43 |
+
for j in range(0, h):
|
44 |
+
for i in range(0, project_val_array[j]):
|
45 |
+
projection_map[j, i] = 0
|
46 |
+
return box_list, projection_map
|
47 |
+
|
48 |
+
def projection_cx(self, box_img):
|
49 |
+
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
|
50 |
+
h, w = box_gray_img.shape
|
51 |
+
# 灰度图片进行二值化处理
|
52 |
+
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
|
53 |
+
# 纵向腐蚀
|
54 |
+
if h < w:
|
55 |
+
kernel = np.ones((2, 1), np.uint8)
|
56 |
+
erode = cv2.erode(thresh1, kernel, iterations=1)
|
57 |
+
else:
|
58 |
+
erode = thresh1
|
59 |
+
# 水平膨胀
|
60 |
+
kernel = np.ones((1, 5), np.uint8)
|
61 |
+
erosion = cv2.dilate(erode, kernel, iterations=1)
|
62 |
+
# 水平投影
|
63 |
+
projection_map = np.ones_like(erosion)
|
64 |
+
project_val_array = [0 for _ in range(0, h)]
|
65 |
+
|
66 |
+
for j in range(0, h):
|
67 |
+
for i in range(0, w):
|
68 |
+
if erosion[j, i] == 255:
|
69 |
+
project_val_array[j] += 1
|
70 |
+
# 根据数组,获取切割点
|
71 |
+
start_idx = 0 # 记录进入字符区的索引
|
72 |
+
end_idx = 0 # 记录进入空白区域的索引
|
73 |
+
in_text = False # 是否遍历到了字符区内
|
74 |
+
box_list = []
|
75 |
+
spilt_threshold = 0
|
76 |
+
for i in range(len(project_val_array)):
|
77 |
+
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
78 |
+
in_text = True
|
79 |
+
start_idx = i
|
80 |
+
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
81 |
+
end_idx = i
|
82 |
+
in_text = False
|
83 |
+
if end_idx - start_idx <= 2:
|
84 |
+
continue
|
85 |
+
box_list.append((start_idx, end_idx + 1))
|
86 |
+
|
87 |
+
if in_text:
|
88 |
+
box_list.append((start_idx, h - 1))
|
89 |
+
# 绘制投影直方图
|
90 |
+
for j in range(0, h):
|
91 |
+
for i in range(0, project_val_array[j]):
|
92 |
+
projection_map[j, i] = 0
|
93 |
+
split_bbox_list = []
|
94 |
+
if len(box_list) > 1:
|
95 |
+
for i, (h_start, h_end) in enumerate(box_list):
|
96 |
+
if i == 0:
|
97 |
+
h_start = 0
|
98 |
+
if i == len(box_list):
|
99 |
+
h_end = h
|
100 |
+
word_img = erosion[h_start : h_end + 1, :]
|
101 |
+
word_h, word_w = word_img.shape
|
102 |
+
w_split_list, w_projection_map = self.projection(
|
103 |
+
word_img.T, word_w, word_h
|
104 |
+
)
|
105 |
+
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
|
106 |
+
if h_start > 0:
|
107 |
+
h_start -= 1
|
108 |
+
h_end += 1
|
109 |
+
word_img = box_img[h_start : h_end + 1 :, w_start : w_end + 1, :]
|
110 |
+
split_bbox_list.append([w_start, h_start, w_end, h_end])
|
111 |
+
else:
|
112 |
+
split_bbox_list.append([0, 0, w, h])
|
113 |
+
return split_bbox_list
|
114 |
+
|
115 |
+
def shrink_bbox(self, bbox):
|
116 |
+
left, top, right, bottom = bbox
|
117 |
+
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
|
118 |
+
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
|
119 |
+
left_new = left + sh_w
|
120 |
+
right_new = right - sh_w
|
121 |
+
top_new = top + sh_h
|
122 |
+
bottom_new = bottom - sh_h
|
123 |
+
if left_new >= right_new:
|
124 |
+
left_new = left
|
125 |
+
right_new = right
|
126 |
+
if top_new >= bottom_new:
|
127 |
+
top_new = top
|
128 |
+
bottom_new = bottom
|
129 |
+
return [left_new, top_new, right_new, bottom_new]
|
130 |
+
|
131 |
+
def __call__(self, data):
|
132 |
+
img = data["image"]
|
133 |
+
cells = data["cells"]
|
134 |
+
height, width = img.shape[0:2]
|
135 |
+
if self.mask_type == 1:
|
136 |
+
mask_img = np.zeros((height, width), dtype=np.float32)
|
137 |
+
else:
|
138 |
+
mask_img = np.zeros((height, width, 3), dtype=np.float32)
|
139 |
+
cell_num = len(cells)
|
140 |
+
for cno in range(cell_num):
|
141 |
+
if "bbox" in cells[cno]:
|
142 |
+
bbox = cells[cno]["bbox"]
|
143 |
+
left, top, right, bottom = bbox
|
144 |
+
box_img = img[top:bottom, left:right, :].copy()
|
145 |
+
split_bbox_list = self.projection_cx(box_img)
|
146 |
+
for sno in range(len(split_bbox_list)):
|
147 |
+
split_bbox_list[sno][0] += left
|
148 |
+
split_bbox_list[sno][1] += top
|
149 |
+
split_bbox_list[sno][2] += left
|
150 |
+
split_bbox_list[sno][3] += top
|
151 |
+
|
152 |
+
for sno in range(len(split_bbox_list)):
|
153 |
+
left, top, right, bottom = split_bbox_list[sno]
|
154 |
+
left, top, right, bottom = self.shrink_bbox(
|
155 |
+
[left, top, right, bottom]
|
156 |
+
)
|
157 |
+
if self.mask_type == 1:
|
158 |
+
mask_img[top:bottom, left:right] = 1.0
|
159 |
+
data["mask_img"] = mask_img
|
160 |
+
else:
|
161 |
+
mask_img[top:bottom, left:right, :] = (255, 255, 255)
|
162 |
+
data["image"] = mask_img
|
163 |
+
return data
|
164 |
+
|
165 |
+
|
166 |
+
class ResizeTableImage(object):
|
167 |
+
def __init__(self, max_len, **kwargs):
|
168 |
+
super(ResizeTableImage, self).__init__()
|
169 |
+
self.max_len = max_len
|
170 |
+
|
171 |
+
def get_img_bbox(self, cells):
|
172 |
+
bbox_list = []
|
173 |
+
if len(cells) == 0:
|
174 |
+
return bbox_list
|
175 |
+
cell_num = len(cells)
|
176 |
+
for cno in range(cell_num):
|
177 |
+
if "bbox" in cells[cno]:
|
178 |
+
bbox = cells[cno]["bbox"]
|
179 |
+
bbox_list.append(bbox)
|
180 |
+
return bbox_list
|
181 |
+
|
182 |
+
def resize_img_table(self, img, bbox_list, max_len):
|
183 |
+
height, width = img.shape[0:2]
|
184 |
+
ratio = max_len / (max(height, width) * 1.0)
|
185 |
+
resize_h = int(height * ratio)
|
186 |
+
resize_w = int(width * ratio)
|
187 |
+
img_new = cv2.resize(img, (resize_w, resize_h))
|
188 |
+
bbox_list_new = []
|
189 |
+
for bno in range(len(bbox_list)):
|
190 |
+
left, top, right, bottom = bbox_list[bno].copy()
|
191 |
+
left = int(left * ratio)
|
192 |
+
top = int(top * ratio)
|
193 |
+
right = int(right * ratio)
|
194 |
+
bottom = int(bottom * ratio)
|
195 |
+
bbox_list_new.append([left, top, right, bottom])
|
196 |
+
return img_new, bbox_list_new
|
197 |
+
|
198 |
+
def __call__(self, data):
|
199 |
+
img = data["image"]
|
200 |
+
if "cells" not in data:
|
201 |
+
cells = []
|
202 |
+
else:
|
203 |
+
cells = data["cells"]
|
204 |
+
bbox_list = self.get_img_bbox(cells)
|
205 |
+
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
|
206 |
+
data["image"] = img_new
|
207 |
+
cell_num = len(cells)
|
208 |
+
bno = 0
|
209 |
+
for cno in range(cell_num):
|
210 |
+
if "bbox" in data["cells"][cno]:
|
211 |
+
data["cells"][cno]["bbox"] = bbox_list_new[bno]
|
212 |
+
bno += 1
|
213 |
+
data["max_len"] = self.max_len
|
214 |
+
return data
|
215 |
+
|
216 |
+
|
217 |
+
class PaddingTableImage(object):
|
218 |
+
def __init__(self, **kwargs):
|
219 |
+
super(PaddingTableImage, self).__init__()
|
220 |
+
|
221 |
+
def __call__(self, data):
|
222 |
+
img = data["image"]
|
223 |
+
max_len = data["max_len"]
|
224 |
+
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
|
225 |
+
height, width = img.shape[0:2]
|
226 |
+
padding_img[0:height, 0:width, :] = img.copy()
|
227 |
+
data["image"] = padding_img
|
228 |
+
return data
|
ocr/ppocr/data/imaug/iaa_augment.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import imgaug
|
4 |
+
import imgaug.augmenters as iaa
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class AugmenterBuilder(object):
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def build(self, args, root=True):
|
13 |
+
if args is None or len(args) == 0:
|
14 |
+
return None
|
15 |
+
elif isinstance(args, list):
|
16 |
+
if root:
|
17 |
+
sequence = [self.build(value, root=False) for value in args]
|
18 |
+
return iaa.Sequential(sequence)
|
19 |
+
else:
|
20 |
+
return getattr(iaa, args[0])(
|
21 |
+
*[self.to_tuple_if_list(a) for a in args[1:]]
|
22 |
+
)
|
23 |
+
elif isinstance(args, dict):
|
24 |
+
cls = getattr(iaa, args["type"])
|
25 |
+
return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()})
|
26 |
+
else:
|
27 |
+
raise RuntimeError("unknown augmenter arg: " + str(args))
|
28 |
+
|
29 |
+
def to_tuple_if_list(self, obj):
|
30 |
+
if isinstance(obj, list):
|
31 |
+
return tuple(obj)
|
32 |
+
return obj
|
33 |
+
|
34 |
+
|
35 |
+
class IaaAugment:
|
36 |
+
def __init__(self, augmenter_args=None, **kwargs):
|
37 |
+
if augmenter_args is None:
|
38 |
+
augmenter_args = [
|
39 |
+
{"type": "Fliplr", "args": {"p": 0.5}},
|
40 |
+
{"type": "Affine", "args": {"rotate": [-10, 10]}},
|
41 |
+
{"type": "Resize", "args": {"size": [0.5, 3]}},
|
42 |
+
]
|
43 |
+
self.augmenter = AugmenterBuilder().build(augmenter_args)
|
44 |
+
|
45 |
+
def __call__(self, data):
|
46 |
+
image = data["image"]
|
47 |
+
shape = image.shape
|
48 |
+
|
49 |
+
if self.augmenter:
|
50 |
+
aug = self.augmenter.to_deterministic()
|
51 |
+
data["image"] = aug.augment_image(image)
|
52 |
+
data = self.may_augment_annotation(aug, data, shape)
|
53 |
+
return data
|
54 |
+
|
55 |
+
def may_augment_annotation(self, aug, data, shape):
|
56 |
+
if aug is None:
|
57 |
+
return data
|
58 |
+
|
59 |
+
line_polys = []
|
60 |
+
for poly in data["polys"]:
|
61 |
+
new_poly = self.may_augment_poly(aug, shape, poly)
|
62 |
+
line_polys.append(new_poly)
|
63 |
+
data["polys"] = np.array(line_polys)
|
64 |
+
return data
|
65 |
+
|
66 |
+
def may_augment_poly(self, aug, img_shape, poly):
|
67 |
+
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
|
68 |
+
keypoints = aug.augment_keypoints(
|
69 |
+
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)]
|
70 |
+
)[0].keypoints
|
71 |
+
poly = [(p.x, p.y) for p in keypoints]
|
72 |
+
return poly
|
ocr/ppocr/data/imaug/label_ops.py
ADDED
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from shapely.geometry import LineString, Point, Polygon
|
8 |
+
|
9 |
+
|
10 |
+
class ClsLabelEncode(object):
|
11 |
+
def __init__(self, label_list, **kwargs):
|
12 |
+
self.label_list = label_list
|
13 |
+
|
14 |
+
def __call__(self, data):
|
15 |
+
label = data["label"]
|
16 |
+
if label not in self.label_list:
|
17 |
+
return None
|
18 |
+
label = self.label_list.index(label)
|
19 |
+
data["label"] = label
|
20 |
+
return data
|
21 |
+
|
22 |
+
|
23 |
+
class DetLabelEncode(object):
|
24 |
+
def __init__(self, **kwargs):
|
25 |
+
pass
|
26 |
+
|
27 |
+
def __call__(self, data):
|
28 |
+
label = data["label"]
|
29 |
+
label = json.loads(label)
|
30 |
+
nBox = len(label)
|
31 |
+
boxes, txts, txt_tags = [], [], []
|
32 |
+
for bno in range(0, nBox):
|
33 |
+
box = label[bno]["points"]
|
34 |
+
txt = label[bno]["transcription"]
|
35 |
+
boxes.append(box)
|
36 |
+
txts.append(txt)
|
37 |
+
if txt in ["*", "###"]:
|
38 |
+
txt_tags.append(True)
|
39 |
+
else:
|
40 |
+
txt_tags.append(False)
|
41 |
+
if len(boxes) == 0:
|
42 |
+
return None
|
43 |
+
boxes = self.expand_points_num(boxes)
|
44 |
+
boxes = np.array(boxes, dtype=np.float32)
|
45 |
+
txt_tags = np.array(txt_tags, dtype=np.bool)
|
46 |
+
|
47 |
+
data["polys"] = boxes
|
48 |
+
data["texts"] = txts
|
49 |
+
data["ignore_tags"] = txt_tags
|
50 |
+
return data
|
51 |
+
|
52 |
+
def order_points_clockwise(self, pts):
|
53 |
+
rect = np.zeros((4, 2), dtype="float32")
|
54 |
+
s = pts.sum(axis=1)
|
55 |
+
rect[0] = pts[np.argmin(s)]
|
56 |
+
rect[2] = pts[np.argmax(s)]
|
57 |
+
diff = np.diff(pts, axis=1)
|
58 |
+
rect[1] = pts[np.argmin(diff)]
|
59 |
+
rect[3] = pts[np.argmax(diff)]
|
60 |
+
return rect
|
61 |
+
|
62 |
+
def expand_points_num(self, boxes):
|
63 |
+
max_points_num = 0
|
64 |
+
for box in boxes:
|
65 |
+
if len(box) > max_points_num:
|
66 |
+
max_points_num = len(box)
|
67 |
+
ex_boxes = []
|
68 |
+
for box in boxes:
|
69 |
+
ex_box = box + [box[-1]] * (max_points_num - len(box))
|
70 |
+
ex_boxes.append(ex_box)
|
71 |
+
return ex_boxes
|
72 |
+
|
73 |
+
|
74 |
+
class BaseRecLabelEncode(object):
|
75 |
+
"""Convert between text-label and text-index"""
|
76 |
+
|
77 |
+
def __init__(self, max_text_length, character_dict_path=None, use_space_char=False):
|
78 |
+
|
79 |
+
self.max_text_len = max_text_length
|
80 |
+
self.beg_str = "sos"
|
81 |
+
self.end_str = "eos"
|
82 |
+
self.lower = False
|
83 |
+
|
84 |
+
if character_dict_path is None:
|
85 |
+
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
86 |
+
dict_character = list(self.character_str)
|
87 |
+
self.lower = True
|
88 |
+
else:
|
89 |
+
self.character_str = []
|
90 |
+
with open(character_dict_path, "rb") as fin:
|
91 |
+
lines = fin.readlines()
|
92 |
+
for line in lines:
|
93 |
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
94 |
+
self.character_str.append(line)
|
95 |
+
if use_space_char:
|
96 |
+
self.character_str.append(" ")
|
97 |
+
dict_character = list(self.character_str)
|
98 |
+
dict_character = self.add_special_char(dict_character)
|
99 |
+
self.dict = {}
|
100 |
+
for i, char in enumerate(dict_character):
|
101 |
+
self.dict[char] = i
|
102 |
+
self.character = dict_character
|
103 |
+
|
104 |
+
def add_special_char(self, dict_character):
|
105 |
+
return dict_character
|
106 |
+
|
107 |
+
def encode(self, text):
|
108 |
+
"""convert text-label into text-index.
|
109 |
+
input:
|
110 |
+
text: text labels of each image. [batch_size]
|
111 |
+
|
112 |
+
output:
|
113 |
+
text: concatenated text index for CTCLoss.
|
114 |
+
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
115 |
+
length: length of each text. [batch_size]
|
116 |
+
"""
|
117 |
+
if len(text) == 0 or len(text) > self.max_text_len:
|
118 |
+
return None
|
119 |
+
if self.lower:
|
120 |
+
text = text.lower()
|
121 |
+
text_list = []
|
122 |
+
for char in text:
|
123 |
+
if char not in self.dict:
|
124 |
+
continue
|
125 |
+
text_list.append(self.dict[char])
|
126 |
+
if len(text_list) == 0:
|
127 |
+
return None
|
128 |
+
return text_list
|
129 |
+
|
130 |
+
|
131 |
+
class NRTRLabelEncode(BaseRecLabelEncode):
|
132 |
+
"""Convert between text-label and text-index"""
|
133 |
+
|
134 |
+
def __init__(
|
135 |
+
self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
|
136 |
+
):
|
137 |
+
|
138 |
+
super(NRTRLabelEncode, self).__init__(
|
139 |
+
max_text_length, character_dict_path, use_space_char
|
140 |
+
)
|
141 |
+
|
142 |
+
def __call__(self, data):
|
143 |
+
text = data["label"]
|
144 |
+
text = self.encode(text)
|
145 |
+
if text is None:
|
146 |
+
return None
|
147 |
+
if len(text) >= self.max_text_len - 1:
|
148 |
+
return None
|
149 |
+
data["length"] = np.array(len(text))
|
150 |
+
text.insert(0, 2)
|
151 |
+
text.append(3)
|
152 |
+
text = text + [0] * (self.max_text_len - len(text))
|
153 |
+
data["label"] = np.array(text)
|
154 |
+
return data
|
155 |
+
|
156 |
+
def add_special_char(self, dict_character):
|
157 |
+
dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
|
158 |
+
return dict_character
|
159 |
+
|
160 |
+
|
161 |
+
class CTCLabelEncode(BaseRecLabelEncode):
|
162 |
+
"""Convert between text-label and text-index"""
|
163 |
+
|
164 |
+
def __init__(
|
165 |
+
self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
|
166 |
+
):
|
167 |
+
super(CTCLabelEncode, self).__init__(
|
168 |
+
max_text_length, character_dict_path, use_space_char
|
169 |
+
)
|
170 |
+
|
171 |
+
def __call__(self, data):
|
172 |
+
text = data["label"]
|
173 |
+
text = self.encode(text)
|
174 |
+
if text is None:
|
175 |
+
return None
|
176 |
+
data["length"] = np.array(len(text))
|
177 |
+
text = text + [0] * (self.max_text_len - len(text))
|
178 |
+
data["label"] = np.array(text)
|
179 |
+
|
180 |
+
label = [0] * len(self.character)
|
181 |
+
for x in text:
|
182 |
+
label[x] += 1
|
183 |
+
data["label_ace"] = np.array(label)
|
184 |
+
return data
|
185 |
+
|
186 |
+
def add_special_char(self, dict_character):
|
187 |
+
dict_character = ["blank"] + dict_character
|
188 |
+
return dict_character
|
189 |
+
|
190 |
+
|
191 |
+
class E2ELabelEncodeTest(BaseRecLabelEncode):
|
192 |
+
def __init__(
|
193 |
+
self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
|
194 |
+
):
|
195 |
+
super(E2ELabelEncodeTest, self).__init__(
|
196 |
+
max_text_length, character_dict_path, use_space_char
|
197 |
+
)
|
198 |
+
|
199 |
+
def __call__(self, data):
|
200 |
+
import json
|
201 |
+
|
202 |
+
padnum = len(self.dict)
|
203 |
+
label = data["label"]
|
204 |
+
label = json.loads(label)
|
205 |
+
nBox = len(label)
|
206 |
+
boxes, txts, txt_tags = [], [], []
|
207 |
+
for bno in range(0, nBox):
|
208 |
+
box = label[bno]["points"]
|
209 |
+
txt = label[bno]["transcription"]
|
210 |
+
boxes.append(box)
|
211 |
+
txts.append(txt)
|
212 |
+
if txt in ["*", "###"]:
|
213 |
+
txt_tags.append(True)
|
214 |
+
else:
|
215 |
+
txt_tags.append(False)
|
216 |
+
boxes = np.array(boxes, dtype=np.float32)
|
217 |
+
txt_tags = np.array(txt_tags, dtype=np.bool)
|
218 |
+
data["polys"] = boxes
|
219 |
+
data["ignore_tags"] = txt_tags
|
220 |
+
temp_texts = []
|
221 |
+
for text in txts:
|
222 |
+
text = text.lower()
|
223 |
+
text = self.encode(text)
|
224 |
+
if text is None:
|
225 |
+
return None
|
226 |
+
text = text + [padnum] * (self.max_text_len - len(text)) # use 36 to pad
|
227 |
+
temp_texts.append(text)
|
228 |
+
data["texts"] = np.array(temp_texts)
|
229 |
+
return data
|
230 |
+
|
231 |
+
|
232 |
+
class E2ELabelEncodeTrain(object):
|
233 |
+
def __init__(self, **kwargs):
|
234 |
+
pass
|
235 |
+
|
236 |
+
def __call__(self, data):
|
237 |
+
import json
|
238 |
+
|
239 |
+
label = data["label"]
|
240 |
+
label = json.loads(label)
|
241 |
+
nBox = len(label)
|
242 |
+
boxes, txts, txt_tags = [], [], []
|
243 |
+
for bno in range(0, nBox):
|
244 |
+
box = label[bno]["points"]
|
245 |
+
txt = label[bno]["transcription"]
|
246 |
+
boxes.append(box)
|
247 |
+
txts.append(txt)
|
248 |
+
if txt in ["*", "###"]:
|
249 |
+
txt_tags.append(True)
|
250 |
+
else:
|
251 |
+
txt_tags.append(False)
|
252 |
+
boxes = np.array(boxes, dtype=np.float32)
|
253 |
+
txt_tags = np.array(txt_tags, dtype=np.bool)
|
254 |
+
|
255 |
+
data["polys"] = boxes
|
256 |
+
data["texts"] = txts
|
257 |
+
data["ignore_tags"] = txt_tags
|
258 |
+
return data
|
259 |
+
|
260 |
+
|
261 |
+
class KieLabelEncode(object):
|
262 |
+
def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
|
263 |
+
super(KieLabelEncode, self).__init__()
|
264 |
+
self.dict = dict({"": 0})
|
265 |
+
with open(character_dict_path, "r", encoding="utf-8") as fr:
|
266 |
+
idx = 1
|
267 |
+
for line in fr:
|
268 |
+
char = line.strip()
|
269 |
+
self.dict[char] = idx
|
270 |
+
idx += 1
|
271 |
+
self.norm = norm
|
272 |
+
self.directed = directed
|
273 |
+
|
274 |
+
def compute_relation(self, boxes):
|
275 |
+
"""Compute relation between every two boxes."""
|
276 |
+
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
277 |
+
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
278 |
+
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
279 |
+
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
280 |
+
dys = (y1s[:, 0][None] - y1s) / self.norm
|
281 |
+
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
282 |
+
whs = ws / hs + np.zeros_like(xhhs)
|
283 |
+
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
284 |
+
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
285 |
+
return relations, bboxes
|
286 |
+
|
287 |
+
def pad_text_indices(self, text_inds):
|
288 |
+
"""Pad text index to same length."""
|
289 |
+
max_len = 300
|
290 |
+
recoder_len = max([len(text_ind) for text_ind in text_inds])
|
291 |
+
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
|
292 |
+
for idx, text_ind in enumerate(text_inds):
|
293 |
+
padded_text_inds[idx, : len(text_ind)] = np.array(text_ind)
|
294 |
+
return padded_text_inds, recoder_len
|
295 |
+
|
296 |
+
def list_to_numpy(self, ann_infos):
|
297 |
+
"""Convert bboxes, relations, texts and labels to ndarray."""
|
298 |
+
boxes, text_inds = ann_infos["points"], ann_infos["text_inds"]
|
299 |
+
boxes = np.array(boxes, np.int32)
|
300 |
+
relations, bboxes = self.compute_relation(boxes)
|
301 |
+
|
302 |
+
labels = ann_infos.get("labels", None)
|
303 |
+
if labels is not None:
|
304 |
+
labels = np.array(labels, np.int32)
|
305 |
+
edges = ann_infos.get("edges", None)
|
306 |
+
if edges is not None:
|
307 |
+
labels = labels[:, None]
|
308 |
+
edges = np.array(edges)
|
309 |
+
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
310 |
+
if self.directed:
|
311 |
+
edges = (edges & labels == 1).astype(np.int32)
|
312 |
+
np.fill_diagonal(edges, -1)
|
313 |
+
labels = np.concatenate([labels, edges], -1)
|
314 |
+
padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
|
315 |
+
max_num = 300
|
316 |
+
temp_bboxes = np.zeros([max_num, 4])
|
317 |
+
h, _ = bboxes.shape
|
318 |
+
temp_bboxes[:h, :] = bboxes
|
319 |
+
|
320 |
+
temp_relations = np.zeros([max_num, max_num, 5])
|
321 |
+
temp_relations[:h, :h, :] = relations
|
322 |
+
|
323 |
+
temp_padded_text_inds = np.zeros([max_num, max_num])
|
324 |
+
temp_padded_text_inds[:h, :] = padded_text_inds
|
325 |
+
|
326 |
+
temp_labels = np.zeros([max_num, max_num])
|
327 |
+
temp_labels[:h, : h + 1] = labels
|
328 |
+
|
329 |
+
tag = np.array([h, recoder_len])
|
330 |
+
return dict(
|
331 |
+
image=ann_infos["image"],
|
332 |
+
points=temp_bboxes,
|
333 |
+
relations=temp_relations,
|
334 |
+
texts=temp_padded_text_inds,
|
335 |
+
labels=temp_labels,
|
336 |
+
tag=tag,
|
337 |
+
)
|
338 |
+
|
339 |
+
def convert_canonical(self, points_x, points_y):
|
340 |
+
|
341 |
+
assert len(points_x) == 4
|
342 |
+
assert len(points_y) == 4
|
343 |
+
|
344 |
+
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
345 |
+
|
346 |
+
polygon = Polygon([(p.x, p.y) for p in points])
|
347 |
+
min_x, min_y, _, _ = polygon.bounds
|
348 |
+
points_to_lefttop = [
|
349 |
+
LineString([points[i], Point(min_x, min_y)]) for i in range(4)
|
350 |
+
]
|
351 |
+
distances = np.array([line.length for line in points_to_lefttop])
|
352 |
+
sort_dist_idx = np.argsort(distances)
|
353 |
+
lefttop_idx = sort_dist_idx[0]
|
354 |
+
|
355 |
+
if lefttop_idx == 0:
|
356 |
+
point_orders = [0, 1, 2, 3]
|
357 |
+
elif lefttop_idx == 1:
|
358 |
+
point_orders = [1, 2, 3, 0]
|
359 |
+
elif lefttop_idx == 2:
|
360 |
+
point_orders = [2, 3, 0, 1]
|
361 |
+
else:
|
362 |
+
point_orders = [3, 0, 1, 2]
|
363 |
+
|
364 |
+
sorted_points_x = [points_x[i] for i in point_orders]
|
365 |
+
sorted_points_y = [points_y[j] for j in point_orders]
|
366 |
+
|
367 |
+
return sorted_points_x, sorted_points_y
|
368 |
+
|
369 |
+
def sort_vertex(self, points_x, points_y):
|
370 |
+
|
371 |
+
assert len(points_x) == 4
|
372 |
+
assert len(points_y) == 4
|
373 |
+
|
374 |
+
x = np.array(points_x)
|
375 |
+
y = np.array(points_y)
|
376 |
+
center_x = np.sum(x) * 0.25
|
377 |
+
center_y = np.sum(y) * 0.25
|
378 |
+
|
379 |
+
x_arr = np.array(x - center_x)
|
380 |
+
y_arr = np.array(y - center_y)
|
381 |
+
|
382 |
+
angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
|
383 |
+
sort_idx = np.argsort(angle)
|
384 |
+
|
385 |
+
sorted_points_x, sorted_points_y = [], []
|
386 |
+
for i in range(4):
|
387 |
+
sorted_points_x.append(points_x[sort_idx[i]])
|
388 |
+
sorted_points_y.append(points_y[sort_idx[i]])
|
389 |
+
|
390 |
+
return self.convert_canonical(sorted_points_x, sorted_points_y)
|
391 |
+
|
392 |
+
def __call__(self, data):
|
393 |
+
import json
|
394 |
+
|
395 |
+
label = data["label"]
|
396 |
+
annotations = json.loads(label)
|
397 |
+
boxes, texts, text_inds, labels, edges = [], [], [], [], []
|
398 |
+
for ann in annotations:
|
399 |
+
box = ann["points"]
|
400 |
+
x_list = [box[i][0] for i in range(4)]
|
401 |
+
y_list = [box[i][1] for i in range(4)]
|
402 |
+
sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
|
403 |
+
sorted_box = []
|
404 |
+
for x, y in zip(sorted_x_list, sorted_y_list):
|
405 |
+
sorted_box.append(x)
|
406 |
+
sorted_box.append(y)
|
407 |
+
boxes.append(sorted_box)
|
408 |
+
text = ann["transcription"]
|
409 |
+
texts.append(ann["transcription"])
|
410 |
+
text_ind = [self.dict[c] for c in text if c in self.dict]
|
411 |
+
text_inds.append(text_ind)
|
412 |
+
if "label" in ann.keys():
|
413 |
+
labels.append(ann["label"])
|
414 |
+
elif "key_cls" in ann.keys():
|
415 |
+
labels.append(ann["key_cls"])
|
416 |
+
else:
|
417 |
+
raise ValueError(
|
418 |
+
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
|
419 |
+
)
|
420 |
+
edges.append(ann.get("edge", 0))
|
421 |
+
ann_infos = dict(
|
422 |
+
image=data["image"],
|
423 |
+
points=boxes,
|
424 |
+
texts=texts,
|
425 |
+
text_inds=text_inds,
|
426 |
+
edges=edges,
|
427 |
+
labels=labels,
|
428 |
+
)
|
429 |
+
|
430 |
+
return self.list_to_numpy(ann_infos)
|
431 |
+
|
432 |
+
|
433 |
+
class AttnLabelEncode(BaseRecLabelEncode):
|
434 |
+
"""Convert between text-label and text-index"""
|
435 |
+
|
436 |
+
def __init__(
|
437 |
+
self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
|
438 |
+
):
|
439 |
+
super(AttnLabelEncode, self).__init__(
|
440 |
+
max_text_length, character_dict_path, use_space_char
|
441 |
+
)
|
442 |
+
|
443 |
+
def add_special_char(self, dict_character):
|
444 |
+
self.beg_str = "sos"
|
445 |
+
self.end_str = "eos"
|
446 |
+
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
447 |
+
return dict_character
|
448 |
+
|
449 |
+
def __call__(self, data):
|
450 |
+
text = data["label"]
|
451 |
+
text = self.encode(text)
|
452 |
+
if text is None:
|
453 |
+
return None
|
454 |
+
if len(text) >= self.max_text_len:
|
455 |
+
return None
|
456 |
+
data["length"] = np.array(len(text))
|
457 |
+
text = (
|
458 |
+
[0]
|
459 |
+
+ text
|
460 |
+
+ [len(self.character) - 1]
|
461 |
+
+ [0] * (self.max_text_len - len(text) - 2)
|
462 |
+
)
|
463 |
+
data["label"] = np.array(text)
|
464 |
+
return data
|
465 |
+
|
466 |
+
def get_ignored_tokens(self):
|
467 |
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
468 |
+
end_idx = self.get_beg_end_flag_idx("end")
|
469 |
+
return [beg_idx, end_idx]
|
470 |
+
|
471 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
472 |
+
if beg_or_end == "beg":
|
473 |
+
idx = np.array(self.dict[self.beg_str])
|
474 |
+
elif beg_or_end == "end":
|
475 |
+
idx = np.array(self.dict[self.end_str])
|
476 |
+
else:
|
477 |
+
assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
478 |
+
return idx
|
479 |
+
|
480 |
+
|
481 |
+
class SEEDLabelEncode(BaseRecLabelEncode):
|
482 |
+
"""Convert between text-label and text-index"""
|
483 |
+
|
484 |
+
def __init__(
|
485 |
+
self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
|
486 |
+
):
|
487 |
+
super(SEEDLabelEncode, self).__init__(
|
488 |
+
max_text_length, character_dict_path, use_space_char
|
489 |
+
)
|
490 |
+
|
491 |
+
def add_special_char(self, dict_character):
|
492 |
+
self.padding = "padding"
|
493 |
+
self.end_str = "eos"
|
494 |
+
self.unknown = "unknown"
|
495 |
+
dict_character = dict_character + [self.end_str, self.padding, self.unknown]
|
496 |
+
return dict_character
|
497 |
+
|
498 |
+
def __call__(self, data):
|
499 |
+
text = data["label"]
|
500 |
+
text = self.encode(text)
|
501 |
+
if text is None:
|
502 |
+
return None
|
503 |
+
if len(text) >= self.max_text_len:
|
504 |
+
return None
|
505 |
+
data["length"] = np.array(len(text)) + 1 # conclude eos
|
506 |
+
text = (
|
507 |
+
text
|
508 |
+
+ [len(self.character) - 3]
|
509 |
+
+ [len(self.character) - 2] * (self.max_text_len - len(text) - 1)
|
510 |
+
)
|
511 |
+
data["label"] = np.array(text)
|
512 |
+
return data
|
513 |
+
|
514 |
+
|
515 |
+
class SRNLabelEncode(BaseRecLabelEncode):
|
516 |
+
"""Convert between text-label and text-index"""
|
517 |
+
|
518 |
+
def __init__(
|
519 |
+
self,
|
520 |
+
max_text_length=25,
|
521 |
+
character_dict_path=None,
|
522 |
+
use_space_char=False,
|
523 |
+
**kwargs
|
524 |
+
):
|
525 |
+
super(SRNLabelEncode, self).__init__(
|
526 |
+
max_text_length, character_dict_path, use_space_char
|
527 |
+
)
|
528 |
+
|
529 |
+
def add_special_char(self, dict_character):
|
530 |
+
dict_character = dict_character + [self.beg_str, self.end_str]
|
531 |
+
return dict_character
|
532 |
+
|
533 |
+
def __call__(self, data):
|
534 |
+
text = data["label"]
|
535 |
+
text = self.encode(text)
|
536 |
+
char_num = len(self.character)
|
537 |
+
if text is None:
|
538 |
+
return None
|
539 |
+
if len(text) > self.max_text_len:
|
540 |
+
return None
|
541 |
+
data["length"] = np.array(len(text))
|
542 |
+
text = text + [char_num - 1] * (self.max_text_len - len(text))
|
543 |
+
data["label"] = np.array(text)
|
544 |
+
return data
|
545 |
+
|
546 |
+
def get_ignored_tokens(self):
|
547 |
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
548 |
+
end_idx = self.get_beg_end_flag_idx("end")
|
549 |
+
return [beg_idx, end_idx]
|
550 |
+
|
551 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
552 |
+
if beg_or_end == "beg":
|
553 |
+
idx = np.array(self.dict[self.beg_str])
|
554 |
+
elif beg_or_end == "end":
|
555 |
+
idx = np.array(self.dict[self.end_str])
|
556 |
+
else:
|
557 |
+
assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
558 |
+
return idx
|
559 |
+
|
560 |
+
|
561 |
+
class TableLabelEncode(object):
|
562 |
+
"""Convert between text-label and text-index"""
|
563 |
+
|
564 |
+
def __init__(
|
565 |
+
self,
|
566 |
+
max_text_length,
|
567 |
+
max_elem_length,
|
568 |
+
max_cell_num,
|
569 |
+
character_dict_path,
|
570 |
+
span_weight=1.0,
|
571 |
+
**kwargs
|
572 |
+
):
|
573 |
+
self.max_text_length = max_text_length
|
574 |
+
self.max_elem_length = max_elem_length
|
575 |
+
self.max_cell_num = max_cell_num
|
576 |
+
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
577 |
+
list_character = self.add_special_char(list_character)
|
578 |
+
list_elem = self.add_special_char(list_elem)
|
579 |
+
self.dict_character = {}
|
580 |
+
for i, char in enumerate(list_character):
|
581 |
+
self.dict_character[char] = i
|
582 |
+
self.dict_elem = {}
|
583 |
+
for i, elem in enumerate(list_elem):
|
584 |
+
self.dict_elem[elem] = i
|
585 |
+
self.span_weight = span_weight
|
586 |
+
|
587 |
+
def load_char_elem_dict(self, character_dict_path):
|
588 |
+
list_character = []
|
589 |
+
list_elem = []
|
590 |
+
with open(character_dict_path, "rb") as fin:
|
591 |
+
lines = fin.readlines()
|
592 |
+
substr = lines[0].decode("utf-8").strip("\r\n").split("\t")
|
593 |
+
character_num = int(substr[0])
|
594 |
+
elem_num = int(substr[1])
|
595 |
+
for cno in range(1, 1 + character_num):
|
596 |
+
character = lines[cno].decode("utf-8").strip("\r\n")
|
597 |
+
list_character.append(character)
|
598 |
+
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
599 |
+
elem = lines[eno].decode("utf-8").strip("\r\n")
|
600 |
+
list_elem.append(elem)
|
601 |
+
return list_character, list_elem
|
602 |
+
|
603 |
+
def add_special_char(self, list_character):
|
604 |
+
self.beg_str = "sos"
|
605 |
+
self.end_str = "eos"
|
606 |
+
list_character = [self.beg_str] + list_character + [self.end_str]
|
607 |
+
return list_character
|
608 |
+
|
609 |
+
def get_span_idx_list(self):
|
610 |
+
span_idx_list = []
|
611 |
+
for elem in self.dict_elem:
|
612 |
+
if "span" in elem:
|
613 |
+
span_idx_list.append(self.dict_elem[elem])
|
614 |
+
return span_idx_list
|
615 |
+
|
616 |
+
def __call__(self, data):
|
617 |
+
cells = data["cells"]
|
618 |
+
structure = data["structure"]["tokens"]
|
619 |
+
structure = self.encode(structure, "elem")
|
620 |
+
if structure is None:
|
621 |
+
return None
|
622 |
+
elem_num = len(structure)
|
623 |
+
structure = [0] + structure + [len(self.dict_elem) - 1]
|
624 |
+
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
625 |
+
structure = np.array(structure)
|
626 |
+
data["structure"] = structure
|
627 |
+
elem_char_idx1 = self.dict_elem["<td>"]
|
628 |
+
elem_char_idx2 = self.dict_elem["<td"]
|
629 |
+
span_idx_list = self.get_span_idx_list()
|
630 |
+
td_idx_list = np.logical_or(
|
631 |
+
structure == elem_char_idx1, structure == elem_char_idx2
|
632 |
+
)
|
633 |
+
td_idx_list = np.where(td_idx_list)[0]
|
634 |
+
|
635 |
+
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
|
636 |
+
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
637 |
+
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
|
638 |
+
img_height, img_width, img_ch = data["image"].shape
|
639 |
+
if len(span_idx_list) > 0:
|
640 |
+
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
641 |
+
span_weight = min(max(span_weight, 1.0), self.span_weight)
|
642 |
+
for cno in range(len(cells)):
|
643 |
+
if "bbox" in cells[cno]:
|
644 |
+
bbox = cells[cno]["bbox"].copy()
|
645 |
+
bbox[0] = bbox[0] * 1.0 / img_width
|
646 |
+
bbox[1] = bbox[1] * 1.0 / img_height
|
647 |
+
bbox[2] = bbox[2] * 1.0 / img_width
|
648 |
+
bbox[3] = bbox[3] * 1.0 / img_height
|
649 |
+
td_idx = td_idx_list[cno]
|
650 |
+
bbox_list[td_idx] = bbox
|
651 |
+
bbox_list_mask[td_idx] = 1.0
|
652 |
+
cand_span_idx = td_idx + 1
|
653 |
+
if cand_span_idx < (self.max_elem_length + 2):
|
654 |
+
if structure[cand_span_idx] in span_idx_list:
|
655 |
+
structure_mask[cand_span_idx] = span_weight
|
656 |
+
|
657 |
+
data["bbox_list"] = bbox_list
|
658 |
+
data["bbox_list_mask"] = bbox_list_mask
|
659 |
+
data["structure_mask"] = structure_mask
|
660 |
+
char_beg_idx = self.get_beg_end_flag_idx("beg", "char")
|
661 |
+
char_end_idx = self.get_beg_end_flag_idx("end", "char")
|
662 |
+
elem_beg_idx = self.get_beg_end_flag_idx("beg", "elem")
|
663 |
+
elem_end_idx = self.get_beg_end_flag_idx("end", "elem")
|
664 |
+
data["sp_tokens"] = np.array(
|
665 |
+
[
|
666 |
+
char_beg_idx,
|
667 |
+
char_end_idx,
|
668 |
+
elem_beg_idx,
|
669 |
+
elem_end_idx,
|
670 |
+
elem_char_idx1,
|
671 |
+
elem_char_idx2,
|
672 |
+
self.max_text_length,
|
673 |
+
self.max_elem_length,
|
674 |
+
self.max_cell_num,
|
675 |
+
elem_num,
|
676 |
+
]
|
677 |
+
)
|
678 |
+
return data
|
679 |
+
|
680 |
+
def encode(self, text, char_or_elem):
|
681 |
+
"""convert text-label into text-index."""
|
682 |
+
if char_or_elem == "char":
|
683 |
+
max_len = self.max_text_length
|
684 |
+
current_dict = self.dict_character
|
685 |
+
else:
|
686 |
+
max_len = self.max_elem_length
|
687 |
+
current_dict = self.dict_elem
|
688 |
+
if len(text) > max_len:
|
689 |
+
return None
|
690 |
+
if len(text) == 0:
|
691 |
+
if char_or_elem == "char":
|
692 |
+
return [self.dict_character["space"]]
|
693 |
+
else:
|
694 |
+
return None
|
695 |
+
text_list = []
|
696 |
+
for char in text:
|
697 |
+
if char not in current_dict:
|
698 |
+
return None
|
699 |
+
text_list.append(current_dict[char])
|
700 |
+
if len(text_list) == 0:
|
701 |
+
if char_or_elem == "char":
|
702 |
+
return [self.dict_character["space"]]
|
703 |
+
else:
|
704 |
+
return None
|
705 |
+
return text_list
|
706 |
+
|
707 |
+
def get_ignored_tokens(self, char_or_elem):
|
708 |
+
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
709 |
+
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
710 |
+
return [beg_idx, end_idx]
|
711 |
+
|
712 |
+
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
713 |
+
if char_or_elem == "char":
|
714 |
+
if beg_or_end == "beg":
|
715 |
+
idx = np.array(self.dict_character[self.beg_str])
|
716 |
+
elif beg_or_end == "end":
|
717 |
+
idx = np.array(self.dict_character[self.end_str])
|
718 |
+
else:
|
719 |
+
assert False, (
|
720 |
+
"Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end
|
721 |
+
)
|
722 |
+
elif char_or_elem == "elem":
|
723 |
+
if beg_or_end == "beg":
|
724 |
+
idx = np.array(self.dict_elem[self.beg_str])
|
725 |
+
elif beg_or_end == "end":
|
726 |
+
idx = np.array(self.dict_elem[self.end_str])
|
727 |
+
else:
|
728 |
+
assert False, (
|
729 |
+
"Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end
|
730 |
+
)
|
731 |
+
else:
|
732 |
+
assert False, "Unsupport type %s in char_or_elem" % char_or_elem
|
733 |
+
return idx
|
734 |
+
|
735 |
+
|
736 |
+
class SARLabelEncode(BaseRecLabelEncode):
|
737 |
+
"""Convert between text-label and text-index"""
|
738 |
+
|
739 |
+
def __init__(
|
740 |
+
self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
|
741 |
+
):
|
742 |
+
super(SARLabelEncode, self).__init__(
|
743 |
+
max_text_length, character_dict_path, use_space_char
|
744 |
+
)
|
745 |
+
|
746 |
+
def add_special_char(self, dict_character):
|
747 |
+
beg_end_str = "<BOS/EOS>"
|
748 |
+
unknown_str = "<UKN>"
|
749 |
+
padding_str = "<PAD>"
|
750 |
+
dict_character = dict_character + [unknown_str]
|
751 |
+
self.unknown_idx = len(dict_character) - 1
|
752 |
+
dict_character = dict_character + [beg_end_str]
|
753 |
+
self.start_idx = len(dict_character) - 1
|
754 |
+
self.end_idx = len(dict_character) - 1
|
755 |
+
dict_character = dict_character + [padding_str]
|
756 |
+
self.padding_idx = len(dict_character) - 1
|
757 |
+
|
758 |
+
return dict_character
|
759 |
+
|
760 |
+
def __call__(self, data):
|
761 |
+
text = data["label"]
|
762 |
+
text = self.encode(text)
|
763 |
+
if text is None:
|
764 |
+
return None
|
765 |
+
if len(text) >= self.max_text_len - 1:
|
766 |
+
return None
|
767 |
+
data["length"] = np.array(len(text))
|
768 |
+
target = [self.start_idx] + text + [self.end_idx]
|
769 |
+
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
|
770 |
+
|
771 |
+
padded_text[: len(target)] = target
|
772 |
+
data["label"] = np.array(padded_text)
|
773 |
+
return data
|
774 |
+
|
775 |
+
def get_ignored_tokens(self):
|
776 |
+
return [self.padding_idx]
|
777 |
+
|
778 |
+
|
779 |
+
class PRENLabelEncode(BaseRecLabelEncode):
|
780 |
+
def __init__(
|
781 |
+
self, max_text_length, character_dict_path, use_space_char=False, **kwargs
|
782 |
+
):
|
783 |
+
super(PRENLabelEncode, self).__init__(
|
784 |
+
max_text_length, character_dict_path, use_space_char
|
785 |
+
)
|
786 |
+
|
787 |
+
def add_special_char(self, dict_character):
|
788 |
+
padding_str = "<PAD>" # 0
|
789 |
+
end_str = "<EOS>" # 1
|
790 |
+
unknown_str = "<UNK>" # 2
|
791 |
+
|
792 |
+
dict_character = [padding_str, end_str, unknown_str] + dict_character
|
793 |
+
self.padding_idx = 0
|
794 |
+
self.end_idx = 1
|
795 |
+
self.unknown_idx = 2
|
796 |
+
|
797 |
+
return dict_character
|
798 |
+
|
799 |
+
def encode(self, text):
|
800 |
+
if len(text) == 0 or len(text) >= self.max_text_len:
|
801 |
+
return None
|
802 |
+
if self.lower:
|
803 |
+
text = text.lower()
|
804 |
+
text_list = []
|
805 |
+
for char in text:
|
806 |
+
if char not in self.dict:
|
807 |
+
text_list.append(self.unknown_idx)
|
808 |
+
else:
|
809 |
+
text_list.append(self.dict[char])
|
810 |
+
text_list.append(self.end_idx)
|
811 |
+
if len(text_list) < self.max_text_len:
|
812 |
+
text_list += [self.padding_idx] * (self.max_text_len - len(text_list))
|
813 |
+
return text_list
|
814 |
+
|
815 |
+
def __call__(self, data):
|
816 |
+
text = data["label"]
|
817 |
+
encoded_text = self.encode(text)
|
818 |
+
if encoded_text is None:
|
819 |
+
return None
|
820 |
+
data["label"] = np.array(encoded_text)
|
821 |
+
return data
|
822 |
+
|
823 |
+
|
824 |
+
class VQATokenLabelEncode(object):
|
825 |
+
"""
|
826 |
+
Label encode for NLP VQA methods
|
827 |
+
"""
|
828 |
+
|
829 |
+
def __init__(
|
830 |
+
self,
|
831 |
+
class_path,
|
832 |
+
contains_re=False,
|
833 |
+
add_special_ids=False,
|
834 |
+
algorithm="LayoutXLM",
|
835 |
+
infer_mode=False,
|
836 |
+
ocr_engine=None,
|
837 |
+
**kwargs
|
838 |
+
):
|
839 |
+
super(VQATokenLabelEncode, self).__init__()
|
840 |
+
from paddlenlp.transformers import (
|
841 |
+
LayoutLMTokenizer,
|
842 |
+
LayoutLMv2Tokenizer,
|
843 |
+
LayoutXLMTokenizer,
|
844 |
+
)
|
845 |
+
|
846 |
+
from ppocr.utils.utility import load_vqa_bio_label_maps
|
847 |
+
|
848 |
+
tokenizer_dict = {
|
849 |
+
"LayoutXLM": {
|
850 |
+
"class": LayoutXLMTokenizer,
|
851 |
+
"pretrained_model": "layoutxlm-base-uncased",
|
852 |
+
},
|
853 |
+
"LayoutLM": {
|
854 |
+
"class": LayoutLMTokenizer,
|
855 |
+
"pretrained_model": "layoutlm-base-uncased",
|
856 |
+
},
|
857 |
+
"LayoutLMv2": {
|
858 |
+
"class": LayoutLMv2Tokenizer,
|
859 |
+
"pretrained_model": "layoutlmv2-base-uncased",
|
860 |
+
},
|
861 |
+
}
|
862 |
+
self.contains_re = contains_re
|
863 |
+
tokenizer_config = tokenizer_dict[algorithm]
|
864 |
+
self.tokenizer = tokenizer_config["class"].from_pretrained(
|
865 |
+
tokenizer_config["pretrained_model"]
|
866 |
+
)
|
867 |
+
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
|
868 |
+
self.add_special_ids = add_special_ids
|
869 |
+
self.infer_mode = infer_mode
|
870 |
+
self.ocr_engine = ocr_engine
|
871 |
+
|
872 |
+
def __call__(self, data):
|
873 |
+
# load bbox and label info
|
874 |
+
ocr_info = self._load_ocr_info(data)
|
875 |
+
|
876 |
+
height, width, _ = data["image"].shape
|
877 |
+
|
878 |
+
words_list = []
|
879 |
+
bbox_list = []
|
880 |
+
input_ids_list = []
|
881 |
+
token_type_ids_list = []
|
882 |
+
segment_offset_id = []
|
883 |
+
gt_label_list = []
|
884 |
+
|
885 |
+
entities = []
|
886 |
+
|
887 |
+
# for re
|
888 |
+
train_re = self.contains_re and not self.infer_mode
|
889 |
+
if train_re:
|
890 |
+
relations = []
|
891 |
+
id2label = {}
|
892 |
+
entity_id_to_index_map = {}
|
893 |
+
empty_entity = set()
|
894 |
+
|
895 |
+
data["ocr_info"] = copy.deepcopy(ocr_info)
|
896 |
+
|
897 |
+
for info in ocr_info:
|
898 |
+
if train_re:
|
899 |
+
# for re
|
900 |
+
if len(info["text"]) == 0:
|
901 |
+
empty_entity.add(info["id"])
|
902 |
+
continue
|
903 |
+
id2label[info["id"]] = info["label"]
|
904 |
+
relations.extend([tuple(sorted(l)) for l in info["linking"]])
|
905 |
+
# smooth_box
|
906 |
+
bbox = self._smooth_box(info["bbox"], height, width)
|
907 |
+
|
908 |
+
text = info["text"]
|
909 |
+
encode_res = self.tokenizer.encode(
|
910 |
+
text, pad_to_max_seq_len=False, return_attention_mask=True
|
911 |
+
)
|
912 |
+
|
913 |
+
if not self.add_special_ids:
|
914 |
+
# TODO: use tok.all_special_ids to remove
|
915 |
+
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
916 |
+
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
|
917 |
+
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
|
918 |
+
# parse label
|
919 |
+
if not self.infer_mode:
|
920 |
+
label = info["label"]
|
921 |
+
gt_label = self._parse_label(label, encode_res)
|
922 |
+
|
923 |
+
# construct entities for re
|
924 |
+
if train_re:
|
925 |
+
if gt_label[0] != self.label2id_map["O"]:
|
926 |
+
entity_id_to_index_map[info["id"]] = len(entities)
|
927 |
+
label = label.upper()
|
928 |
+
entities.append(
|
929 |
+
{
|
930 |
+
"start": len(input_ids_list),
|
931 |
+
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
932 |
+
"label": label.upper(),
|
933 |
+
}
|
934 |
+
)
|
935 |
+
else:
|
936 |
+
entities.append(
|
937 |
+
{
|
938 |
+
"start": len(input_ids_list),
|
939 |
+
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
940 |
+
"label": "O",
|
941 |
+
}
|
942 |
+
)
|
943 |
+
input_ids_list.extend(encode_res["input_ids"])
|
944 |
+
token_type_ids_list.extend(encode_res["token_type_ids"])
|
945 |
+
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
946 |
+
words_list.append(text)
|
947 |
+
segment_offset_id.append(len(input_ids_list))
|
948 |
+
if not self.infer_mode:
|
949 |
+
gt_label_list.extend(gt_label)
|
950 |
+
|
951 |
+
data["input_ids"] = input_ids_list
|
952 |
+
data["token_type_ids"] = token_type_ids_list
|
953 |
+
data["bbox"] = bbox_list
|
954 |
+
data["attention_mask"] = [1] * len(input_ids_list)
|
955 |
+
data["labels"] = gt_label_list
|
956 |
+
data["segment_offset_id"] = segment_offset_id
|
957 |
+
data["tokenizer_params"] = dict(
|
958 |
+
padding_side=self.tokenizer.padding_side,
|
959 |
+
pad_token_type_id=self.tokenizer.pad_token_type_id,
|
960 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
961 |
+
)
|
962 |
+
data["entities"] = entities
|
963 |
+
|
964 |
+
if train_re:
|
965 |
+
data["relations"] = relations
|
966 |
+
data["id2label"] = id2label
|
967 |
+
data["empty_entity"] = empty_entity
|
968 |
+
data["entity_id_to_index_map"] = entity_id_to_index_map
|
969 |
+
return data
|
970 |
+
|
971 |
+
def _load_ocr_info(self, data):
|
972 |
+
def trans_poly_to_bbox(poly):
|
973 |
+
x1 = np.min([p[0] for p in poly])
|
974 |
+
x2 = np.max([p[0] for p in poly])
|
975 |
+
y1 = np.min([p[1] for p in poly])
|
976 |
+
y2 = np.max([p[1] for p in poly])
|
977 |
+
return [x1, y1, x2, y2]
|
978 |
+
|
979 |
+
if self.infer_mode:
|
980 |
+
ocr_result = self.ocr_engine.ocr(data["image"], cls=False)
|
981 |
+
ocr_info = []
|
982 |
+
for res in ocr_result:
|
983 |
+
ocr_info.append(
|
984 |
+
{
|
985 |
+
"text": res[1][0],
|
986 |
+
"bbox": trans_poly_to_bbox(res[0]),
|
987 |
+
"poly": res[0],
|
988 |
+
}
|
989 |
+
)
|
990 |
+
return ocr_info
|
991 |
+
else:
|
992 |
+
info = data["label"]
|
993 |
+
# read text info
|
994 |
+
info_dict = json.loads(info)
|
995 |
+
return info_dict["ocr_info"]
|
996 |
+
|
997 |
+
def _smooth_box(self, bbox, height, width):
|
998 |
+
bbox[0] = int(bbox[0] * 1000.0 / width)
|
999 |
+
bbox[2] = int(bbox[2] * 1000.0 / width)
|
1000 |
+
bbox[1] = int(bbox[1] * 1000.0 / height)
|
1001 |
+
bbox[3] = int(bbox[3] * 1000.0 / height)
|
1002 |
+
return bbox
|
1003 |
+
|
1004 |
+
def _parse_label(self, label, encode_res):
|
1005 |
+
gt_label = []
|
1006 |
+
if label.lower() == "other":
|
1007 |
+
gt_label.extend([0] * len(encode_res["input_ids"]))
|
1008 |
+
else:
|
1009 |
+
gt_label.append(self.label2id_map[("b-" + label).upper()])
|
1010 |
+
gt_label.extend(
|
1011 |
+
[self.label2id_map[("i-" + label).upper()]]
|
1012 |
+
* (len(encode_res["input_ids"]) - 1)
|
1013 |
+
)
|
1014 |
+
return gt_label
|
1015 |
+
|
1016 |
+
|
1017 |
+
class MultiLabelEncode(BaseRecLabelEncode):
|
1018 |
+
def __init__(
|
1019 |
+
self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
|
1020 |
+
):
|
1021 |
+
super(MultiLabelEncode, self).__init__(
|
1022 |
+
max_text_length, character_dict_path, use_space_char
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
self.ctc_encode = CTCLabelEncode(
|
1026 |
+
max_text_length, character_dict_path, use_space_char, **kwargs
|
1027 |
+
)
|
1028 |
+
self.sar_encode = SARLabelEncode(
|
1029 |
+
max_text_length, character_dict_path, use_space_char, **kwargs
|
1030 |
+
)
|
1031 |
+
|
1032 |
+
def __call__(self, data):
|
1033 |
+
|
1034 |
+
data_ctc = copy.deepcopy(data)
|
1035 |
+
data_sar = copy.deepcopy(data)
|
1036 |
+
data_out = dict()
|
1037 |
+
data_out["img_path"] = data.get("img_path", None)
|
1038 |
+
data_out["image"] = data["image"]
|
1039 |
+
ctc = self.ctc_encode.__call__(data_ctc)
|
1040 |
+
sar = self.sar_encode.__call__(data_sar)
|
1041 |
+
if ctc is None or sar is None:
|
1042 |
+
return None
|
1043 |
+
data_out["label_ctc"] = ctc["label"]
|
1044 |
+
data_out["label_sar"] = sar["label"]
|
1045 |
+
data_out["length"] = ctc["length"]
|
1046 |
+
return data_out
|
ocr/ppocr/data/imaug/make_border_map.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
np.seterr(divide="ignore", invalid="ignore")
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import pyclipper
|
10 |
+
from shapely.geometry import Polygon
|
11 |
+
|
12 |
+
warnings.simplefilter("ignore")
|
13 |
+
|
14 |
+
__all__ = ["MakeBorderMap"]
|
15 |
+
|
16 |
+
|
17 |
+
class MakeBorderMap(object):
|
18 |
+
def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7, **kwargs):
|
19 |
+
self.shrink_ratio = shrink_ratio
|
20 |
+
self.thresh_min = thresh_min
|
21 |
+
self.thresh_max = thresh_max
|
22 |
+
|
23 |
+
def __call__(self, data):
|
24 |
+
|
25 |
+
img = data["image"]
|
26 |
+
text_polys = data["polys"]
|
27 |
+
ignore_tags = data["ignore_tags"]
|
28 |
+
|
29 |
+
canvas = np.zeros(img.shape[:2], dtype=np.float32)
|
30 |
+
mask = np.zeros(img.shape[:2], dtype=np.float32)
|
31 |
+
|
32 |
+
for i in range(len(text_polys)):
|
33 |
+
if ignore_tags[i]:
|
34 |
+
continue
|
35 |
+
self.draw_border_map(text_polys[i], canvas, mask=mask)
|
36 |
+
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
|
37 |
+
|
38 |
+
data["threshold_map"] = canvas
|
39 |
+
data["threshold_mask"] = mask
|
40 |
+
return data
|
41 |
+
|
42 |
+
def draw_border_map(self, polygon, canvas, mask):
|
43 |
+
polygon = np.array(polygon)
|
44 |
+
assert polygon.ndim == 2
|
45 |
+
assert polygon.shape[1] == 2
|
46 |
+
|
47 |
+
polygon_shape = Polygon(polygon)
|
48 |
+
if polygon_shape.area <= 0:
|
49 |
+
return
|
50 |
+
distance = (
|
51 |
+
polygon_shape.area
|
52 |
+
* (1 - np.power(self.shrink_ratio, 2))
|
53 |
+
/ polygon_shape.length
|
54 |
+
)
|
55 |
+
subject = [tuple(l) for l in polygon]
|
56 |
+
padding = pyclipper.PyclipperOffset()
|
57 |
+
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
58 |
+
|
59 |
+
padded_polygon = np.array(padding.Execute(distance)[0])
|
60 |
+
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
61 |
+
|
62 |
+
xmin = padded_polygon[:, 0].min()
|
63 |
+
xmax = padded_polygon[:, 0].max()
|
64 |
+
ymin = padded_polygon[:, 1].min()
|
65 |
+
ymax = padded_polygon[:, 1].max()
|
66 |
+
width = xmax - xmin + 1
|
67 |
+
height = ymax - ymin + 1
|
68 |
+
|
69 |
+
polygon[:, 0] = polygon[:, 0] - xmin
|
70 |
+
polygon[:, 1] = polygon[:, 1] - ymin
|
71 |
+
|
72 |
+
xs = np.broadcast_to(
|
73 |
+
np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)
|
74 |
+
)
|
75 |
+
ys = np.broadcast_to(
|
76 |
+
np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)
|
77 |
+
)
|
78 |
+
|
79 |
+
distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
|
80 |
+
for i in range(polygon.shape[0]):
|
81 |
+
j = (i + 1) % polygon.shape[0]
|
82 |
+
absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
|
83 |
+
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
|
84 |
+
distance_map = distance_map.min(axis=0)
|
85 |
+
|
86 |
+
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
|
87 |
+
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
|
88 |
+
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
|
89 |
+
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
|
90 |
+
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
|
91 |
+
1
|
92 |
+
- distance_map[
|
93 |
+
ymin_valid - ymin : ymax_valid - ymax + height,
|
94 |
+
xmin_valid - xmin : xmax_valid - xmax + width,
|
95 |
+
],
|
96 |
+
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
|
97 |
+
)
|
98 |
+
|
99 |
+
def _distance(self, xs, ys, point_1, point_2):
|
100 |
+
"""
|
101 |
+
compute the distance from point to a line
|
102 |
+
ys: coordinates in the first axis
|
103 |
+
xs: coordinates in the second axis
|
104 |
+
point_1, point_2: (x, y), the end of the line
|
105 |
+
"""
|
106 |
+
height, width = xs.shape[:2]
|
107 |
+
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
|
108 |
+
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
|
109 |
+
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
|
110 |
+
point_1[1] - point_2[1]
|
111 |
+
)
|
112 |
+
|
113 |
+
cosin = (square_distance - square_distance_1 - square_distance_2) / (
|
114 |
+
2 * np.sqrt(square_distance_1 * square_distance_2)
|
115 |
+
)
|
116 |
+
square_sin = 1 - np.square(cosin)
|
117 |
+
square_sin = np.nan_to_num(square_sin)
|
118 |
+
result = np.sqrt(
|
119 |
+
square_distance_1 * square_distance_2 * square_sin / square_distance
|
120 |
+
)
|
121 |
+
|
122 |
+
result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[
|
123 |
+
cosin < 0
|
124 |
+
]
|
125 |
+
# self.extend_line(point_1, point_2, result)
|
126 |
+
return result
|
127 |
+
|
128 |
+
def extend_line(self, point_1, point_2, result, shrink_ratio):
|
129 |
+
ex_point_1 = (
|
130 |
+
int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
|
131 |
+
int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + shrink_ratio))),
|
132 |
+
)
|
133 |
+
cv2.line(
|
134 |
+
result,
|
135 |
+
tuple(ex_point_1),
|
136 |
+
tuple(point_1),
|
137 |
+
4096.0,
|
138 |
+
1,
|
139 |
+
lineType=cv2.LINE_AA,
|
140 |
+
shift=0,
|
141 |
+
)
|
142 |
+
ex_point_2 = (
|
143 |
+
int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
|
144 |
+
int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + shrink_ratio))),
|
145 |
+
)
|
146 |
+
cv2.line(
|
147 |
+
result,
|
148 |
+
tuple(ex_point_2),
|
149 |
+
tuple(point_2),
|
150 |
+
4096.0,
|
151 |
+
1,
|
152 |
+
lineType=cv2.LINE_AA,
|
153 |
+
shift=0,
|
154 |
+
)
|
155 |
+
return ex_point_1, ex_point_2
|
ocr/ppocr/data/imaug/make_pse_gt.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import pyclipper
|
6 |
+
from shapely.geometry import Polygon
|
7 |
+
|
8 |
+
__all__ = ["MakePseGt"]
|
9 |
+
|
10 |
+
|
11 |
+
class MakePseGt(object):
|
12 |
+
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
|
13 |
+
self.kernel_num = kernel_num
|
14 |
+
self.min_shrink_ratio = min_shrink_ratio
|
15 |
+
self.size = size
|
16 |
+
|
17 |
+
def __call__(self, data):
|
18 |
+
|
19 |
+
image = data["image"]
|
20 |
+
text_polys = data["polys"]
|
21 |
+
ignore_tags = data["ignore_tags"]
|
22 |
+
|
23 |
+
h, w, _ = image.shape
|
24 |
+
short_edge = min(h, w)
|
25 |
+
if short_edge < self.size:
|
26 |
+
# keep short_size >= self.size
|
27 |
+
scale = self.size / short_edge
|
28 |
+
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
29 |
+
text_polys *= scale
|
30 |
+
|
31 |
+
gt_kernels = []
|
32 |
+
for i in range(1, self.kernel_num + 1):
|
33 |
+
# s1->sn, from big to small
|
34 |
+
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
|
35 |
+
text_kernel, ignore_tags = self.generate_kernel(
|
36 |
+
image.shape[0:2], rate, text_polys, ignore_tags
|
37 |
+
)
|
38 |
+
gt_kernels.append(text_kernel)
|
39 |
+
|
40 |
+
training_mask = np.ones(image.shape[0:2], dtype="uint8")
|
41 |
+
for i in range(text_polys.shape[0]):
|
42 |
+
if ignore_tags[i]:
|
43 |
+
cv2.fillPoly(
|
44 |
+
training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0
|
45 |
+
)
|
46 |
+
|
47 |
+
gt_kernels = np.array(gt_kernels)
|
48 |
+
gt_kernels[gt_kernels > 0] = 1
|
49 |
+
|
50 |
+
data["image"] = image
|
51 |
+
data["polys"] = text_polys
|
52 |
+
data["gt_kernels"] = gt_kernels[0:]
|
53 |
+
data["gt_text"] = gt_kernels[0]
|
54 |
+
data["mask"] = training_mask.astype("float32")
|
55 |
+
return data
|
56 |
+
|
57 |
+
def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
|
58 |
+
"""
|
59 |
+
Refer to part of the code:
|
60 |
+
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
|
61 |
+
"""
|
62 |
+
|
63 |
+
h, w = img_size
|
64 |
+
text_kernel = np.zeros((h, w), dtype=np.float32)
|
65 |
+
for i, poly in enumerate(text_polys):
|
66 |
+
polygon = Polygon(poly)
|
67 |
+
distance = (
|
68 |
+
polygon.area
|
69 |
+
* (1 - shrink_ratio * shrink_ratio)
|
70 |
+
/ (polygon.length + 1e-6)
|
71 |
+
)
|
72 |
+
subject = [tuple(l) for l in poly]
|
73 |
+
pco = pyclipper.PyclipperOffset()
|
74 |
+
pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
75 |
+
shrinked = np.array(pco.Execute(-distance))
|
76 |
+
|
77 |
+
if len(shrinked) == 0 or shrinked.size == 0:
|
78 |
+
if ignore_tags is not None:
|
79 |
+
ignore_tags[i] = True
|
80 |
+
continue
|
81 |
+
try:
|
82 |
+
shrinked = np.array(shrinked[0]).reshape(-1, 2)
|
83 |
+
except:
|
84 |
+
if ignore_tags is not None:
|
85 |
+
ignore_tags[i] = True
|
86 |
+
continue
|
87 |
+
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
|
88 |
+
return text_kernel, ignore_tags
|
ocr/ppocr/data/imaug/make_shrink_map.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import pyclipper
|
6 |
+
from shapely.geometry import Polygon
|
7 |
+
|
8 |
+
__all__ = ["MakeShrinkMap"]
|
9 |
+
|
10 |
+
|
11 |
+
class MakeShrinkMap(object):
|
12 |
+
r"""
|
13 |
+
Making binary mask from detection data with ICDAR format.
|
14 |
+
Typically following the process of class `MakeICDARData`.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
|
18 |
+
self.min_text_size = min_text_size
|
19 |
+
self.shrink_ratio = shrink_ratio
|
20 |
+
|
21 |
+
def __call__(self, data):
|
22 |
+
image = data["image"]
|
23 |
+
text_polys = data["polys"]
|
24 |
+
ignore_tags = data["ignore_tags"]
|
25 |
+
|
26 |
+
h, w = image.shape[:2]
|
27 |
+
text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
|
28 |
+
gt = np.zeros((h, w), dtype=np.float32)
|
29 |
+
mask = np.ones((h, w), dtype=np.float32)
|
30 |
+
for i in range(len(text_polys)):
|
31 |
+
polygon = text_polys[i]
|
32 |
+
height = max(polygon[:, 1]) - min(polygon[:, 1])
|
33 |
+
width = max(polygon[:, 0]) - min(polygon[:, 0])
|
34 |
+
if ignore_tags[i] or min(height, width) < self.min_text_size:
|
35 |
+
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
36 |
+
ignore_tags[i] = True
|
37 |
+
else:
|
38 |
+
polygon_shape = Polygon(polygon)
|
39 |
+
subject = [tuple(l) for l in polygon]
|
40 |
+
padding = pyclipper.PyclipperOffset()
|
41 |
+
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
42 |
+
shrinked = []
|
43 |
+
|
44 |
+
# Increase the shrink ratio every time we get multiple polygon returned back
|
45 |
+
possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio)
|
46 |
+
np.append(possible_ratios, 1)
|
47 |
+
# print(possible_ratios)
|
48 |
+
for ratio in possible_ratios:
|
49 |
+
# print(f"Change shrink ratio to {ratio}")
|
50 |
+
distance = (
|
51 |
+
polygon_shape.area
|
52 |
+
* (1 - np.power(ratio, 2))
|
53 |
+
/ polygon_shape.length
|
54 |
+
)
|
55 |
+
shrinked = padding.Execute(-distance)
|
56 |
+
if len(shrinked) == 1:
|
57 |
+
break
|
58 |
+
|
59 |
+
if shrinked == []:
|
60 |
+
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
61 |
+
ignore_tags[i] = True
|
62 |
+
continue
|
63 |
+
|
64 |
+
for each_shirnk in shrinked:
|
65 |
+
shirnk = np.array(each_shirnk).reshape(-1, 2)
|
66 |
+
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
|
67 |
+
|
68 |
+
data["shrink_map"] = gt
|
69 |
+
data["shrink_mask"] = mask
|
70 |
+
return data
|
71 |
+
|
72 |
+
def validate_polygons(self, polygons, ignore_tags, h, w):
|
73 |
+
"""
|
74 |
+
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
|
75 |
+
"""
|
76 |
+
if len(polygons) == 0:
|
77 |
+
return polygons, ignore_tags
|
78 |
+
assert len(polygons) == len(ignore_tags)
|
79 |
+
for polygon in polygons:
|
80 |
+
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
|
81 |
+
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
|
82 |
+
|
83 |
+
for i in range(len(polygons)):
|
84 |
+
area = self.polygon_area(polygons[i])
|
85 |
+
if abs(area) < 1:
|
86 |
+
ignore_tags[i] = True
|
87 |
+
if area > 0:
|
88 |
+
polygons[i] = polygons[i][::-1, :]
|
89 |
+
return polygons, ignore_tags
|
90 |
+
|
91 |
+
def polygon_area(self, polygon):
|
92 |
+
"""
|
93 |
+
compute polygon area
|
94 |
+
"""
|
95 |
+
area = 0
|
96 |
+
q = polygon[-1]
|
97 |
+
for p in polygon:
|
98 |
+
area += p[0] * q[1] - p[1] * q[0]
|
99 |
+
q = p
|
100 |
+
return area / 2.0
|
ocr/ppocr/data/imaug/operators.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
|
3 |
+
import math
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import six
|
9 |
+
|
10 |
+
|
11 |
+
class DecodeImage(object):
|
12 |
+
"""decode image"""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self, img_mode="RGB", channel_first=False, ignore_orientation=False, **kwargs
|
16 |
+
):
|
17 |
+
self.img_mode = img_mode
|
18 |
+
self.channel_first = channel_first
|
19 |
+
self.ignore_orientation = ignore_orientation
|
20 |
+
|
21 |
+
def __call__(self, data):
|
22 |
+
img = data["image"]
|
23 |
+
if six.PY2:
|
24 |
+
assert (
|
25 |
+
type(img) is str and len(img) > 0
|
26 |
+
), "invalid input 'img' in DecodeImage"
|
27 |
+
else:
|
28 |
+
assert (
|
29 |
+
type(img) is bytes and len(img) > 0
|
30 |
+
), "invalid input 'img' in DecodeImage"
|
31 |
+
img = np.frombuffer(img, dtype="uint8")
|
32 |
+
if self.ignore_orientation:
|
33 |
+
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
|
34 |
+
else:
|
35 |
+
img = cv2.imdecode(img, 1)
|
36 |
+
if img is None:
|
37 |
+
return None
|
38 |
+
if self.img_mode == "GRAY":
|
39 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
40 |
+
elif self.img_mode == "RGB":
|
41 |
+
assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
|
42 |
+
img = img[:, :, ::-1]
|
43 |
+
|
44 |
+
if self.channel_first:
|
45 |
+
img = img.transpose((2, 0, 1))
|
46 |
+
|
47 |
+
data["image"] = img
|
48 |
+
return data
|
49 |
+
|
50 |
+
|
51 |
+
class NRTRDecodeImage(object):
|
52 |
+
"""decode image"""
|
53 |
+
|
54 |
+
def __init__(self, img_mode="RGB", channel_first=False, **kwargs):
|
55 |
+
self.img_mode = img_mode
|
56 |
+
self.channel_first = channel_first
|
57 |
+
|
58 |
+
def __call__(self, data):
|
59 |
+
img = data["image"]
|
60 |
+
if six.PY2:
|
61 |
+
assert (
|
62 |
+
type(img) is str and len(img) > 0
|
63 |
+
), "invalid input 'img' in DecodeImage"
|
64 |
+
else:
|
65 |
+
assert (
|
66 |
+
type(img) is bytes and len(img) > 0
|
67 |
+
), "invalid input 'img' in DecodeImage"
|
68 |
+
img = np.frombuffer(img, dtype="uint8")
|
69 |
+
|
70 |
+
img = cv2.imdecode(img, 1)
|
71 |
+
|
72 |
+
if img is None:
|
73 |
+
return None
|
74 |
+
if self.img_mode == "GRAY":
|
75 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
76 |
+
elif self.img_mode == "RGB":
|
77 |
+
assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
|
78 |
+
img = img[:, :, ::-1]
|
79 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
80 |
+
if self.channel_first:
|
81 |
+
img = img.transpose((2, 0, 1))
|
82 |
+
data["image"] = img
|
83 |
+
return data
|
84 |
+
|
85 |
+
|
86 |
+
class NormalizeImage(object):
|
87 |
+
"""normalize image such as substract mean, divide std"""
|
88 |
+
|
89 |
+
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
|
90 |
+
if isinstance(scale, str):
|
91 |
+
scale = eval(scale)
|
92 |
+
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
93 |
+
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
94 |
+
std = std if std is not None else [0.229, 0.224, 0.225]
|
95 |
+
|
96 |
+
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
|
97 |
+
self.mean = np.array(mean).reshape(shape).astype("float32")
|
98 |
+
self.std = np.array(std).reshape(shape).astype("float32")
|
99 |
+
|
100 |
+
def __call__(self, data):
|
101 |
+
img = data["image"]
|
102 |
+
from PIL import Image
|
103 |
+
|
104 |
+
if isinstance(img, Image.Image):
|
105 |
+
img = np.array(img)
|
106 |
+
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
|
107 |
+
data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
|
108 |
+
return data
|
109 |
+
|
110 |
+
|
111 |
+
class ToCHWImage(object):
|
112 |
+
"""convert hwc image to chw image"""
|
113 |
+
|
114 |
+
def __init__(self, **kwargs):
|
115 |
+
pass
|
116 |
+
|
117 |
+
def __call__(self, data):
|
118 |
+
img = data["image"]
|
119 |
+
from PIL import Image
|
120 |
+
|
121 |
+
if isinstance(img, Image.Image):
|
122 |
+
img = np.array(img)
|
123 |
+
data["image"] = img.transpose((2, 0, 1))
|
124 |
+
return data
|
125 |
+
|
126 |
+
|
127 |
+
class Fasttext(object):
|
128 |
+
def __init__(self, path="None", **kwargs):
|
129 |
+
import fasttext
|
130 |
+
|
131 |
+
self.fast_model = fasttext.load_model(path)
|
132 |
+
|
133 |
+
def __call__(self, data):
|
134 |
+
label = data["label"]
|
135 |
+
fast_label = self.fast_model[label]
|
136 |
+
data["fast_label"] = fast_label
|
137 |
+
return data
|
138 |
+
|
139 |
+
|
140 |
+
class KeepKeys(object):
|
141 |
+
def __init__(self, keep_keys, **kwargs):
|
142 |
+
self.keep_keys = keep_keys
|
143 |
+
|
144 |
+
def __call__(self, data):
|
145 |
+
data_list = []
|
146 |
+
for key in self.keep_keys:
|
147 |
+
data_list.append(data[key])
|
148 |
+
return data_list
|
149 |
+
|
150 |
+
|
151 |
+
class Pad(object):
|
152 |
+
def __init__(self, size=None, size_div=32, **kwargs):
|
153 |
+
if size is not None and not isinstance(size, (int, list, tuple)):
|
154 |
+
raise TypeError(
|
155 |
+
"Type of target_size is invalid. Now is {}".format(type(size))
|
156 |
+
)
|
157 |
+
if isinstance(size, int):
|
158 |
+
size = [size, size]
|
159 |
+
self.size = size
|
160 |
+
self.size_div = size_div
|
161 |
+
|
162 |
+
def __call__(self, data):
|
163 |
+
|
164 |
+
img = data["image"]
|
165 |
+
img_h, img_w = img.shape[0], img.shape[1]
|
166 |
+
if self.size:
|
167 |
+
resize_h2, resize_w2 = self.size
|
168 |
+
assert (
|
169 |
+
img_h < resize_h2 and img_w < resize_w2
|
170 |
+
), "(h, w) of target size should be greater than (img_h, img_w)"
|
171 |
+
else:
|
172 |
+
resize_h2 = max(
|
173 |
+
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
|
174 |
+
self.size_div,
|
175 |
+
)
|
176 |
+
resize_w2 = max(
|
177 |
+
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
|
178 |
+
self.size_div,
|
179 |
+
)
|
180 |
+
img = cv2.copyMakeBorder(
|
181 |
+
img,
|
182 |
+
0,
|
183 |
+
resize_h2 - img_h,
|
184 |
+
0,
|
185 |
+
resize_w2 - img_w,
|
186 |
+
cv2.BORDER_CONSTANT,
|
187 |
+
value=0,
|
188 |
+
)
|
189 |
+
data["image"] = img
|
190 |
+
return data
|
191 |
+
|
192 |
+
|
193 |
+
class Resize(object):
|
194 |
+
def __init__(self, size=(640, 640), **kwargs):
|
195 |
+
self.size = size
|
196 |
+
|
197 |
+
def resize_image(self, img):
|
198 |
+
resize_h, resize_w = self.size
|
199 |
+
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
200 |
+
ratio_h = float(resize_h) / ori_h
|
201 |
+
ratio_w = float(resize_w) / ori_w
|
202 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
203 |
+
return img, [ratio_h, ratio_w]
|
204 |
+
|
205 |
+
def __call__(self, data):
|
206 |
+
img = data["image"]
|
207 |
+
if "polys" in data:
|
208 |
+
text_polys = data["polys"]
|
209 |
+
|
210 |
+
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
211 |
+
if "polys" in data:
|
212 |
+
new_boxes = []
|
213 |
+
for box in text_polys:
|
214 |
+
new_box = []
|
215 |
+
for cord in box:
|
216 |
+
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
217 |
+
new_boxes.append(new_box)
|
218 |
+
data["polys"] = np.array(new_boxes, dtype=np.float32)
|
219 |
+
data["image"] = img_resize
|
220 |
+
return data
|
221 |
+
|
222 |
+
|
223 |
+
class DetResizeForTest(object):
|
224 |
+
def __init__(self, **kwargs):
|
225 |
+
super(DetResizeForTest, self).__init__()
|
226 |
+
self.resize_type = 0
|
227 |
+
if "image_shape" in kwargs:
|
228 |
+
self.image_shape = kwargs["image_shape"]
|
229 |
+
self.resize_type = 1
|
230 |
+
elif "limit_side_len" in kwargs:
|
231 |
+
self.limit_side_len = kwargs["limit_side_len"]
|
232 |
+
self.limit_type = kwargs.get("limit_type", "min")
|
233 |
+
elif "resize_long" in kwargs:
|
234 |
+
self.resize_type = 2
|
235 |
+
self.resize_long = kwargs.get("resize_long", 960)
|
236 |
+
else:
|
237 |
+
self.limit_side_len = 736
|
238 |
+
self.limit_type = "min"
|
239 |
+
|
240 |
+
def __call__(self, data):
|
241 |
+
img = data["image"]
|
242 |
+
src_h, src_w, _ = img.shape
|
243 |
+
|
244 |
+
if self.resize_type == 0:
|
245 |
+
# img, shape = self.resize_image_type0(img)
|
246 |
+
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
247 |
+
elif self.resize_type == 2:
|
248 |
+
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
249 |
+
else:
|
250 |
+
# img, shape = self.resize_image_type1(img)
|
251 |
+
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
252 |
+
data["image"] = img
|
253 |
+
data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
|
254 |
+
return data
|
255 |
+
|
256 |
+
def resize_image_type1(self, img):
|
257 |
+
resize_h, resize_w = self.image_shape
|
258 |
+
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
259 |
+
ratio_h = float(resize_h) / ori_h
|
260 |
+
ratio_w = float(resize_w) / ori_w
|
261 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
262 |
+
# return img, np.array([ori_h, ori_w])
|
263 |
+
return img, [ratio_h, ratio_w]
|
264 |
+
|
265 |
+
def resize_image_type0(self, img):
|
266 |
+
"""
|
267 |
+
resize image to a size multiple of 32 which is required by the network
|
268 |
+
args:
|
269 |
+
img(array): array with shape [h, w, c]
|
270 |
+
return(tuple):
|
271 |
+
img, (ratio_h, ratio_w)
|
272 |
+
"""
|
273 |
+
limit_side_len = self.limit_side_len
|
274 |
+
h, w, c = img.shape
|
275 |
+
|
276 |
+
# limit the max side
|
277 |
+
if self.limit_type == "max":
|
278 |
+
if max(h, w) > limit_side_len:
|
279 |
+
if h > w:
|
280 |
+
ratio = float(limit_side_len) / h
|
281 |
+
else:
|
282 |
+
ratio = float(limit_side_len) / w
|
283 |
+
else:
|
284 |
+
ratio = 1.0
|
285 |
+
elif self.limit_type == "min":
|
286 |
+
if min(h, w) < limit_side_len:
|
287 |
+
if h < w:
|
288 |
+
ratio = float(limit_side_len) / h
|
289 |
+
else:
|
290 |
+
ratio = float(limit_side_len) / w
|
291 |
+
else:
|
292 |
+
ratio = 1.0
|
293 |
+
elif self.limit_type == "resize_long":
|
294 |
+
ratio = float(limit_side_len) / max(h, w)
|
295 |
+
else:
|
296 |
+
raise Exception("not support limit type, image ")
|
297 |
+
resize_h = int(h * ratio)
|
298 |
+
resize_w = int(w * ratio)
|
299 |
+
|
300 |
+
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
301 |
+
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
302 |
+
|
303 |
+
try:
|
304 |
+
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
305 |
+
return None, (None, None)
|
306 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
307 |
+
except:
|
308 |
+
print(img.shape, resize_w, resize_h)
|
309 |
+
sys.exit(0)
|
310 |
+
ratio_h = resize_h / float(h)
|
311 |
+
ratio_w = resize_w / float(w)
|
312 |
+
return img, [ratio_h, ratio_w]
|
313 |
+
|
314 |
+
def resize_image_type2(self, img):
|
315 |
+
h, w, _ = img.shape
|
316 |
+
|
317 |
+
resize_w = w
|
318 |
+
resize_h = h
|
319 |
+
|
320 |
+
if resize_h > resize_w:
|
321 |
+
ratio = float(self.resize_long) / resize_h
|
322 |
+
else:
|
323 |
+
ratio = float(self.resize_long) / resize_w
|
324 |
+
|
325 |
+
resize_h = int(resize_h * ratio)
|
326 |
+
resize_w = int(resize_w * ratio)
|
327 |
+
|
328 |
+
max_stride = 128
|
329 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
330 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
331 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
332 |
+
ratio_h = resize_h / float(h)
|
333 |
+
ratio_w = resize_w / float(w)
|
334 |
+
|
335 |
+
return img, [ratio_h, ratio_w]
|
336 |
+
|
337 |
+
|
338 |
+
class E2EResizeForTest(object):
|
339 |
+
def __init__(self, **kwargs):
|
340 |
+
super(E2EResizeForTest, self).__init__()
|
341 |
+
self.max_side_len = kwargs["max_side_len"]
|
342 |
+
self.valid_set = kwargs["valid_set"]
|
343 |
+
|
344 |
+
def __call__(self, data):
|
345 |
+
img = data["image"]
|
346 |
+
src_h, src_w, _ = img.shape
|
347 |
+
if self.valid_set == "totaltext":
|
348 |
+
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
349 |
+
img, max_side_len=self.max_side_len
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
353 |
+
img, max_side_len=self.max_side_len
|
354 |
+
)
|
355 |
+
data["image"] = im_resized
|
356 |
+
data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
|
357 |
+
return data
|
358 |
+
|
359 |
+
def resize_image_for_totaltext(self, im, max_side_len=512):
|
360 |
+
|
361 |
+
h, w, _ = im.shape
|
362 |
+
resize_w = w
|
363 |
+
resize_h = h
|
364 |
+
ratio = 1.25
|
365 |
+
if h * ratio > max_side_len:
|
366 |
+
ratio = float(max_side_len) / resize_h
|
367 |
+
resize_h = int(resize_h * ratio)
|
368 |
+
resize_w = int(resize_w * ratio)
|
369 |
+
|
370 |
+
max_stride = 128
|
371 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
372 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
373 |
+
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
374 |
+
ratio_h = resize_h / float(h)
|
375 |
+
ratio_w = resize_w / float(w)
|
376 |
+
return im, (ratio_h, ratio_w)
|
377 |
+
|
378 |
+
def resize_image(self, im, max_side_len=512):
|
379 |
+
"""
|
380 |
+
resize image to a size multiple of max_stride which is required by the network
|
381 |
+
:param im: the resized image
|
382 |
+
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
383 |
+
:return: the resized image and the resize ratio
|
384 |
+
"""
|
385 |
+
h, w, _ = im.shape
|
386 |
+
|
387 |
+
resize_w = w
|
388 |
+
resize_h = h
|
389 |
+
|
390 |
+
# Fix the longer side
|
391 |
+
if resize_h > resize_w:
|
392 |
+
ratio = float(max_side_len) / resize_h
|
393 |
+
else:
|
394 |
+
ratio = float(max_side_len) / resize_w
|
395 |
+
|
396 |
+
resize_h = int(resize_h * ratio)
|
397 |
+
resize_w = int(resize_w * ratio)
|
398 |
+
|
399 |
+
max_stride = 128
|
400 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
401 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
402 |
+
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
403 |
+
ratio_h = resize_h / float(h)
|
404 |
+
ratio_w = resize_w / float(w)
|
405 |
+
|
406 |
+
return im, (ratio_h, ratio_w)
|
407 |
+
|
408 |
+
|
409 |
+
class KieResize(object):
|
410 |
+
def __init__(self, **kwargs):
|
411 |
+
super(KieResize, self).__init__()
|
412 |
+
self.max_side, self.min_side = kwargs["img_scale"][0], kwargs["img_scale"][1]
|
413 |
+
|
414 |
+
def __call__(self, data):
|
415 |
+
img = data["image"]
|
416 |
+
points = data["points"]
|
417 |
+
src_h, src_w, _ = img.shape
|
418 |
+
(
|
419 |
+
im_resized,
|
420 |
+
scale_factor,
|
421 |
+
[ratio_h, ratio_w],
|
422 |
+
[new_h, new_w],
|
423 |
+
) = self.resize_image(img)
|
424 |
+
resize_points = self.resize_boxes(img, points, scale_factor)
|
425 |
+
data["ori_image"] = img
|
426 |
+
data["ori_boxes"] = points
|
427 |
+
data["points"] = resize_points
|
428 |
+
data["image"] = im_resized
|
429 |
+
data["shape"] = np.array([new_h, new_w])
|
430 |
+
return data
|
431 |
+
|
432 |
+
def resize_image(self, img):
|
433 |
+
norm_img = np.zeros([1024, 1024, 3], dtype="float32")
|
434 |
+
scale = [512, 1024]
|
435 |
+
h, w = img.shape[:2]
|
436 |
+
max_long_edge = max(scale)
|
437 |
+
max_short_edge = min(scale)
|
438 |
+
scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
|
439 |
+
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(
|
440 |
+
h * float(scale_factor) + 0.5
|
441 |
+
)
|
442 |
+
max_stride = 32
|
443 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
444 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
445 |
+
im = cv2.resize(img, (resize_w, resize_h))
|
446 |
+
new_h, new_w = im.shape[:2]
|
447 |
+
w_scale = new_w / w
|
448 |
+
h_scale = new_h / h
|
449 |
+
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
450 |
+
norm_img[:new_h, :new_w, :] = im
|
451 |
+
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
452 |
+
|
453 |
+
def resize_boxes(self, im, points, scale_factor):
|
454 |
+
points = points * scale_factor
|
455 |
+
img_shape = im.shape[:2]
|
456 |
+
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
457 |
+
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
458 |
+
return points
|
ocr/ppocr/data/imaug/pg_process.py
ADDED
@@ -0,0 +1,961 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
__all__ = ["PGProcessTrain"]
|
7 |
+
|
8 |
+
|
9 |
+
class PGProcessTrain(object):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
character_dict_path,
|
13 |
+
max_text_length,
|
14 |
+
max_text_nums,
|
15 |
+
tcl_len,
|
16 |
+
batch_size=14,
|
17 |
+
min_crop_size=24,
|
18 |
+
min_text_size=4,
|
19 |
+
max_text_size=512,
|
20 |
+
**kwargs
|
21 |
+
):
|
22 |
+
self.tcl_len = tcl_len
|
23 |
+
self.max_text_length = max_text_length
|
24 |
+
self.max_text_nums = max_text_nums
|
25 |
+
self.batch_size = batch_size
|
26 |
+
self.min_crop_size = min_crop_size
|
27 |
+
self.min_text_size = min_text_size
|
28 |
+
self.max_text_size = max_text_size
|
29 |
+
self.Lexicon_Table = self.get_dict(character_dict_path)
|
30 |
+
self.pad_num = len(self.Lexicon_Table)
|
31 |
+
self.img_id = 0
|
32 |
+
|
33 |
+
def get_dict(self, character_dict_path):
|
34 |
+
character_str = ""
|
35 |
+
with open(character_dict_path, "rb") as fin:
|
36 |
+
lines = fin.readlines()
|
37 |
+
for line in lines:
|
38 |
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
39 |
+
character_str += line
|
40 |
+
dict_character = list(character_str)
|
41 |
+
return dict_character
|
42 |
+
|
43 |
+
def quad_area(self, poly):
|
44 |
+
"""
|
45 |
+
compute area of a polygon
|
46 |
+
:param poly:
|
47 |
+
:return:
|
48 |
+
"""
|
49 |
+
edge = [
|
50 |
+
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
51 |
+
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
52 |
+
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
53 |
+
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
|
54 |
+
]
|
55 |
+
return np.sum(edge) / 2.0
|
56 |
+
|
57 |
+
def gen_quad_from_poly(self, poly):
|
58 |
+
"""
|
59 |
+
Generate min area quad from poly.
|
60 |
+
"""
|
61 |
+
point_num = poly.shape[0]
|
62 |
+
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
63 |
+
rect = cv2.minAreaRect(
|
64 |
+
poly.astype(np.int32)
|
65 |
+
) # (center (x,y), (width, height), angle of rotation)
|
66 |
+
box = np.array(cv2.boxPoints(rect))
|
67 |
+
|
68 |
+
first_point_idx = 0
|
69 |
+
min_dist = 1e4
|
70 |
+
for i in range(4):
|
71 |
+
dist = (
|
72 |
+
np.linalg.norm(box[(i + 0) % 4] - poly[0])
|
73 |
+
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
|
74 |
+
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
|
75 |
+
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
76 |
+
)
|
77 |
+
if dist < min_dist:
|
78 |
+
min_dist = dist
|
79 |
+
first_point_idx = i
|
80 |
+
for i in range(4):
|
81 |
+
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
82 |
+
|
83 |
+
return min_area_quad
|
84 |
+
|
85 |
+
def check_and_validate_polys(self, polys, tags, im_size):
|
86 |
+
"""
|
87 |
+
check so that the text poly is in the same direction,
|
88 |
+
and also filter some invalid polygons
|
89 |
+
:param polys:
|
90 |
+
:param tags:
|
91 |
+
:return:
|
92 |
+
"""
|
93 |
+
(h, w) = im_size
|
94 |
+
if polys.shape[0] == 0:
|
95 |
+
return polys, np.array([]), np.array([])
|
96 |
+
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
97 |
+
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
98 |
+
|
99 |
+
validated_polys = []
|
100 |
+
validated_tags = []
|
101 |
+
hv_tags = []
|
102 |
+
for poly, tag in zip(polys, tags):
|
103 |
+
quad = self.gen_quad_from_poly(poly)
|
104 |
+
p_area = self.quad_area(quad)
|
105 |
+
if abs(p_area) < 1:
|
106 |
+
print("invalid poly")
|
107 |
+
continue
|
108 |
+
if p_area > 0:
|
109 |
+
if tag == False:
|
110 |
+
print("poly in wrong direction")
|
111 |
+
tag = True # reversed cases should be ignore
|
112 |
+
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
|
113 |
+
quad = quad[(0, 3, 2, 1), :]
|
114 |
+
|
115 |
+
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(
|
116 |
+
quad[3] - quad[2]
|
117 |
+
)
|
118 |
+
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(
|
119 |
+
quad[1] - quad[2]
|
120 |
+
)
|
121 |
+
hv_tag = 1
|
122 |
+
|
123 |
+
if len_w * 2.0 < len_h:
|
124 |
+
hv_tag = 0
|
125 |
+
|
126 |
+
validated_polys.append(poly)
|
127 |
+
validated_tags.append(tag)
|
128 |
+
hv_tags.append(hv_tag)
|
129 |
+
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
|
130 |
+
|
131 |
+
def crop_area(
|
132 |
+
self, im, polys, tags, hv_tags, txts, crop_background=False, max_tries=25
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
make random crop from the input image
|
136 |
+
:param im:
|
137 |
+
:param polys: [b,4,2]
|
138 |
+
:param tags:
|
139 |
+
:param crop_background:
|
140 |
+
:param max_tries: 50 -> 25
|
141 |
+
:return:
|
142 |
+
"""
|
143 |
+
h, w, _ = im.shape
|
144 |
+
pad_h = h // 10
|
145 |
+
pad_w = w // 10
|
146 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
147 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
148 |
+
for poly in polys:
|
149 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
150 |
+
minx = np.min(poly[:, 0])
|
151 |
+
maxx = np.max(poly[:, 0])
|
152 |
+
w_array[minx + pad_w : maxx + pad_w] = 1
|
153 |
+
miny = np.min(poly[:, 1])
|
154 |
+
maxy = np.max(poly[:, 1])
|
155 |
+
h_array[miny + pad_h : maxy + pad_h] = 1
|
156 |
+
# ensure the cropped area not across a text
|
157 |
+
h_axis = np.where(h_array == 0)[0]
|
158 |
+
w_axis = np.where(w_array == 0)[0]
|
159 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
160 |
+
return im, polys, tags, hv_tags, txts
|
161 |
+
for i in range(max_tries):
|
162 |
+
xx = np.random.choice(w_axis, size=2)
|
163 |
+
xmin = np.min(xx) - pad_w
|
164 |
+
xmax = np.max(xx) - pad_w
|
165 |
+
xmin = np.clip(xmin, 0, w - 1)
|
166 |
+
xmax = np.clip(xmax, 0, w - 1)
|
167 |
+
yy = np.random.choice(h_axis, size=2)
|
168 |
+
ymin = np.min(yy) - pad_h
|
169 |
+
ymax = np.max(yy) - pad_h
|
170 |
+
ymin = np.clip(ymin, 0, h - 1)
|
171 |
+
ymax = np.clip(ymax, 0, h - 1)
|
172 |
+
if xmax - xmin < self.min_crop_size or ymax - ymin < self.min_crop_size:
|
173 |
+
continue
|
174 |
+
if polys.shape[0] != 0:
|
175 |
+
poly_axis_in_area = (
|
176 |
+
(polys[:, :, 0] >= xmin)
|
177 |
+
& (polys[:, :, 0] <= xmax)
|
178 |
+
& (polys[:, :, 1] >= ymin)
|
179 |
+
& (polys[:, :, 1] <= ymax)
|
180 |
+
)
|
181 |
+
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
182 |
+
else:
|
183 |
+
selected_polys = []
|
184 |
+
if len(selected_polys) == 0:
|
185 |
+
# no text in this area
|
186 |
+
if crop_background:
|
187 |
+
txts_tmp = []
|
188 |
+
for selected_poly in selected_polys:
|
189 |
+
txts_tmp.append(txts[selected_poly])
|
190 |
+
txts = txts_tmp
|
191 |
+
return (
|
192 |
+
im[ymin : ymax + 1, xmin : xmax + 1, :],
|
193 |
+
polys[selected_polys],
|
194 |
+
tags[selected_polys],
|
195 |
+
hv_tags[selected_polys],
|
196 |
+
txts,
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
continue
|
200 |
+
im = im[ymin : ymax + 1, xmin : xmax + 1, :]
|
201 |
+
polys = polys[selected_polys]
|
202 |
+
tags = tags[selected_polys]
|
203 |
+
hv_tags = hv_tags[selected_polys]
|
204 |
+
txts_tmp = []
|
205 |
+
for selected_poly in selected_polys:
|
206 |
+
txts_tmp.append(txts[selected_poly])
|
207 |
+
txts = txts_tmp
|
208 |
+
polys[:, :, 0] -= xmin
|
209 |
+
polys[:, :, 1] -= ymin
|
210 |
+
return im, polys, tags, hv_tags, txts
|
211 |
+
|
212 |
+
return im, polys, tags, hv_tags, txts
|
213 |
+
|
214 |
+
def fit_and_gather_tcl_points_v2(
|
215 |
+
self,
|
216 |
+
min_area_quad,
|
217 |
+
poly,
|
218 |
+
max_h,
|
219 |
+
max_w,
|
220 |
+
fixed_point_num=64,
|
221 |
+
img_id=0,
|
222 |
+
reference_height=3,
|
223 |
+
):
|
224 |
+
"""
|
225 |
+
Find the center point of poly as key_points, then fit and gather.
|
226 |
+
"""
|
227 |
+
key_point_xys = []
|
228 |
+
point_num = poly.shape[0]
|
229 |
+
for idx in range(point_num // 2):
|
230 |
+
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
|
231 |
+
key_point_xys.append(center_point)
|
232 |
+
|
233 |
+
tmp_image = np.zeros(
|
234 |
+
shape=(
|
235 |
+
max_h,
|
236 |
+
max_w,
|
237 |
+
),
|
238 |
+
dtype="float32",
|
239 |
+
)
|
240 |
+
cv2.polylines(tmp_image, [np.array(key_point_xys).astype("int32")], False, 1.0)
|
241 |
+
ys, xs = np.where(tmp_image > 0)
|
242 |
+
xy_text = np.array(list(zip(xs, ys)), dtype="float32")
|
243 |
+
|
244 |
+
left_center_pt = ((min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
|
245 |
+
right_center_pt = ((min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
|
246 |
+
proj_unit_vec = (right_center_pt - left_center_pt) / (
|
247 |
+
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
|
248 |
+
)
|
249 |
+
proj_unit_vec_tile = np.tile(proj_unit_vec, (xy_text.shape[0], 1)) # (n, 2)
|
250 |
+
left_center_pt_tile = np.tile(left_center_pt, (xy_text.shape[0], 1)) # (n, 2)
|
251 |
+
xy_text_to_left_center = xy_text - left_center_pt_tile
|
252 |
+
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
|
253 |
+
xy_text = xy_text[np.argsort(proj_value)]
|
254 |
+
|
255 |
+
# convert to np and keep the num of point not greater then fixed_point_num
|
256 |
+
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
|
257 |
+
point_num = len(pos_info)
|
258 |
+
if point_num > fixed_point_num:
|
259 |
+
keep_ids = [
|
260 |
+
int((point_num * 1.0 / fixed_point_num) * x)
|
261 |
+
for x in range(fixed_point_num)
|
262 |
+
]
|
263 |
+
pos_info = pos_info[keep_ids, :]
|
264 |
+
|
265 |
+
keep = int(min(len(pos_info), fixed_point_num))
|
266 |
+
if np.random.rand() < 0.2 and reference_height >= 3:
|
267 |
+
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
|
268 |
+
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape([keep, 1])
|
269 |
+
pos_info += random_float
|
270 |
+
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
|
271 |
+
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
|
272 |
+
|
273 |
+
# padding to fixed length
|
274 |
+
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
|
275 |
+
pos_l[:, 0] = np.ones((self.tcl_len,)) * img_id
|
276 |
+
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
|
277 |
+
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
|
278 |
+
pos_m[:keep] = 1.0
|
279 |
+
return pos_l, pos_m
|
280 |
+
|
281 |
+
def generate_direction_map(self, poly_quads, n_char, direction_map):
|
282 |
+
""" """
|
283 |
+
width_list = []
|
284 |
+
height_list = []
|
285 |
+
for quad in poly_quads:
|
286 |
+
quad_w = (
|
287 |
+
np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
|
288 |
+
) / 2.0
|
289 |
+
quad_h = (
|
290 |
+
np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
|
291 |
+
) / 2.0
|
292 |
+
width_list.append(quad_w)
|
293 |
+
height_list.append(quad_h)
|
294 |
+
norm_width = max(sum(width_list) / n_char, 1.0)
|
295 |
+
average_height = max(sum(height_list) / len(height_list), 1.0)
|
296 |
+
k = 1
|
297 |
+
for quad in poly_quads:
|
298 |
+
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
299 |
+
direct_vector = (
|
300 |
+
direct_vector_full
|
301 |
+
/ (np.linalg.norm(direct_vector_full) + 1e-6)
|
302 |
+
* norm_width
|
303 |
+
)
|
304 |
+
direction_label = tuple(
|
305 |
+
map(float, [direct_vector[0], direct_vector[1], 1.0 / average_height])
|
306 |
+
)
|
307 |
+
cv2.fillPoly(
|
308 |
+
direction_map,
|
309 |
+
quad.round().astype(np.int32)[np.newaxis, :, :],
|
310 |
+
direction_label,
|
311 |
+
)
|
312 |
+
k += 1
|
313 |
+
return direction_map
|
314 |
+
|
315 |
+
def calculate_average_height(self, poly_quads):
|
316 |
+
""" """
|
317 |
+
height_list = []
|
318 |
+
for quad in poly_quads:
|
319 |
+
quad_h = (
|
320 |
+
np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
|
321 |
+
) / 2.0
|
322 |
+
height_list.append(quad_h)
|
323 |
+
average_height = max(sum(height_list) / len(height_list), 1.0)
|
324 |
+
return average_height
|
325 |
+
|
326 |
+
def generate_tcl_ctc_label(
|
327 |
+
self,
|
328 |
+
h,
|
329 |
+
w,
|
330 |
+
polys,
|
331 |
+
tags,
|
332 |
+
text_strs,
|
333 |
+
ds_ratio,
|
334 |
+
tcl_ratio=0.3,
|
335 |
+
shrink_ratio_of_width=0.15,
|
336 |
+
):
|
337 |
+
"""
|
338 |
+
Generate polygon.
|
339 |
+
"""
|
340 |
+
score_map_big = np.zeros(
|
341 |
+
(
|
342 |
+
h,
|
343 |
+
w,
|
344 |
+
),
|
345 |
+
dtype=np.float32,
|
346 |
+
)
|
347 |
+
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
348 |
+
polys = polys * ds_ratio
|
349 |
+
|
350 |
+
score_map = np.zeros(
|
351 |
+
(
|
352 |
+
h,
|
353 |
+
w,
|
354 |
+
),
|
355 |
+
dtype=np.float32,
|
356 |
+
)
|
357 |
+
score_label_map = np.zeros(
|
358 |
+
(
|
359 |
+
h,
|
360 |
+
w,
|
361 |
+
),
|
362 |
+
dtype=np.float32,
|
363 |
+
)
|
364 |
+
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
365 |
+
training_mask = np.ones(
|
366 |
+
(
|
367 |
+
h,
|
368 |
+
w,
|
369 |
+
),
|
370 |
+
dtype=np.float32,
|
371 |
+
)
|
372 |
+
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
373 |
+
[1, 1, 3]
|
374 |
+
).astype(np.float32)
|
375 |
+
|
376 |
+
label_idx = 0
|
377 |
+
score_label_map_text_label_list = []
|
378 |
+
pos_list, pos_mask, label_list = [], [], []
|
379 |
+
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
380 |
+
poly = poly_tag[0]
|
381 |
+
tag = poly_tag[1]
|
382 |
+
|
383 |
+
# generate min_area_quad
|
384 |
+
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
385 |
+
min_area_quad_h = 0.5 * (
|
386 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[3])
|
387 |
+
+ np.linalg.norm(min_area_quad[1] - min_area_quad[2])
|
388 |
+
)
|
389 |
+
min_area_quad_w = 0.5 * (
|
390 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[1])
|
391 |
+
+ np.linalg.norm(min_area_quad[2] - min_area_quad[3])
|
392 |
+
)
|
393 |
+
|
394 |
+
if (
|
395 |
+
min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio
|
396 |
+
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio
|
397 |
+
):
|
398 |
+
continue
|
399 |
+
|
400 |
+
if tag:
|
401 |
+
cv2.fillPoly(
|
402 |
+
training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15
|
403 |
+
)
|
404 |
+
else:
|
405 |
+
text_label = text_strs[poly_idx]
|
406 |
+
text_label = self.prepare_text_label(text_label, self.Lexicon_Table)
|
407 |
+
|
408 |
+
text_label_index_list = [
|
409 |
+
[self.Lexicon_Table.index(c_)]
|
410 |
+
for c_ in text_label
|
411 |
+
if c_ in self.Lexicon_Table
|
412 |
+
]
|
413 |
+
if len(text_label_index_list) < 1:
|
414 |
+
continue
|
415 |
+
|
416 |
+
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
417 |
+
tcl_quads = self.poly2quads(tcl_poly)
|
418 |
+
poly_quads = self.poly2quads(poly)
|
419 |
+
|
420 |
+
stcl_quads, quad_index = self.shrink_poly_along_width(
|
421 |
+
tcl_quads,
|
422 |
+
shrink_ratio_of_width=shrink_ratio_of_width,
|
423 |
+
expand_height_ratio=1.0 / tcl_ratio,
|
424 |
+
)
|
425 |
+
|
426 |
+
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
|
427 |
+
cv2.fillPoly(
|
428 |
+
score_map_big, np.round(stcl_quads / ds_ratio).astype(np.int32), 1.0
|
429 |
+
)
|
430 |
+
|
431 |
+
for idx, quad in enumerate(stcl_quads):
|
432 |
+
quad_mask = np.zeros((h, w), dtype=np.float32)
|
433 |
+
quad_mask = cv2.fillPoly(
|
434 |
+
quad_mask,
|
435 |
+
np.round(quad[np.newaxis, :, :]).astype(np.int32),
|
436 |
+
1.0,
|
437 |
+
)
|
438 |
+
tbo_map = self.gen_quad_tbo(
|
439 |
+
poly_quads[quad_index[idx]], quad_mask, tbo_map
|
440 |
+
)
|
441 |
+
|
442 |
+
# score label map and score_label_map_text_label_list for refine
|
443 |
+
if label_idx == 0:
|
444 |
+
text_pos_list_ = [
|
445 |
+
[len(self.Lexicon_Table)],
|
446 |
+
]
|
447 |
+
score_label_map_text_label_list.append(text_pos_list_)
|
448 |
+
|
449 |
+
label_idx += 1
|
450 |
+
cv2.fillPoly(
|
451 |
+
score_label_map, np.round(poly_quads).astype(np.int32), label_idx
|
452 |
+
)
|
453 |
+
score_label_map_text_label_list.append(text_label_index_list)
|
454 |
+
|
455 |
+
# direction info, fix-me
|
456 |
+
n_char = len(text_label_index_list)
|
457 |
+
direction_map = self.generate_direction_map(
|
458 |
+
poly_quads, n_char, direction_map
|
459 |
+
)
|
460 |
+
|
461 |
+
# pos info
|
462 |
+
average_shrink_height = self.calculate_average_height(stcl_quads)
|
463 |
+
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
464 |
+
min_area_quad,
|
465 |
+
poly,
|
466 |
+
max_h=h,
|
467 |
+
max_w=w,
|
468 |
+
fixed_point_num=64,
|
469 |
+
img_id=self.img_id,
|
470 |
+
reference_height=average_shrink_height,
|
471 |
+
)
|
472 |
+
|
473 |
+
label_l = text_label_index_list
|
474 |
+
if len(text_label_index_list) < 2:
|
475 |
+
continue
|
476 |
+
|
477 |
+
pos_list.append(pos_l)
|
478 |
+
pos_mask.append(pos_m)
|
479 |
+
label_list.append(label_l)
|
480 |
+
|
481 |
+
# use big score_map for smooth tcl lines
|
482 |
+
score_map_big_resized = cv2.resize(
|
483 |
+
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio
|
484 |
+
)
|
485 |
+
score_map = np.array(score_map_big_resized > 1e-3, dtype="float32")
|
486 |
+
|
487 |
+
return (
|
488 |
+
score_map,
|
489 |
+
score_label_map,
|
490 |
+
tbo_map,
|
491 |
+
direction_map,
|
492 |
+
training_mask,
|
493 |
+
pos_list,
|
494 |
+
pos_mask,
|
495 |
+
label_list,
|
496 |
+
score_label_map_text_label_list,
|
497 |
+
)
|
498 |
+
|
499 |
+
def adjust_point(self, poly):
|
500 |
+
"""
|
501 |
+
adjust point order.
|
502 |
+
"""
|
503 |
+
point_num = poly.shape[0]
|
504 |
+
if point_num == 4:
|
505 |
+
len_1 = np.linalg.norm(poly[0] - poly[1])
|
506 |
+
len_2 = np.linalg.norm(poly[1] - poly[2])
|
507 |
+
len_3 = np.linalg.norm(poly[2] - poly[3])
|
508 |
+
len_4 = np.linalg.norm(poly[3] - poly[0])
|
509 |
+
|
510 |
+
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
511 |
+
poly = poly[[1, 2, 3, 0], :]
|
512 |
+
|
513 |
+
elif point_num > 4:
|
514 |
+
vector_1 = poly[0] - poly[1]
|
515 |
+
vector_2 = poly[1] - poly[2]
|
516 |
+
cos_theta = np.dot(vector_1, vector_2) / (
|
517 |
+
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6
|
518 |
+
)
|
519 |
+
theta = np.arccos(np.round(cos_theta, decimals=4))
|
520 |
+
|
521 |
+
if abs(theta) > (70 / 180 * math.pi):
|
522 |
+
index = list(range(1, point_num)) + [0]
|
523 |
+
poly = poly[np.array(index), :]
|
524 |
+
return poly
|
525 |
+
|
526 |
+
def gen_min_area_quad_from_poly(self, poly):
|
527 |
+
"""
|
528 |
+
Generate min area quad from poly.
|
529 |
+
"""
|
530 |
+
point_num = poly.shape[0]
|
531 |
+
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
532 |
+
if point_num == 4:
|
533 |
+
min_area_quad = poly
|
534 |
+
center_point = np.sum(poly, axis=0) / 4
|
535 |
+
else:
|
536 |
+
rect = cv2.minAreaRect(
|
537 |
+
poly.astype(np.int32)
|
538 |
+
) # (center (x,y), (width, height), angle of rotation)
|
539 |
+
center_point = rect[0]
|
540 |
+
box = np.array(cv2.boxPoints(rect))
|
541 |
+
|
542 |
+
first_point_idx = 0
|
543 |
+
min_dist = 1e4
|
544 |
+
for i in range(4):
|
545 |
+
dist = (
|
546 |
+
np.linalg.norm(box[(i + 0) % 4] - poly[0])
|
547 |
+
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
|
548 |
+
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
|
549 |
+
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
550 |
+
)
|
551 |
+
if dist < min_dist:
|
552 |
+
min_dist = dist
|
553 |
+
first_point_idx = i
|
554 |
+
|
555 |
+
for i in range(4):
|
556 |
+
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
557 |
+
|
558 |
+
return min_area_quad, center_point
|
559 |
+
|
560 |
+
def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
|
561 |
+
"""
|
562 |
+
Generate shrink_quad_along_width.
|
563 |
+
"""
|
564 |
+
ratio_pair = np.array(
|
565 |
+
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32
|
566 |
+
)
|
567 |
+
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
568 |
+
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
569 |
+
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
570 |
+
|
571 |
+
def shrink_poly_along_width(
|
572 |
+
self, quads, shrink_ratio_of_width, expand_height_ratio=1.0
|
573 |
+
):
|
574 |
+
"""
|
575 |
+
shrink poly with given length.
|
576 |
+
"""
|
577 |
+
upper_edge_list = []
|
578 |
+
|
579 |
+
def get_cut_info(edge_len_list, cut_len):
|
580 |
+
for idx, edge_len in enumerate(edge_len_list):
|
581 |
+
cut_len -= edge_len
|
582 |
+
if cut_len <= 0.000001:
|
583 |
+
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
584 |
+
return idx, ratio
|
585 |
+
|
586 |
+
for quad in quads:
|
587 |
+
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
588 |
+
upper_edge_list.append(upper_edge_len)
|
589 |
+
|
590 |
+
# length of left edge and right edge.
|
591 |
+
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
|
592 |
+
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
|
593 |
+
|
594 |
+
shrink_length = (
|
595 |
+
min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
|
596 |
+
)
|
597 |
+
# shrinking length
|
598 |
+
upper_len_left = shrink_length
|
599 |
+
upper_len_right = sum(upper_edge_list) - shrink_length
|
600 |
+
|
601 |
+
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
602 |
+
left_quad = self.shrink_quad_along_width(
|
603 |
+
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1
|
604 |
+
)
|
605 |
+
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
606 |
+
right_quad = self.shrink_quad_along_width(
|
607 |
+
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio
|
608 |
+
)
|
609 |
+
|
610 |
+
out_quad_list = []
|
611 |
+
if left_idx == right_idx:
|
612 |
+
out_quad_list.append(
|
613 |
+
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]]
|
614 |
+
)
|
615 |
+
else:
|
616 |
+
out_quad_list.append(left_quad)
|
617 |
+
for idx in range(left_idx + 1, right_idx):
|
618 |
+
out_quad_list.append(quads[idx])
|
619 |
+
out_quad_list.append(right_quad)
|
620 |
+
|
621 |
+
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
622 |
+
|
623 |
+
def prepare_text_label(self, label_str, Lexicon_Table):
|
624 |
+
"""
|
625 |
+
Prepare text lablel by given Lexicon_Table.
|
626 |
+
"""
|
627 |
+
if len(Lexicon_Table) == 36:
|
628 |
+
return label_str.lower()
|
629 |
+
else:
|
630 |
+
return label_str
|
631 |
+
|
632 |
+
def vector_angle(self, A, B):
|
633 |
+
"""
|
634 |
+
Calculate the angle between vector AB and x-axis positive direction.
|
635 |
+
"""
|
636 |
+
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
637 |
+
return np.arctan2(*AB)
|
638 |
+
|
639 |
+
def theta_line_cross_point(self, theta, point):
|
640 |
+
"""
|
641 |
+
Calculate the line through given point and angle in ax + by + c =0 form.
|
642 |
+
"""
|
643 |
+
x, y = point
|
644 |
+
cos = np.cos(theta)
|
645 |
+
sin = np.sin(theta)
|
646 |
+
return [sin, -cos, cos * y - sin * x]
|
647 |
+
|
648 |
+
def line_cross_two_point(self, A, B):
|
649 |
+
"""
|
650 |
+
Calculate the line through given point A and B in ax + by + c =0 form.
|
651 |
+
"""
|
652 |
+
angle = self.vector_angle(A, B)
|
653 |
+
return self.theta_line_cross_point(angle, A)
|
654 |
+
|
655 |
+
def average_angle(self, poly):
|
656 |
+
"""
|
657 |
+
Calculate the average angle between left and right edge in given poly.
|
658 |
+
"""
|
659 |
+
p0, p1, p2, p3 = poly
|
660 |
+
angle30 = self.vector_angle(p3, p0)
|
661 |
+
angle21 = self.vector_angle(p2, p1)
|
662 |
+
return (angle30 + angle21) / 2
|
663 |
+
|
664 |
+
def line_cross_point(self, line1, line2):
|
665 |
+
"""
|
666 |
+
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
667 |
+
"""
|
668 |
+
a1, b1, c1 = line1
|
669 |
+
a2, b2, c2 = line2
|
670 |
+
d = a1 * b2 - a2 * b1
|
671 |
+
|
672 |
+
if d == 0:
|
673 |
+
print("Cross point does not exist")
|
674 |
+
return np.array([0, 0], dtype=np.float32)
|
675 |
+
else:
|
676 |
+
x = (b1 * c2 - b2 * c1) / d
|
677 |
+
y = (a2 * c1 - a1 * c2) / d
|
678 |
+
|
679 |
+
return np.array([x, y], dtype=np.float32)
|
680 |
+
|
681 |
+
def quad2tcl(self, poly, ratio):
|
682 |
+
"""
|
683 |
+
Generate center line by poly clock-wise point. (4, 2)
|
684 |
+
"""
|
685 |
+
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
686 |
+
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
687 |
+
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
688 |
+
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
689 |
+
|
690 |
+
def poly2tcl(self, poly, ratio):
|
691 |
+
"""
|
692 |
+
Generate center line by poly clock-wise point.
|
693 |
+
"""
|
694 |
+
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
695 |
+
tcl_poly = np.zeros_like(poly)
|
696 |
+
point_num = poly.shape[0]
|
697 |
+
|
698 |
+
for idx in range(point_num // 2):
|
699 |
+
point_pair = (
|
700 |
+
poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
|
701 |
+
)
|
702 |
+
tcl_poly[idx] = point_pair[0]
|
703 |
+
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
704 |
+
return tcl_poly
|
705 |
+
|
706 |
+
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
707 |
+
"""
|
708 |
+
Generate tbo_map for give quad.
|
709 |
+
"""
|
710 |
+
# upper and lower line function: ax + by + c = 0;
|
711 |
+
up_line = self.line_cross_two_point(quad[0], quad[1])
|
712 |
+
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
713 |
+
|
714 |
+
quad_h = 0.5 * (
|
715 |
+
np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
|
716 |
+
)
|
717 |
+
quad_w = 0.5 * (
|
718 |
+
np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
|
719 |
+
)
|
720 |
+
|
721 |
+
# average angle of left and right line.
|
722 |
+
angle = self.average_angle(quad)
|
723 |
+
|
724 |
+
xy_in_poly = np.argwhere(tcl_mask == 1)
|
725 |
+
for y, x in xy_in_poly:
|
726 |
+
point = (x, y)
|
727 |
+
line = self.theta_line_cross_point(angle, point)
|
728 |
+
cross_point_upper = self.line_cross_point(up_line, line)
|
729 |
+
cross_point_lower = self.line_cross_point(lower_line, line)
|
730 |
+
##FIX, offset reverse
|
731 |
+
upper_offset_x, upper_offset_y = cross_point_upper - point
|
732 |
+
lower_offset_x, lower_offset_y = cross_point_lower - point
|
733 |
+
tbo_map[y, x, 0] = upper_offset_y
|
734 |
+
tbo_map[y, x, 1] = upper_offset_x
|
735 |
+
tbo_map[y, x, 2] = lower_offset_y
|
736 |
+
tbo_map[y, x, 3] = lower_offset_x
|
737 |
+
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
738 |
+
return tbo_map
|
739 |
+
|
740 |
+
def poly2quads(self, poly):
|
741 |
+
"""
|
742 |
+
Split poly into quads.
|
743 |
+
"""
|
744 |
+
quad_list = []
|
745 |
+
point_num = poly.shape[0]
|
746 |
+
|
747 |
+
# point pair
|
748 |
+
point_pair_list = []
|
749 |
+
for idx in range(point_num // 2):
|
750 |
+
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
751 |
+
point_pair_list.append(point_pair)
|
752 |
+
|
753 |
+
quad_num = point_num // 2 - 1
|
754 |
+
for idx in range(quad_num):
|
755 |
+
# reshape and adjust to clock-wise
|
756 |
+
quad_list.append(
|
757 |
+
(np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]
|
758 |
+
)
|
759 |
+
|
760 |
+
return np.array(quad_list)
|
761 |
+
|
762 |
+
def rotate_im_poly(self, im, text_polys):
|
763 |
+
"""
|
764 |
+
rotate image with 90 / 180 / 270 degre
|
765 |
+
"""
|
766 |
+
im_w, im_h = im.shape[1], im.shape[0]
|
767 |
+
dst_im = im.copy()
|
768 |
+
dst_polys = []
|
769 |
+
rand_degree_ratio = np.random.rand()
|
770 |
+
rand_degree_cnt = 1
|
771 |
+
if rand_degree_ratio > 0.5:
|
772 |
+
rand_degree_cnt = 3
|
773 |
+
for i in range(rand_degree_cnt):
|
774 |
+
dst_im = np.rot90(dst_im)
|
775 |
+
rot_degree = -90 * rand_degree_cnt
|
776 |
+
rot_angle = rot_degree * math.pi / 180.0
|
777 |
+
n_poly = text_polys.shape[0]
|
778 |
+
cx, cy = 0.5 * im_w, 0.5 * im_h
|
779 |
+
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
780 |
+
for i in range(n_poly):
|
781 |
+
wordBB = text_polys[i]
|
782 |
+
poly = []
|
783 |
+
for j in range(4): # 16->4
|
784 |
+
sx, sy = wordBB[j][0], wordBB[j][1]
|
785 |
+
dx = (
|
786 |
+
math.cos(rot_angle) * (sx - cx)
|
787 |
+
- math.sin(rot_angle) * (sy - cy)
|
788 |
+
+ ncx
|
789 |
+
)
|
790 |
+
dy = (
|
791 |
+
math.sin(rot_angle) * (sx - cx)
|
792 |
+
+ math.cos(rot_angle) * (sy - cy)
|
793 |
+
+ ncy
|
794 |
+
)
|
795 |
+
poly.append([dx, dy])
|
796 |
+
dst_polys.append(poly)
|
797 |
+
return dst_im, np.array(dst_polys, dtype=np.float32)
|
798 |
+
|
799 |
+
def __call__(self, data):
|
800 |
+
input_size = 512
|
801 |
+
im = data["image"]
|
802 |
+
text_polys = data["polys"]
|
803 |
+
text_tags = data["ignore_tags"]
|
804 |
+
text_strs = data["texts"]
|
805 |
+
h, w, _ = im.shape
|
806 |
+
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
807 |
+
text_polys, text_tags, (h, w)
|
808 |
+
)
|
809 |
+
if text_polys.shape[0] <= 0:
|
810 |
+
return None
|
811 |
+
# set aspect ratio and keep area fix
|
812 |
+
asp_scales = np.arange(1.0, 1.55, 0.1)
|
813 |
+
asp_scale = np.random.choice(asp_scales)
|
814 |
+
if np.random.rand() < 0.5:
|
815 |
+
asp_scale = 1.0 / asp_scale
|
816 |
+
asp_scale = math.sqrt(asp_scale)
|
817 |
+
|
818 |
+
asp_wx = asp_scale
|
819 |
+
asp_hy = 1.0 / asp_scale
|
820 |
+
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
821 |
+
text_polys[:, :, 0] *= asp_wx
|
822 |
+
text_polys[:, :, 1] *= asp_hy
|
823 |
+
|
824 |
+
h, w, _ = im.shape
|
825 |
+
if max(h, w) > 2048:
|
826 |
+
rd_scale = 2048.0 / max(h, w)
|
827 |
+
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
828 |
+
text_polys *= rd_scale
|
829 |
+
h, w, _ = im.shape
|
830 |
+
if min(h, w) < 16:
|
831 |
+
return None
|
832 |
+
|
833 |
+
# no background
|
834 |
+
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
|
835 |
+
im, text_polys, text_tags, hv_tags, text_strs, crop_background=False
|
836 |
+
)
|
837 |
+
|
838 |
+
if text_polys.shape[0] == 0:
|
839 |
+
return None
|
840 |
+
# # continue for all ignore case
|
841 |
+
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
842 |
+
return None
|
843 |
+
new_h, new_w, _ = im.shape
|
844 |
+
if (new_h is None) or (new_w is None):
|
845 |
+
return None
|
846 |
+
# resize image
|
847 |
+
std_ratio = float(input_size) / max(new_w, new_h)
|
848 |
+
rand_scales = np.array(
|
849 |
+
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]
|
850 |
+
)
|
851 |
+
rz_scale = std_ratio * np.random.choice(rand_scales)
|
852 |
+
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
853 |
+
text_polys[:, :, 0] *= rz_scale
|
854 |
+
text_polys[:, :, 1] *= rz_scale
|
855 |
+
|
856 |
+
# add gaussian blur
|
857 |
+
if np.random.rand() < 0.1 * 0.5:
|
858 |
+
ks = np.random.permutation(5)[0] + 1
|
859 |
+
ks = int(ks / 2) * 2 + 1
|
860 |
+
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
861 |
+
# add brighter
|
862 |
+
if np.random.rand() < 0.1 * 0.5:
|
863 |
+
im = im * (1.0 + np.random.rand() * 0.5)
|
864 |
+
im = np.clip(im, 0.0, 255.0)
|
865 |
+
# add darker
|
866 |
+
if np.random.rand() < 0.1 * 0.5:
|
867 |
+
im = im * (1.0 - np.random.rand() * 0.5)
|
868 |
+
im = np.clip(im, 0.0, 255.0)
|
869 |
+
|
870 |
+
# Padding the im to [input_size, input_size]
|
871 |
+
new_h, new_w, _ = im.shape
|
872 |
+
if min(new_w, new_h) < input_size * 0.5:
|
873 |
+
return None
|
874 |
+
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
|
875 |
+
im_padded[:, :, 2] = 0.485 * 255
|
876 |
+
im_padded[:, :, 1] = 0.456 * 255
|
877 |
+
im_padded[:, :, 0] = 0.406 * 255
|
878 |
+
|
879 |
+
# Random the start position
|
880 |
+
del_h = input_size - new_h
|
881 |
+
del_w = input_size - new_w
|
882 |
+
sh, sw = 0, 0
|
883 |
+
if del_h > 1:
|
884 |
+
sh = int(np.random.rand() * del_h)
|
885 |
+
if del_w > 1:
|
886 |
+
sw = int(np.random.rand() * del_w)
|
887 |
+
|
888 |
+
# Padding
|
889 |
+
im_padded[sh : sh + new_h, sw : sw + new_w, :] = im.copy()
|
890 |
+
text_polys[:, :, 0] += sw
|
891 |
+
text_polys[:, :, 1] += sh
|
892 |
+
|
893 |
+
(
|
894 |
+
score_map,
|
895 |
+
score_label_map,
|
896 |
+
border_map,
|
897 |
+
direction_map,
|
898 |
+
training_mask,
|
899 |
+
pos_list,
|
900 |
+
pos_mask,
|
901 |
+
label_list,
|
902 |
+
score_label_map_text_label,
|
903 |
+
) = self.generate_tcl_ctc_label(
|
904 |
+
input_size, input_size, text_polys, text_tags, text_strs, 0.25
|
905 |
+
)
|
906 |
+
if len(label_list) <= 0: # eliminate negative samples
|
907 |
+
return None
|
908 |
+
pos_list_temp = np.zeros([64, 3])
|
909 |
+
pos_mask_temp = np.zeros([64, 1])
|
910 |
+
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
|
911 |
+
|
912 |
+
for i, label in enumerate(label_list):
|
913 |
+
n = len(label)
|
914 |
+
if n > self.max_text_length:
|
915 |
+
label_list[i] = label[: self.max_text_length]
|
916 |
+
continue
|
917 |
+
while n < self.max_text_length:
|
918 |
+
label.append([self.pad_num])
|
919 |
+
n += 1
|
920 |
+
|
921 |
+
for i in range(len(label_list)):
|
922 |
+
label_list[i] = np.array(label_list[i])
|
923 |
+
|
924 |
+
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
|
925 |
+
return None
|
926 |
+
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
|
927 |
+
pos_list.append(pos_list_temp)
|
928 |
+
pos_mask.append(pos_mask_temp)
|
929 |
+
label_list.append(label_list_temp)
|
930 |
+
|
931 |
+
if self.img_id == self.batch_size - 1:
|
932 |
+
self.img_id = 0
|
933 |
+
else:
|
934 |
+
self.img_id += 1
|
935 |
+
|
936 |
+
im_padded[:, :, 2] -= 0.485 * 255
|
937 |
+
im_padded[:, :, 1] -= 0.456 * 255
|
938 |
+
im_padded[:, :, 0] -= 0.406 * 255
|
939 |
+
im_padded[:, :, 2] /= 255.0 * 0.229
|
940 |
+
im_padded[:, :, 1] /= 255.0 * 0.224
|
941 |
+
im_padded[:, :, 0] /= 255.0 * 0.225
|
942 |
+
im_padded = im_padded.transpose((2, 0, 1))
|
943 |
+
images = im_padded[::-1, :, :]
|
944 |
+
tcl_maps = score_map[np.newaxis, :, :]
|
945 |
+
tcl_label_maps = score_label_map[np.newaxis, :, :]
|
946 |
+
border_maps = border_map.transpose((2, 0, 1))
|
947 |
+
direction_maps = direction_map.transpose((2, 0, 1))
|
948 |
+
training_masks = training_mask[np.newaxis, :, :]
|
949 |
+
pos_list = np.array(pos_list)
|
950 |
+
pos_mask = np.array(pos_mask)
|
951 |
+
label_list = np.array(label_list)
|
952 |
+
data["images"] = images
|
953 |
+
data["tcl_maps"] = tcl_maps
|
954 |
+
data["tcl_label_maps"] = tcl_label_maps
|
955 |
+
data["border_maps"] = border_maps
|
956 |
+
data["direction_maps"] = direction_maps
|
957 |
+
data["training_masks"] = training_masks
|
958 |
+
data["label_list"] = label_list
|
959 |
+
data["pos_list"] = pos_list
|
960 |
+
data["pos_mask"] = pos_mask
|
961 |
+
return data
|