Addaci's picture
Updated app.py to include transformers
d885415 verified
raw
history blame
1.97 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load your fine-tuned mT5 model
model_name = "Addaci/mT5-small-experiment-13-checkpoint-2790"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def correct_htr(raw_htr_text):
# Tokenize the input text
inputs = tokenizer(raw_htr_text, return_tensors="pt")
# Generate corrected text
outputs = model.generate(**inputs)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_text
def summarize_text(legal_text):
# Tokenize the input text with summarization prompt
inputs = tokenizer("summarize: " + legal_text, return_tensors="pt")
# Generate summary
outputs = model.generate(**inputs)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
def answer_question(legal_text, question):
# Combine context and question
inputs = tokenizer(f"question: {question} context: {legal_text}", return_tensors="pt")
# Generate answer
outputs = model.generate(**inputs)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Create the Gradio interface
iface = gr.Interface(
fn=[correct_htr, summarize_text, answer_question],
inputs=[
gr.Textbox(lines=5, placeholder="Enter raw HTR text here..."),
gr.Textbox(lines=10, placeholder="Enter legal text to summarize..."),
[gr.Textbox(lines=10, placeholder="Enter legal text..."),
gr.Textbox(lines=2, placeholder="Enter your question...")]
],
outputs=[
gr.Textbox(lines=5, placeholder="Corrected HTR text"),
gr.Textbox(lines=5, placeholder="Summary of legal text"),
gr.Textbox(lines=5, placeholder="Answer to your question")
],
title="mT5 Legal Assistant",
description="Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases."
)
iface.launch()