import gradio as gr from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration # Load model and tokenizer for mT5-small model = T5ForConditionalGeneration.from_pretrained("google/mt5-small") tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") # Define task-specific prompts def correct_htr_text(input_text, max_new_tokens, temperature): prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}" inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def summarize_legal_text(input_text, max_new_tokens, temperature): prompt = f"Summarize this legal text: {input_text}" inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def answer_legal_question(input_text, question, max_new_tokens, temperature): prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}" inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Define Gradio interface functions def correct_htr_interface(text, max_new_tokens, temperature): return correct_htr_text(text, max_new_tokens, temperature) def summarize_interface(text, max_new_tokens, temperature): return summarize_legal_text(text, max_new_tokens, temperature) def question_interface(text, question, max_new_tokens, temperature): return answer_legal_question(text, question, max_new_tokens, temperature) def clear_all(): return "", "" # External clickable buttons def clickable_buttons(): button_html = """
""" return button_html # Interface layout with gr.Blocks() as demo: gr.HTML("