Spaces:
Build error
Build error
import gradio as gr | |
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration | |
# Load model and tokenizer for mT5-small | |
model = T5ForConditionalGeneration.from_pretrained("google/mt5-small") | |
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") | |
# Define task-specific prompts | |
def correct_htr_text(input_text, max_new_tokens, temperature): | |
prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def summarize_legal_text(input_text, max_new_tokens, temperature): | |
prompt = f"Summarize this legal text: {input_text}" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def answer_legal_question(input_text, question, max_new_tokens, temperature): | |
prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Define Gradio interface functions | |
def correct_htr_interface(text, max_new_tokens, temperature): | |
return correct_htr_text(text, max_new_tokens, temperature) | |
def summarize_interface(text, max_new_tokens, temperature): | |
return summarize_legal_text(text, max_new_tokens, temperature) | |
def question_interface(text, question, max_new_tokens, temperature): | |
return answer_legal_question(text, question, max_new_tokens, temperature) | |
def clear_all(): | |
return "", "" | |
# External clickable buttons | |
def clickable_buttons(): | |
button_html = """ | |
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;"> | |
<a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" | |
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;"> | |
Admiralty Court Legal Glossary</a> | |
<a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" | |
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;"> | |
HCA 13/70 Ground Truth</a> | |
</div> | |
""" | |
return button_html | |
# Interface layout | |
with gr.Blocks() as demo: | |
gr.HTML("<h1>Flan-T5 Legal Assistant</h1>") | |
gr.HTML(clickable_buttons()) | |
with gr.Tab("Correct Raw HTR"): | |
input_text = gr.Textbox(lines=10, label="Textbox") | |
output_text = gr.Textbox(label="Textbox") | |
max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens") | |
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") | |
correct_button = gr.Button("Correct HTR") | |
clear_button = gr.Button("Clear") | |
correct_button.click(fn=correct_htr_interface, | |
inputs=[input_text, max_new_tokens, temperature], | |
outputs=output_text) | |
clear_button.click(fn=clear_all, outputs=[input_text, output_text]) | |
with gr.Tab("Summarize Legal Text"): | |
input_text_summarize = gr.Textbox(lines=10, label="Textbox") | |
output_text_summarize = gr.Textbox(label="Textbox") | |
max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens") | |
temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature") | |
summarize_button = gr.Button("Summarize Text") | |
clear_button_summarize = gr.Button("Clear") | |
summarize_button.click(fn=summarize_interface, | |
inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize], | |
outputs=output_text_summarize) | |
clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize]) | |
with gr.Tab("Answer Legal Question"): | |
input_text_question = gr.Textbox(lines=10, label="Textbox") | |
question = gr.Textbox(label="Textbox") | |
output_text_question = gr.Textbox(label="Textbox") | |
max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens") | |
temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") | |
question_button = gr.Button("Get Answer") | |
clear_button_question = gr.Button("Clear") | |
question_button.click(fn=question_interface, | |
inputs=[input_text_question, question, max_new_tokens_question, temperature_question], | |
outputs=output_text_question) | |
clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question]) | |
gr.Button("Clear", elem_id="clear_button").click(clear_all) | |
demo.launch() |