Addaci's picture
Update app.py (#6)
74f4d51 verified
import gradio as gr
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
# Load model and tokenizer for mT5-small
model = T5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
# Define task-specific prompts
def correct_htr_text(input_text, max_new_tokens, temperature):
prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def summarize_legal_text(input_text, max_new_tokens, temperature):
prompt = f"Summarize this legal text: {input_text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def answer_legal_question(input_text, question, max_new_tokens, temperature):
prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Define Gradio interface functions
def correct_htr_interface(text, max_new_tokens, temperature):
return correct_htr_text(text, max_new_tokens, temperature)
def summarize_interface(text, max_new_tokens, temperature):
return summarize_legal_text(text, max_new_tokens, temperature)
def question_interface(text, question, max_new_tokens, temperature):
return answer_legal_question(text, question, max_new_tokens, temperature)
def clear_all():
return "", ""
# External clickable buttons
def clickable_buttons():
button_html = """
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
<a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary"
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
Admiralty Court Legal Glossary</a>
<a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt"
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
HCA 13/70 Ground Truth</a>
</div>
"""
return button_html
# Interface layout
with gr.Blocks() as demo:
gr.HTML("<h1>Flan-T5 Legal Assistant</h1>")
gr.HTML(clickable_buttons())
with gr.Tab("Correct Raw HTR"):
input_text = gr.Textbox(lines=10, label="Textbox")
output_text = gr.Textbox(label="Textbox")
max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
correct_button = gr.Button("Correct HTR")
clear_button = gr.Button("Clear")
correct_button.click(fn=correct_htr_interface,
inputs=[input_text, max_new_tokens, temperature],
outputs=output_text)
clear_button.click(fn=clear_all, outputs=[input_text, output_text])
with gr.Tab("Summarize Legal Text"):
input_text_summarize = gr.Textbox(lines=10, label="Textbox")
output_text_summarize = gr.Textbox(label="Textbox")
max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
summarize_button = gr.Button("Summarize Text")
clear_button_summarize = gr.Button("Clear")
summarize_button.click(fn=summarize_interface,
inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize],
outputs=output_text_summarize)
clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize])
with gr.Tab("Answer Legal Question"):
input_text_question = gr.Textbox(lines=10, label="Textbox")
question = gr.Textbox(label="Textbox")
output_text_question = gr.Textbox(label="Textbox")
max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens")
temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
question_button = gr.Button("Get Answer")
clear_button_question = gr.Button("Clear")
question_button.click(fn=question_interface,
inputs=[input_text_question, question, max_new_tokens_question, temperature_question],
outputs=output_text_question)
clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question])
gr.Button("Clear", elem_id="clear_button").click(clear_all)
demo.launch()