Addaci's picture
Further debugging of app.py
edba1cc verified
raw
history blame
3.51 kB
import gradio as gr
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
# Load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
# Summarize Legal Text function
def summarize_legal_text(input_text, max_new_tokens, temperature):
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
summary_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Correct HTR function
def correct_htr_text(input_text, max_new_tokens, temperature):
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Answer Legal Question function
def answer_legal_question(context, question, max_new_tokens, temperature):
input_text = f"Answer the following question based on the context: {question}\nContext: {context}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Gradio Interface
with gr.Blocks() as demo:
with gr.Tab("Summarize Legal Text"):
summarize_input = gr.Textbox(label="Input Text", placeholder="Enter legal text here...", lines=10)
summarize_output = gr.Textbox(label="Summarized Text", lines=10)
max_new_tokens_summarize = gr.Slider(10, 512, value=256, step=1, label="Max New Tokens")
temperature_summarize = gr.Slider(0.1, 1, value=0.5, step=0.1, label="Temperature")
summarize_button = gr.Button("Summarize Text")
summarize_button.click(
summarize_legal_text,
inputs=[summarize_input, max_new_tokens_summarize, temperature_summarize],
outputs=summarize_output,
)
with gr.Tab("Correct Raw HTR Text"):
htr_input = gr.Textbox(label="Input HTR Text", placeholder="Enter HTR text here...", lines=5)
htr_output = gr.Textbox(label="Corrected HTR Text", lines=5)
max_new_tokens_htr = gr.Slider(10, 512, value=128, step=1, label="Max New Tokens")
temperature_htr = gr.Slider(0.1, 1, value=0.7, step=0.1, label="Temperature")
htr_button = gr.Button("Correct HTR")
htr_button.click(
correct_htr_text,
inputs=[htr_input, max_new_tokens_htr, temperature_htr],
outputs=htr_output,
)
with gr.Tab("Answer Legal Question"):
question_input_context = gr.Textbox(label="Context Text", placeholder="Enter legal context...", lines=10)
question_input = gr.Textbox(label="Enter your question", placeholder="Enter your question here...", lines=2)
question_output = gr.Textbox(label="Answer", lines=5)
max_new_tokens_question = gr.Slider(10, 512, value=128, step=1, label="Max New Tokens")
temperature_question = gr.Slider(0.1, 1, value=0.7, step=0.1, label="Temperature")
question_button = gr.Button("Get Answer")
question_button.click(
answer_legal_question,
inputs=[question_input_context, question_input, max_new_tokens_question, temperature_question],
outputs=question_output,
)
demo.launch()