Spaces:
Build error
Build error
import gradio as gr | |
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration | |
# Load model and tokenizer | |
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small") | |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small") | |
# Summarize Legal Text function | |
def summarize_legal_text(input_text, max_new_tokens, temperature): | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
summary_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Correct HTR function | |
def correct_htr_text(input_text, max_new_tokens, temperature): | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Answer Legal Question function | |
def answer_legal_question(context, question, max_new_tokens, temperature): | |
input_text = f"Answer the following question based on the context: {question}\nContext: {context}" | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
with gr.Tab("Summarize Legal Text"): | |
summarize_input = gr.Textbox(label="Input Text", placeholder="Enter legal text here...", lines=10) | |
summarize_output = gr.Textbox(label="Summarized Text", lines=10) | |
max_new_tokens_summarize = gr.Slider(10, 512, value=256, step=1, label="Max New Tokens") | |
temperature_summarize = gr.Slider(0.1, 1, value=0.5, step=0.1, label="Temperature") | |
summarize_button = gr.Button("Summarize Text") | |
summarize_button.click( | |
summarize_legal_text, | |
inputs=[summarize_input, max_new_tokens_summarize, temperature_summarize], | |
outputs=summarize_output, | |
) | |
with gr.Tab("Correct Raw HTR Text"): | |
htr_input = gr.Textbox(label="Input HTR Text", placeholder="Enter HTR text here...", lines=5) | |
htr_output = gr.Textbox(label="Corrected HTR Text", lines=5) | |
max_new_tokens_htr = gr.Slider(10, 512, value=128, step=1, label="Max New Tokens") | |
temperature_htr = gr.Slider(0.1, 1, value=0.7, step=0.1, label="Temperature") | |
htr_button = gr.Button("Correct HTR") | |
htr_button.click( | |
correct_htr_text, | |
inputs=[htr_input, max_new_tokens_htr, temperature_htr], | |
outputs=htr_output, | |
) | |
with gr.Tab("Answer Legal Question"): | |
question_input_context = gr.Textbox(label="Context Text", placeholder="Enter legal context...", lines=10) | |
question_input = gr.Textbox(label="Enter your question", placeholder="Enter your question here...", lines=2) | |
question_output = gr.Textbox(label="Answer", lines=5) | |
max_new_tokens_question = gr.Slider(10, 512, value=128, step=1, label="Max New Tokens") | |
temperature_question = gr.Slider(0.1, 1, value=0.7, step=0.1, label="Temperature") | |
question_button = gr.Button("Get Answer") | |
question_button.click( | |
answer_legal_question, | |
inputs=[question_input_context, question_input, max_new_tokens_question, temperature_question], | |
outputs=question_output, | |
) | |
demo.launch() |