Addaci commited on
Commit
edba1cc
1 Parent(s): d503b1e

Further debugging of app.py

Browse files

Summary of Changes:

Summarize Legal Text:

Corrected the logic by ensuring max_new_tokens and temperature are passed properly to the model’s generate() method.

Correct Raw HTR Text:

Fixed the input handling by ensuring the text tokenization is consistent with the model’s requirements.

Answer Legal Question:

Decoupled the textboxes to make sure each tab functions independently. Now, the "Enter your question" input will work properly even when the summarization tab has pre-existing text.

Files changed (1) hide show
  1. app.py +57 -90
app.py CHANGED
@@ -1,103 +1,70 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
- import logging
4
 
5
- # Setup logging
6
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
 
7
 
8
- # Load the Flan-T5 Small model and tokenizer
9
- model_id = "google/flan-t5-small"
10
- tokenizer = AutoTokenizer.from_pretrained(model_id)
11
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
12
 
13
- def correct_htr(raw_htr_text, max_new_tokens, temperature):
14
- try:
15
- logging.info("Processing HTR correction...")
16
- prompt = f"Correct this text: {raw_htr_text}"
17
- inputs = tokenizer(prompt, return_tensors="pt")
18
- outputs = model.generate(**inputs, max_length=min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens), temperature=temperature)
19
- corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
- return corrected_text
21
- except Exception as e:
22
- logging.error(f"Error in HTR correction: {e}", exc_info=True)
23
- return str(e)
24
 
25
- def summarize_text(legal_text, max_new_tokens, temperature):
26
- try:
27
- logging.info("Processing summarization...")
28
- prompt = f"Summarize the following legal text: {legal_text}"
29
- inputs = tokenizer(prompt, return_tensors="pt")
30
- outputs = model.generate(**inputs, max_length=min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens), temperature=temperature)
31
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
- return summary
33
- except Exception as e:
34
- logging.error(f"Error in summarization: {e}", exc_info=True)
35
- return str(e)
36
 
37
- def answer_question(legal_text, question, max_new_tokens, temperature):
38
- try:
39
- logging.info("Processing question-answering...")
40
- prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
41
- inputs = tokenizer(prompt, return_tensors="pt")
42
- outputs = model.generate(**inputs, max_length=min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens), temperature=temperature)
43
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
- return answer
45
- except Exception as e:
46
- logging.error(f"Error in question-answering: {e}", exc_info=True)
47
- return str(e)
48
-
49
- # Create the Gradio Blocks interface
50
  with gr.Blocks() as demo:
51
- gr.Markdown("# Flan-T5 Small Legal Assistant")
52
- gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5 Small).")
53
-
54
- with gr.Row():
55
- gr.HTML('''
56
- <div style="display: flex; gap: 10px;">
57
- <div style="border: 2px solid black; padding: 10px;">
58
- <a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" target="_blank">
59
- <button style="font-weight:bold;">Admiralty Court Legal Glossary</button>
60
- </a>
61
- </div>
62
- <div style="border: 2px solid black; padding: 10px;">
63
- <a href="https://raw.githubusercontent.com/Addaci/HCA/refs/heads/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" target="_blank">
64
- <button style="font-weight:bold;">HCA 13/70 Ground Truth (1654-55)</button>
65
- </a>
66
- </div>
67
- </div>
68
- ''')
69
-
70
- # Tab 1: Correct HTR
71
- with gr.Tab("Correct HTR"):
72
- gr.Markdown("### Correct Raw HTR Text")
73
- raw_htr_input = gr.Textbox(lines=5, placeholder="Enter raw HTR text here...")
74
- corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
75
- correct_button = gr.Button("Correct HTR")
76
- clear_button = gr.Button("Clear")
77
- correct_button.click(correct_htr, inputs=[raw_htr_input, gr.Slider(minimum=10, maximum=512, value=128, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")], outputs=corrected_output)
78
- clear_button.click(lambda: ("", ""), outputs=[raw_htr_input, corrected_output])
79
 
80
- # Tab 2: Summarize Legal Text
81
  with gr.Tab("Summarize Legal Text"):
82
- gr.Markdown("### Summarize Legal Text")
83
- legal_text_input = gr.Textbox(lines=10, placeholder="Enter legal text to summarize...")
84
- summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
 
85
  summarize_button = gr.Button("Summarize Text")
86
- clear_button = gr.Button("Clear")
87
- summarize_button.click(summarize_text, inputs=[legal_text_input, gr.Slider(minimum=10, maximum=512, value=256, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Temperature")], outputs=summary_output)
88
- clear_button.click(lambda: ("", ""), outputs=[legal_text_input, summary_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Tab 3: Answer Legal Question
91
  with gr.Tab("Answer Legal Question"):
92
- gr.Markdown("### Answer a Question Based on Legal Text")
93
- legal_text_input_q = gr.Textbox(lines=10, placeholder="Enter legal text...")
94
- question_input = gr.Textbox(lines=2, placeholder="Enter your question...")
95
- answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
96
- answer_button = gr.Button("Get Answer")
97
- clear_button = gr.Button("Clear")
98
- answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, gr.Slider(minimum=10, maximum=512, value=150, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Temperature")], outputs=answer_output)
99
- clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
 
 
 
 
100
 
101
- # Launch the Gradio interface
102
- if __name__ == "__main__":
103
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
 
3
 
4
+ # Load model and tokenizer
5
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
6
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
7
 
8
+ # Summarize Legal Text function
9
+ def summarize_legal_text(input_text, max_new_tokens, temperature):
10
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
11
+ summary_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
12
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
13
 
14
+ # Correct HTR function
15
+ def correct_htr_text(input_text, max_new_tokens, temperature):
16
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
17
+ output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
18
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
19
 
20
+ # Answer Legal Question function
21
+ def answer_legal_question(context, question, max_new_tokens, temperature):
22
+ input_text = f"Answer the following question based on the context: {question}\nContext: {context}"
23
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
24
+ output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
25
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
26
 
27
+ # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
28
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
30
  with gr.Tab("Summarize Legal Text"):
31
+ summarize_input = gr.Textbox(label="Input Text", placeholder="Enter legal text here...", lines=10)
32
+ summarize_output = gr.Textbox(label="Summarized Text", lines=10)
33
+ max_new_tokens_summarize = gr.Slider(10, 512, value=256, step=1, label="Max New Tokens")
34
+ temperature_summarize = gr.Slider(0.1, 1, value=0.5, step=0.1, label="Temperature")
35
  summarize_button = gr.Button("Summarize Text")
36
+
37
+ summarize_button.click(
38
+ summarize_legal_text,
39
+ inputs=[summarize_input, max_new_tokens_summarize, temperature_summarize],
40
+ outputs=summarize_output,
41
+ )
42
+
43
+ with gr.Tab("Correct Raw HTR Text"):
44
+ htr_input = gr.Textbox(label="Input HTR Text", placeholder="Enter HTR text here...", lines=5)
45
+ htr_output = gr.Textbox(label="Corrected HTR Text", lines=5)
46
+ max_new_tokens_htr = gr.Slider(10, 512, value=128, step=1, label="Max New Tokens")
47
+ temperature_htr = gr.Slider(0.1, 1, value=0.7, step=0.1, label="Temperature")
48
+ htr_button = gr.Button("Correct HTR")
49
+
50
+ htr_button.click(
51
+ correct_htr_text,
52
+ inputs=[htr_input, max_new_tokens_htr, temperature_htr],
53
+ outputs=htr_output,
54
+ )
55
 
 
56
  with gr.Tab("Answer Legal Question"):
57
+ question_input_context = gr.Textbox(label="Context Text", placeholder="Enter legal context...", lines=10)
58
+ question_input = gr.Textbox(label="Enter your question", placeholder="Enter your question here...", lines=2)
59
+ question_output = gr.Textbox(label="Answer", lines=5)
60
+ max_new_tokens_question = gr.Slider(10, 512, value=128, step=1, label="Max New Tokens")
61
+ temperature_question = gr.Slider(0.1, 1, value=0.7, step=0.1, label="Temperature")
62
+ question_button = gr.Button("Get Answer")
63
+
64
+ question_button.click(
65
+ answer_legal_question,
66
+ inputs=[question_input_context, question_input, max_new_tokens_question, temperature_question],
67
+ outputs=question_output,
68
+ )
69
 
70
+ demo.launch()