|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import traceback |
|
|
|
import requests |
|
import numpy as np |
|
import gradio as gr |
|
import cv2 |
|
|
|
from paddlenlp import Taskflow |
|
from paddlenlp.utils.doc_parser import DocParser |
|
|
|
doc_parser = DocParser() |
|
task_instance = Taskflow("information_extraction", model="uie-x-base", schema="") |
|
|
|
examples = [ |
|
[ |
|
"invoice.jpeg", |
|
"名称;纳税人识别号;开票日期", |
|
], |
|
[ |
|
"custom.jpeg", |
|
"收发货人;进口口岸;进口日期;运输方式;征免性质;境内目的地;运输工具名称;包装种类;件数;合同协议号" |
|
], |
|
[ |
|
"resume.png", |
|
"职位;年龄;学校|时间;学校|专业", |
|
], |
|
] |
|
|
|
example_files = { |
|
"Name;Title;Web Link;Email;Address": "business_card.png", |
|
"Name;DOB;ISS;EXP": "license.jpeg", |
|
"职位;年龄;学校|时间;学校|专业": "resume.png", |
|
"收发货人;进口口岸;进口日期;运输方式;征免性质;境内目的地;运输工具名称;包装种类;件数;合同协议号": "custom.jpeg", |
|
"名称;纳税人识别号;开票日期": "invoice.jpeg", |
|
} |
|
|
|
lang_map = { |
|
"resume.png": "ch", |
|
"custom.jpeg": "ch", |
|
"business_card.png": "en", |
|
"invoice.jpeg": "ch", |
|
"license.jpeg": "en", |
|
} |
|
|
|
def dbc2sbc(s): |
|
rs = "" |
|
for char in s: |
|
code = ord(char) |
|
if code == 0x3000: |
|
code = 0x0020 |
|
else: |
|
code -= 0xfee0 |
|
if not (0x0021 <= code and code <= 0x7e): |
|
rs += char |
|
continue |
|
rs += chr(code) |
|
return rs |
|
|
|
|
|
def process_path(path): |
|
error = None |
|
if path: |
|
try: |
|
images_list = [doc_parser.read_image(path)] |
|
return ( |
|
path, |
|
gr.update(visible=True, value=images_list), |
|
gr.update(visible=True), |
|
gr.update(visible=False, value=None), |
|
gr.update(visible=False, value=None), |
|
None, |
|
) |
|
except Exception as e: |
|
traceback.print_exc() |
|
error = str(e) |
|
return ( |
|
None, |
|
gr.update(visible=False, value=None), |
|
gr.update(visible=False), |
|
gr.update(visible=False, value=None), |
|
gr.update(visible=False, value=None), |
|
gr.update(visible=True, value=error) if error is not None else None, |
|
None, |
|
) |
|
|
|
|
|
def process_upload(file): |
|
if file: |
|
return process_path(file.name) |
|
else: |
|
return ( |
|
None, |
|
gr.update(visible=False, value=None), |
|
gr.update(visible=False), |
|
gr.update(visible=False, value=None), |
|
gr.update(visible=False, value=None), |
|
None, |
|
) |
|
|
|
|
|
def BGR2RGB(img): |
|
pilimg = img.copy() |
|
pilimg[:, :, 0] = img[:, :, 2] |
|
pilimg[:, :, 2] = img[:, :, 0] |
|
return pilimg |
|
|
|
|
|
def np2base64(image_np): |
|
image_np = BGR2RGB(image_np) |
|
image = cv2.imencode('.jpg', image_np)[1] |
|
base64_str = str(base64.b64encode(image))[2:-1] |
|
return base64_str |
|
|
|
|
|
def get_schema(schema_str): |
|
def _is_ch(s): |
|
for ch in s: |
|
if "\u4e00" <= ch <= "\u9fff": |
|
return True |
|
return False |
|
schema_lang = "ch" if _is_ch(schema_str) else "en" |
|
schema = schema_str.split(";") |
|
schema_list = [] |
|
for s in schema: |
|
cand = s.split("|") |
|
if len(cand) == 1: |
|
schema_list.append(cand[0]) |
|
else: |
|
subject = cand[0] |
|
relations = cand[1:] |
|
added = False |
|
for a in schema_list: |
|
if isinstance(a, dict): |
|
if subject in a.keys(): |
|
a[subject].extend(relations) |
|
added = True |
|
break |
|
if not added: |
|
a = {subject: relations} |
|
schema_list.append(a) |
|
return schema_list, schema_lang |
|
|
|
|
|
def run_taskflow(document, schema, argument): |
|
task_instance.set_schema(schema) |
|
|
|
return task_instance({'doc': document}) |
|
|
|
|
|
def process_doc(document, schema, ocr_lang, layout_analysis): |
|
if not schema: |
|
schema = '时间;组织机构;人物' |
|
if document is None: |
|
return None, None |
|
|
|
schema, schema_lang = get_schema(dbc2sbc(schema)) |
|
argument = { |
|
"ocr_lang": ocr_lang, |
|
"schema_lang": schema_lang, |
|
"layout_analysis": layout_analysis |
|
} |
|
prediction = run_taskflow(document, schema, argument)[0] |
|
|
|
img_show = doc_parser.write_image_with_results( |
|
document, |
|
result=prediction, |
|
return_image=True) |
|
img_list = [img_show] |
|
|
|
return ( |
|
gr.update(visible=True, value=img_list), |
|
gr.update(visible=True, value=prediction), |
|
) |
|
|
|
|
|
def load_example_document(img, schema, ocr_lang, layout_analysis): |
|
if img is not None: |
|
document = example_files[schema] |
|
choice = lang_map[document].split("-") |
|
ocr_lang = choice[0] |
|
layout_analysis = False if len(choice) == 1 else True |
|
preview, answer = process_doc(document, schema, ocr_lang, layout_analysis) |
|
return document, schema, preview, gr.update(visible=True), answer |
|
else: |
|
return None, None, None, gr.update(visible=False), None |
|
|
|
|
|
def read_content(file_path: str) -> str: |
|
"""read the content of target file |
|
""" |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
return content |
|
|
|
|
|
CSS = """ |
|
#prompt input { |
|
font-size: 16px; |
|
} |
|
#url-textbox { |
|
padding: 0 !important; |
|
} |
|
#short-upload-box .w-full { |
|
min-height: 10rem !important; |
|
} |
|
/* I think something like this can be used to re-shape |
|
* the table |
|
*/ |
|
/* |
|
.gr-samples-table tr { |
|
display: inline; |
|
} |
|
.gr-samples-table .p-2 { |
|
width: 100px; |
|
} |
|
*/ |
|
#select-a-file { |
|
width: 100%; |
|
} |
|
#file-clear { |
|
padding-top: 2px !important; |
|
padding-bottom: 2px !important; |
|
padding-left: 8px !important; |
|
padding-right: 8px !important; |
|
margin-top: 10px; |
|
} |
|
.gradio-container .gr-button-primary { |
|
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%); |
|
border: 1px solid #B0DCCC; |
|
border-radius: 8px; |
|
color: #1B8700; |
|
} |
|
.gradio-container.dark button#submit-button { |
|
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%); |
|
border: 1px solid #B0DCCC; |
|
border-radius: 8px; |
|
color: #1B8700 |
|
} |
|
table.gr-samples-table tr td { |
|
border: none; |
|
outline: none; |
|
} |
|
table.gr-samples-table tr td:first-of-type { |
|
width: 0%; |
|
} |
|
div#short-upload-box div.absolute { |
|
display: none !important; |
|
} |
|
gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div { |
|
gap: 0px 2%; |
|
} |
|
gradio-app div div div div.w-full, .gradio-app div div div div.w-full { |
|
gap: 0px; |
|
} |
|
gradio-app h2, .gradio-app h2 { |
|
padding-top: 10px; |
|
} |
|
#answer { |
|
overflow-y: scroll; |
|
color: white; |
|
background: #666; |
|
border-color: #666; |
|
font-size: 20px; |
|
font-weight: bold; |
|
} |
|
#answer span { |
|
color: white; |
|
} |
|
#answer textarea { |
|
color:white; |
|
background: #777; |
|
border-color: #777; |
|
font-size: 18px; |
|
} |
|
#url-error input { |
|
color: red; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=CSS) as demo: |
|
gr.HTML(read_content("header.html")) |
|
gr.Markdown( |
|
"**UIE-X 🧾 🎓** is a universal information extraction engine which supports both document and text inputs. It is powered by BAIDU and released on PaddleNLP. " |
|
"Our extraction target(schema) can be set in natural language without limitation, and it also supports most extraction tasks. " |
|
"The model performs well on zero-shot and few-shot settings. Moreover, on PaddleNLP, we provide a comprehensive and easy-to-use fine-tuning customization workflow." |
|
"For more details, please visit the [GitHub](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction)" |
|
) |
|
|
|
document = gr.Variable() |
|
is_text = gr.Variable() |
|
example_schema = gr.Textbox(visible=False) |
|
example_image = gr.Image(visible=False) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Markdown("## 1. 选择文件 / Select a file 📄", elem_id="select-a-file") |
|
img_clear_button = gr.Button( |
|
"Clear", variant="secondary", elem_id="file-clear", visible=False |
|
) |
|
image = gr.Gallery(visible=False) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
with gr.Row(): |
|
url = gr.Textbox( |
|
show_label=False, |
|
placeholder="URL", |
|
lines=1, |
|
max_lines=1, |
|
elem_id="url-textbox", |
|
) |
|
submit = gr.Button("Get") |
|
url_error = gr.Textbox( |
|
visible=False, |
|
elem_id="url-error", |
|
max_lines=1, |
|
interactive=False, |
|
label="Error", |
|
) |
|
gr.Markdown("— or —") |
|
upload = gr.File(label=None, interactive=True, elem_id="short-upload-box") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[example_image, example_schema], |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("## 2. 信息抽取 / Information extraction ℹ️ ") |
|
gr.Markdown("### 👉 设置schema") |
|
gr.Markdown("实体抽取:实体类别之间以';'分割,例如 **人物;组织机构**") |
|
gr.Markdown("关系抽取:需配置主体和关系类别,中间以'|'分割,例如 **人物|出生时间;人物|邮箱**") |
|
gr.Markdown("### 👉 Set a schema") |
|
gr.Markdown("Entity extraction: entity label should be separated by ';', e.g. **Person;Organization**") |
|
gr.Markdown("Relation extraction: set the subject and relation type, separated by '|', e.g. **Person|Date;Person|Email**") |
|
gr.Markdown("### 💪 模型定制 / Model customization") |
|
gr.Markdown("我们建议通过[数据标注+微调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction/document)的流程进一步增强模型在特定场景的效果") |
|
gr.Markdown("We recommend to further improve the extraction performance in specific domain through the process of [data annotation & fine-tuning](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction/document)") |
|
|
|
schema = gr.Textbox( |
|
label="Schema", |
|
placeholder="e.g. Name|Company;Name|Position;Email;Phone Number", |
|
lines=1, |
|
max_lines=1, |
|
) |
|
|
|
ocr_lang = gr.Radio( |
|
choices=["ch", "en"], |
|
value="en", |
|
label="OCR语言 / OCR Language (Please choose ch for Chinese images.)", |
|
) |
|
|
|
layout_analysis = gr.Radio( |
|
choices=["yes", "no"], |
|
value="no", |
|
label="版面分析 / Layout analysis (Better extraction for multi-line text)", |
|
) |
|
|
|
with gr.Row(): |
|
clear_button = gr.Button("Clear", variant="secondary") |
|
submit_button = gr.Button( |
|
"Submit", variant="primary", elem_id="submit-button" |
|
) |
|
with gr.Column(): |
|
output = gr.JSON(label="Output", visible=False) |
|
|
|
for cb in [img_clear_button, clear_button]: |
|
cb.click( |
|
lambda _: ( |
|
gr.update(visible=False, value=None), |
|
None, |
|
gr.update(visible=False, value=None), |
|
gr.update(visible=False), |
|
None, |
|
None, |
|
None, |
|
gr.update(visible=False, value=None), |
|
None, |
|
), |
|
inputs=clear_button, |
|
outputs=[ |
|
image, |
|
document, |
|
output, |
|
img_clear_button, |
|
example_image, |
|
upload, |
|
url, |
|
url_error, |
|
schema, |
|
], |
|
) |
|
|
|
upload.change( |
|
fn=process_upload, |
|
inputs=[upload], |
|
outputs=[document, image, img_clear_button, output, url_error], |
|
) |
|
submit.click( |
|
fn=process_path, |
|
inputs=[url], |
|
outputs=[document, image, img_clear_button, output, url_error], |
|
) |
|
|
|
schema.submit( |
|
fn=process_doc, |
|
inputs=[document, schema, ocr_lang, layout_analysis], |
|
outputs=[image, output], |
|
) |
|
|
|
submit_button.click( |
|
fn=process_doc, |
|
inputs=[document, schema, ocr_lang, layout_analysis], |
|
outputs=[image, output], |
|
) |
|
|
|
example_image.change( |
|
fn=load_example_document, |
|
inputs=[example_image, example_schema, ocr_lang, layout_analysis], |
|
outputs=[document, schema, image, img_clear_button, output], |
|
) |
|
|
|
gr.Markdown("[![Stargazers repo roster for @PaddlePaddle/PaddleNLP](https://reporoster.com/stars/PaddlePaddle/PaddleNLP)](https://github.com/PaddlePaddle/PaddleNLP)") |
|
gr.HTML(read_content("footer.html")) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(enable_queue=False) |
|
|