ankur-bohra's picture
Add basic structure
0d99179
raw
history blame
13.2 kB
import base64
import os
from io import BytesIO
from pathlib import Path
from langchain.schema.output_parser import OutputParserException
import gradio as gr
from PIL import Image
import categories
from categories import Category
from main import process_image, process_pdf
HF_TOKEN = os.getenv("HF_TOKEN")
PDF_IFRAME = """
<div style="border-radius: 10px; width: 100%; overflow: hidden;">
<iframe
src="data:application/pdf;base64,{0}"
width="100%"
height="400"
type="application/pdf">
</iframe>
</div>"""
hf_writer_normal = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "automatic-reimbursement-tool-demo", separate_dirs=False
)
hf_writer_incorrect = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "automatic-reimbursement-tool-demo-incorrect", separate_dirs=False
)
# with open("examples/example1.pdf", "rb") as pdf_file:
# base64_pdf = base64.b64encode(pdf_file.read())
# example_paths = []
# current_file_path = None
# def ignore_examples(function):
# def new_function(*args, **kwargs):
# global example_paths, current_file_path
# if current_file_path not in example_paths:
# return function(*args, **kwargs)
def display_file(input_file):
global current_file_path
current_file_path = input_file.name if input_file else None
if not input_file:
return gr.HTML.update(visible=False), gr.Image.update(visible=False)
if input_file.name.endswith(".pdf"):
with open(input_file.name, "rb") as input_file:
pdf_base64 = base64.b64encode(input_file.read()).decode()
return gr.HTML.update(
PDF_IFRAME.format(pdf_base64), visible=True
), gr.Image.update(visible=False)
else:
# image = Image.open(input_file.name)
return gr.HTML.update(visible=False), gr.Image.update(
input_file.name, visible=True
)
def show_intermediate_outputs(show_intermediate):
if show_intermediate:
return gr.Accordion.update(visible=True)
else:
return gr.Accordion.update(visible=False)
def show_share_contact(share_result):
return gr.Textbox.update(visible=share_result)
def clear_inputs():
return gr.File.update(value=None)
def submit(input_file, old_text):
if not input_file:
gr.Error("Please upload a file to continue!")
return gr.Textbox.update()
# Send change to preprocessed image or to extracted text
if input_file.name.endswith(".pdf"):
text = process_pdf(Path(input_file.name), extract_only=True)
else:
text = process_image(Path(input_file.name), extract_only=True)
return text
def categorize_extracted_text(extracted_text):
category = categories.categorize_text(extracted_text)
# gr.Info(f"Recognized category: {category}")
return category
def extract_from_category(category, extracted_text):
# gr.Info("Received category: " + category)
if not category:
return (
gr.Chatbot.update(None),
gr.JSON.update(None),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
)
category = Category[category]
chain = categories.category_modules[category].chain
formatted_prompt = chain.prompt.format_prompt(
text=extracted_text,
format_instructions=chain.output_parser.get_format_instructions(),
)
result = chain.generate(
input_list=[
{
"text": extracted_text,
"format_instructions": chain.output_parser.get_format_instructions(),
}
]
)
question = f""
if len(formatted_prompt.messages) > 1:
question += f"**System:**\n{formatted_prompt.messages[1].content}"
question += f"\n\n**Human:**\n{formatted_prompt.messages[0].content}"
answer = result.generations[0][0].text
try:
information = chain.output_parser.parse_with_prompt(answer, formatted_prompt)
information = information.json() if information else {}
except OutputParserException as e:
information = {
"error": "Unable to parse chatbot output",
"details": str(e),
"output": e.llm_output,
}
return (
gr.Chatbot.update([[question, answer]]),
gr.JSON.update(information),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
)
def dynamic_auto_flag(flag_method):
def modified_flag_method(share_result, *args, **kwargs):
if share_result:
flag_method(*args, **kwargs)
return modified_flag_method
# def save_example_and_submit(input_file):
# example_paths.append(input_file.name)
# submit(input_file, "")
with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
gr.Markdown("<center><h1>Automatic Reimbursement Tool Demo</h1></center>")
gr.Markdown("<h2>Description</h2>")
gr.Markdown(
"The reimbursement filing process can be time-consuming and cumbersome, causing "
"frustration for faculty members and finance departments. Our project aims to "
"automate the information extraction involved in the process by feeding "
"extracted text to language models such as ChatGPT. This demo showcases the "
"categorization and extraction parts of the pipeline. Categorization is done "
"to identify the relevant details associated with the text, after which "
"extraction is done for those details using a language model."
)
gr.Markdown("<h2>Try it out!</h2>")
with gr.Box() as demo:
with gr.Row():
with gr.Column(variant="panel"):
gr.HTML(
'<div><center style="color:rgb(200, 200, 200);">Input</center></div>'
)
pdf_preview = gr.HTML(label="Preview", show_label=True, visible=False)
image_preview = gr.Image(
label="Preview", show_label=True, visible=False, height=350
)
input_file = gr.File(
label="Input receipt",
show_label=True,
type="file",
file_count="single",
file_types=["image", ".pdf"],
)
input_file.change(
display_file, input_file, [pdf_preview, image_preview]
)
with gr.Row():
clear = gr.Button("Clear", variant="secondary")
submit_button = gr.Button("Submit", variant="primary")
show_intermediate = gr.Checkbox(
False,
label="Show intermediate outputs",
info="There are several intermediate steps in the process such as preprocessing, OCR, chatbot interaction. You can choose to show their results here.",
)
share_result = gr.Checkbox(
True,
label="Share results",
info="Sharing your result with us will help us immensely in improving this tool.",
interactive=True,
)
contact = gr.Textbox(
type="email",
label="Contact",
interactive=True,
placeholder="Enter your email address",
info="Optionally, enter your email address to allow us to contact you regarding your result.",
visible=True,
)
share_result.change(show_share_contact, share_result, [contact])
with gr.Column(variant="panel"):
gr.HTML(
'<div><center style="color:rgb(200, 200, 200);">Output</center></div>'
)
category = gr.Dropdown(
value=None,
choices=Category.__members__.keys(),
label=f"Recognized category ({', '.join(Category.__members__.keys())})",
show_label=True,
interactive=False,
)
intermediate_outputs = gr.Accordion(
"Intermediate outputs", open=True, visible=False
)
with intermediate_outputs:
extracted_text = gr.Textbox(
label="Extracted text",
show_label=True,
max_lines=5,
show_copy_button=True,
lines=5,
interactive=False,
)
chatbot = gr.Chatbot(
None,
label="Chatbot interaction",
show_label=True,
interactive=False,
height=240,
)
information = gr.JSON(label="Extracted information")
with gr.Row():
flag_incorrect_button = gr.Button(
"Flag as incorrect", variant="stop", interactive=True
)
flag_irrelevant_button = gr.Button(
"Flag as irrelevant", variant="stop", interactive=True
)
show_intermediate.change(
show_intermediate_outputs, show_intermediate, [intermediate_outputs]
)
clear.click(clear_inputs, None, [input_file])
submit_button.click(
submit,
[input_file, extracted_text],
[extracted_text],
)
submit_button.click(
lambda input_file, category, chatbot, information: (
gr.Dropdown.update(None),
gr.Chatbot.update(None),
gr.Textbox.update(None),
) if input_file else (category, chatbot, information),
[input_file, category, chatbot, information],
[category, chatbot, information],
)
extracted_text.change(
categorize_extracted_text,
[extracted_text],
[category],
)
category.change(
extract_from_category,
[category, extracted_text],
[chatbot, information, flag_incorrect_button, flag_irrelevant_button],
)
hf_writer_normal.setup(
[input_file, extracted_text, category, chatbot, information, contact],
flagging_dir="flagged",
)
flag_method = gr.flagging.FlagMethod(
hf_writer_normal, "", "", visual_feedback=True
)
information.change(
dynamic_auto_flag(flag_method),
inputs=[
share_result,
input_file,
extracted_text,
category,
chatbot,
information,
contact,
],
outputs=None,
preprocess=False,
queue=False,
)
hf_writer_incorrect.setup(
[input_file, extracted_text, category, chatbot, information, contact],
flagging_dir="flagged_incorrect",
)
flag_incorrect_method = gr.flagging.FlagMethod(
hf_writer_incorrect,
"Flag as incorrect",
"Incorrect",
visual_feedback=True,
)
flag_incorrect_button.click(
lambda: gr.Button.update(value="Saving...", interactive=False),
None,
flag_incorrect_button,
queue=False,
)
flag_incorrect_button.click(
flag_incorrect_method,
inputs=[
input_file,
extracted_text,
category,
chatbot,
information,
contact,
],
outputs=[flag_incorrect_button],
preprocess=False,
queue=False,
)
flag_irrelevant_method = gr.flagging.FlagMethod(
hf_writer_incorrect,
"Flag as irrelevant",
"Irrelevant",
visual_feedback=True,
)
flag_irrelevant_button.click(
lambda: gr.Button.update(value="Saving...", interactive=False),
None,
flag_irrelevant_button,
queue=False,
)
flag_irrelevant_button.click(
flag_irrelevant_method,
inputs=[
input_file,
extracted_text,
category,
chatbot,
information,
contact,
],
outputs=[flag_irrelevant_button],
preprocess=False,
queue=False,
)
page.launch(show_api=True, show_error=True, debug=True)