Addaci's picture
Debugging app.py and making interface changes (repositioning sliders and changing their colour)
cdfc9b6 verified
raw
history blame
7.8 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
# Setup logging (optional, but helpful for debugging)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# Load the Flan-T5 Small model and tokenizer
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
def correct_htr(raw_htr_text, max_new_tokens, temperature):
try:
if not raw_htr_text:
raise ValueError("Input text cannot be empty.")
logging.info("Processing HTR correction with Flan-T5 Small...")
prompt = f"Correct this text: {raw_htr_text}"
inputs = tokenizer(prompt, return_tensors="pt")
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.debug(f"Generated output for HTR correction: {corrected_text}")
return corrected_text
except ValueError as ve:
logging.warning(f"Validation error: {ve}")
return str(ve)
except Exception as e:
logging.error(f"Error in HTR correction: {e}", exc_info=True)
return "An error occurred while processing the text."
def summarize_text(legal_text, max_new_tokens, temperature):
try:
if not legal_text:
raise ValueError("Input text cannot be empty.")
logging.info("Processing summarization with Flan-T5 Small...")
prompt = f"Summarize the following legal text: {legal_text}"
inputs = tokenizer(prompt, return_tensors="pt")
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.debug(f"Generated summary: {summary}")
return summary
except ValueError as ve:
logging.warning(f"Validation error: {ve}")
return str(ve)
except Exception as e:
logging.error(f"Error in summarization: {e}", exc_info=True)
return "An error occurred while summarizing the text."
def answer_question(legal_text, question, max_new_tokens, temperature):
try:
if not legal_text or not question:
raise ValueError("Both legal text and question must be provided.")
logging.info("Processing question-answering with Flan-T5 Small...")
prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
inputs = tokenizer(prompt, return_tensors="pt")
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.debug(f"Generated answer: {answer}")
return answer
except ValueError as ve:
logging.warning(f"Validation error: {ve}")
return str(ve)
except Exception as e:
logging.error(f"Error in question-answering: {e}", exc_info=True)
return "An error occurred while answering the question."
def clear_fields():
return "", "", ""
# Create the Gradio Blocks interface
with gr.Blocks(css=".block .input-slider { color: blue !important }") as demo:
gr.Markdown("# Flan-T5 Small Legal Assistant")
gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5 Small).")
with gr.Row():
gr.HTML('''
<div style="display: flex; gap: 10px;">
<div style="border: 2px solid black; padding: 10px; display: inline-block;">
<a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" target="_blank">
<button style="font-weight:bold;">Admiralty Court Legal Glossary</button>
</a>
</div>
<div style="border: 2px solid black; padding: 10px; display: inline-block;">
<a href="https://raw.githubusercontent.com/Addaci/HCA/refs/heads/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" target="_blank">
<button style="font-weight:bold;">HCA 13/70 Ground Truth (1654-55)</button>
</a>
</div>
</div>
''')
with gr.Tab("Correct HTR"):
gr.Markdown("### Correct Raw HTR Text")
raw_htr_input = gr.Textbox(lines=5, placeholder="Enter raw HTR text here...")
corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
correct_button = gr.Button("Correct HTR")
clear_button = gr.Button("Clear")
correct_max_new_tokens = gr.Slider(minimum=10, maximum=512, value=128, step=1, label="Max New Tokens")
correct_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
correct_button.click(correct_htr, inputs=[raw_htr_input, correct_max_new_tokens, correct_temperature], outputs=corrected_output)
clear_button.click(clear_fields, outputs=[raw_htr_input, corrected_output])
gr.Markdown("### Set Parameters")
correct_max_new_tokens.render()
correct_temperature.render()
with gr.Tab("Summarize Legal Text"):
gr.Markdown("### Summarize Legal Text")
legal_text_input = gr.Textbox(lines=10, placeholder="Enter legal text to summarize...")
summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
summarize_button = gr.Button("Summarize Text")
clear_button = gr.Button("Clear")
summarize_max_new_tokens = gr.Slider(minimum=10, maximum=1024, value=256, step=1, label="Max New Tokens")
summarize_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Temperature")
summarize_button.click(summarize_text, inputs=[legal_text_input, summarize_max_new_tokens, summarize_temperature], outputs=summary_output)
clear_button.click(clear_fields, outputs=[legal_text_input, summary_output])
gr.Markdown("### Set Parameters")
summarize_max_new_tokens.render()
summarize_temperature.render()
with gr.Tab("Answer Legal Question"):
gr.Markdown("### Answer a Question Based on Legal Text")
legal_text_input_q = gr.Textbox(lines=10, placeholder="Enter legal text...")
question_input = gr.Textbox(lines=2, placeholder="Enter your question...")
answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
answer_button = gr.Button("Get Answer")
clear_button = gr.Button("Clear")
answer_max_new_tokens = gr.Slider(minimum=10, maximum=512, value=150, step=1, label="Max New Tokens")
answer_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Temperature")
answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, answer_max_new_tokens, answer_temperature], outputs=answer_output)
clear_button.click(clear_fields, outputs=[legal_text_input_q, question_input, answer_output])
gr.Markdown("### Set Parameters")
answer_max_new_tokens.render()
answer_temperature.render()
# Model warm-up (optional, but useful for performance)
model.generate(**tokenizer("Warm-up", return_tensors="pt"), max_length=10)
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch()