import gradio as gr from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration # Load model and tokenizer for mT5-small model = T5ForConditionalGeneration.from_pretrained("google/mt5-small") tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") # Define task-specific prompts def correct_htr_text(input_text, max_new_tokens, temperature): prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}" inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def summarize_legal_text(input_text, max_new_tokens, temperature): prompt = f"Summarize this legal text: {input_text}" inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def answer_legal_question(input_text, question, max_new_tokens, temperature): prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}" inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Define Gradio interface functions def correct_htr_interface(text, max_new_tokens, temperature): return correct_htr_text(text, max_new_tokens, temperature) def summarize_interface(text, max_new_tokens, temperature): return summarize_legal_text(text, max_new_tokens, temperature) def question_interface(text, question, max_new_tokens, temperature): return answer_legal_question(text, question, max_new_tokens, temperature) def clear_all(): return "", "" # External clickable buttons def clickable_buttons(): button_html = """
Admiralty Court Legal Glossary HCA 13/70 Ground Truth
""" return button_html # Interface layout with gr.Blocks() as demo: gr.HTML("

Flan-T5 Legal Assistant

") gr.HTML(clickable_buttons()) with gr.Tab("Correct Raw HTR"): input_text = gr.Textbox(lines=10, label="Textbox") output_text = gr.Textbox(label="Textbox") max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens") temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") correct_button = gr.Button("Correct HTR") clear_button = gr.Button("Clear") correct_button.click(fn=correct_htr_interface, inputs=[input_text, max_new_tokens, temperature], outputs=output_text) clear_button.click(fn=clear_all, outputs=[input_text, output_text]) with gr.Tab("Summarize Legal Text"): input_text_summarize = gr.Textbox(lines=10, label="Textbox") output_text_summarize = gr.Textbox(label="Textbox") max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens") temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature") summarize_button = gr.Button("Summarize Text") clear_button_summarize = gr.Button("Clear") summarize_button.click(fn=summarize_interface, inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize], outputs=output_text_summarize) clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize]) with gr.Tab("Answer Legal Question"): input_text_question = gr.Textbox(lines=10, label="Textbox") question = gr.Textbox(label="Textbox") output_text_question = gr.Textbox(label="Textbox") max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens") temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") question_button = gr.Button("Get Answer") clear_button_question = gr.Button("Clear") question_button.click(fn=question_interface, inputs=[input_text_question, question, max_new_tokens_question, temperature_question], outputs=output_text_question) clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question]) gr.Button("Clear", elem_id="clear_button").click(clear_all) demo.launch()