Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Load your fine-tuned mT5 model | |
model_name = "MarineLives/mT5-small-experiment-13" # Replace with your model name or path | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def correct_htr(raw_htr_text): | |
# Tokenize the input text | |
inputs = tokenizer(raw_htr_text, return_tensors="pt") | |
# Generate corrected text | |
outputs = model.generate(**inputs) | |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return corrected_text | |
def summarize_text(legal_text): | |
# Tokenize the input text with summarization prompt | |
inputs = tokenizer("summarize: " + legal_text, return_tensors="pt") | |
# Generate summary | |
outputs = model.generate(**inputs) | |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return summary | |
def answer_question(legal_text, question): | |
# Combine context and question | |
inputs = tokenizer(f"question: {question} context: {legal_text}", return_tensors="pt") | |
# Generate answer | |
outputs = model.generate(**inputs) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return answer | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=[correct_htr, summarize_text, answer_question], | |
inputs=[ | |
gr.Textbox(lines=5, placeholder="Enter raw HTR text here..."), | |
gr.Textbox(lines=10, placeholder="Enter legal text to summarize..."), | |
[gr.Textbox(lines=10, placeholder="Enter legal text..."), | |
gr.Textbox(lines=2, placeholder="Enter your question...")] | |
], | |
outputs=[ | |
gr.Textbox(lines=5, placeholder="Corrected HTR text"), | |
gr.Textbox(lines=5, placeholder="Summary of legal text"), | |
gr.Textbox(lines=5, placeholder="Answer to your question") | |
], | |
title="mT5 Legal Assistant", | |
description="Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases." | |
) | |
iface.launch() |