Addaci commited on
Commit
340628c
1 Parent(s): 85abbe0

ChatGPT-4o rewrite of GeminiPro app.py file with change to Flan-T5-small model

Browse files

List of Improvements:

1. Model Change to Flan-T5 Small:

Updated the model from flan-t5-xl to flan-t5-small to improve performance and reduce memory usage.
Reduced the max_new_tokens slider’s upper limit from 1000 to 512 (since the smaller model will generally work with smaller outputs).
Reduced the default value for max_new_tokens to 128 for reasonable generation with a smaller model.

2. Efficiency Improvements:

The smaller model significantly improves the speed of generation and lowers memory overhead, which is crucial for use cases where lower-latency responses are needed, such as interactive apps.

3. Adjusted max_length Calculation:

Updated the max_length argument in the model.generate() calls, setting a more appropriate limit based on the input token length for the smaller model. This prevents unnecessarily long outputs for smaller inputs.

4. Consistency in Error Handling and Validation:

Maintained the improved error handling and validation from the previous version, ensuring that empty inputs are caught early, and clearer error messages are provided.

5. Preserved UI and Sliders:

The interface and overall structure of the Gradio Blocks app remain the same for usability, but the sliders for max_new_tokens are adjusted to be more suitable for the smaller model.

6. Warm-Up Call:

Kept the model warm-up step to ensure the app's first run is faster.

By switching to the smaller model, the application becomes more responsive and resource-efficient, while retaining the same features and user interface as the original code.

Files changed (1) hide show
  1. app.py +46 -18
app.py CHANGED
@@ -5,58 +5,82 @@ import logging
5
  # Setup logging (optional, but helpful for debugging)
6
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
7
 
8
- # Load the Flan-T5 model and tokenizer
9
- model_id = "google/flan-t5-xl"
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
12
 
13
  # Define the sliders outside the gr.Row() block
14
- max_new_tokens = gr.Slider(minimum=10, maximum=1000, value=500, step=1, label="Max New Tokens")
15
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
16
 
17
  def correct_htr(raw_htr_text, max_new_tokens, temperature):
18
  try:
19
- logging.info("Processing HTR correction with Flan-T5...")
 
 
 
20
  prompt = f"Correct this text: {raw_htr_text}"
21
  inputs = tokenizer(prompt, return_tensors="pt")
22
- outputs = model.generate(**inputs, max_length=max_new_tokens, temperature=temperature)
 
23
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
  logging.debug(f"Generated output for HTR correction: {corrected_text}")
25
  return corrected_text
 
 
 
26
  except Exception as e:
27
  logging.error(f"Error in HTR correction: {e}", exc_info=True)
28
- return str(e)
29
 
30
  def summarize_text(legal_text, max_new_tokens, temperature):
31
  try:
32
- logging.info("Processing summarization with Flan-T5...")
 
 
 
33
  prompt = f"Summarize the following legal text: {legal_text}"
34
  inputs = tokenizer(prompt, return_tensors="pt")
35
- outputs = model.generate(**inputs, max_length=max_new_tokens, temperature=temperature)
 
36
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
  logging.debug(f"Generated summary: {summary}")
38
  return summary
 
 
 
39
  except Exception as e:
40
  logging.error(f"Error in summarization: {e}", exc_info=True)
41
- return str(e)
42
 
43
  def answer_question(legal_text, question, max_new_tokens, temperature):
44
  try:
45
- logging.info("Processing question-answering with Flan-T5...")
 
 
 
46
  prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
47
  inputs = tokenizer(prompt, return_tensors="pt")
48
- outputs = model.generate(**inputs, max_length=max_new_tokens, temperature=temperature)
 
49
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
  logging.debug(f"Generated answer: {answer}")
51
  return answer
 
 
 
52
  except Exception as e:
53
  logging.error(f"Error in question-answering: {e}", exc_info=True)
54
- return str(e)
 
 
 
55
 
56
  # Create the Gradio Blocks interface
57
  with gr.Blocks() as demo:
58
- gr.Markdown("# Flan-T5 Legal Assistant")
59
- gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5).")
60
 
61
  with gr.Row():
62
  gr.HTML('''
@@ -82,7 +106,7 @@ with gr.Blocks() as demo:
82
  clear_button = gr.Button("Clear")
83
 
84
  correct_button.click(correct_htr, inputs=[raw_htr_input, max_new_tokens, temperature], outputs=corrected_output)
85
- clear_button.click(lambda: ("", ""), outputs=[raw_htr_input, corrected_output])
86
 
87
  with gr.Tab("Summarize Legal Text"):
88
  gr.Markdown("### Summarize Legal Text")
@@ -92,7 +116,7 @@ with gr.Blocks() as demo:
92
  clear_button = gr.Button("Clear")
93
 
94
  summarize_button.click(summarize_text, inputs=[legal_text_input, max_new_tokens, temperature], outputs=summary_output)
95
- clear_button.click(lambda: ("", ""), outputs=[legal_text_input, summary_output])
96
 
97
  with gr.Tab("Answer Legal Question"):
98
  gr.Markdown("### Answer a Question Based on Legal Text")
@@ -103,13 +127,17 @@ with gr.Blocks() as demo:
103
  clear_button = gr.Button("Clear")
104
 
105
  answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, max_new_tokens, temperature], outputs=answer_output)
106
- clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
107
 
108
  # The sliders are already defined, so just include them in the layout
109
  with gr.Row():
110
  # No need to redefine max_new_tokens and temperature here
111
  pass
112
 
 
 
 
113
  # Launch the Gradio interface
114
  if __name__ == "__main__":
115
- demo.launch()
 
 
5
  # Setup logging (optional, but helpful for debugging)
6
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
7
 
8
+ # Load the Flan-T5 Small model and tokenizer
9
+ model_id = "google/flan-t5-small"
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
12
 
13
  # Define the sliders outside the gr.Row() block
14
+ max_new_tokens = gr.Slider(minimum=10, maximum=512, value=360, step=1, label="Max New Tokens") # Adjusted for smaller model
15
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
16
 
17
  def correct_htr(raw_htr_text, max_new_tokens, temperature):
18
  try:
19
+ if not raw_htr_text:
20
+ raise ValueError("Input text cannot be empty.")
21
+
22
+ logging.info("Processing HTR correction with Flan-T5 Small...")
23
  prompt = f"Correct this text: {raw_htr_text}"
24
  inputs = tokenizer(prompt, return_tensors="pt")
25
+ max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens) # Cap max_length
26
+ outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
27
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  logging.debug(f"Generated output for HTR correction: {corrected_text}")
29
  return corrected_text
30
+ except ValueError as ve:
31
+ logging.warning(f"Validation error: {ve}")
32
+ return str(ve)
33
  except Exception as e:
34
  logging.error(f"Error in HTR correction: {e}", exc_info=True)
35
+ return "An error occurred while processing the text."
36
 
37
  def summarize_text(legal_text, max_new_tokens, temperature):
38
  try:
39
+ if not legal_text:
40
+ raise ValueError("Input text cannot be empty.")
41
+
42
+ logging.info("Processing summarization with Flan-T5 Small...")
43
  prompt = f"Summarize the following legal text: {legal_text}"
44
  inputs = tokenizer(prompt, return_tensors="pt")
45
+ max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens) # Cap max_length
46
+ outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
47
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
  logging.debug(f"Generated summary: {summary}")
49
  return summary
50
+ except ValueError as ve:
51
+ logging.warning(f"Validation error: {ve}")
52
+ return str(ve)
53
  except Exception as e:
54
  logging.error(f"Error in summarization: {e}", exc_info=True)
55
+ return "An error occurred while summarizing the text."
56
 
57
  def answer_question(legal_text, question, max_new_tokens, temperature):
58
  try:
59
+ if not legal_text or not question:
60
+ raise ValueError("Both legal text and question must be provided.")
61
+
62
+ logging.info("Processing question-answering with Flan-T5 Small...")
63
  prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
64
  inputs = tokenizer(prompt, return_tensors="pt")
65
+ max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens) # Cap max_length
66
+ outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
67
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
  logging.debug(f"Generated answer: {answer}")
69
  return answer
70
+ except ValueError as ve:
71
+ logging.warning(f"Validation error: {ve}")
72
+ return str(ve)
73
  except Exception as e:
74
  logging.error(f"Error in question-answering: {e}", exc_info=True)
75
+ return "An error occurred while answering the question."
76
+
77
+ def clear_fields():
78
+ return "", "", ""
79
 
80
  # Create the Gradio Blocks interface
81
  with gr.Blocks() as demo:
82
+ gr.Markdown("# Flan-T5 Small Legal Assistant")
83
+ gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5 Small).")
84
 
85
  with gr.Row():
86
  gr.HTML('''
 
106
  clear_button = gr.Button("Clear")
107
 
108
  correct_button.click(correct_htr, inputs=[raw_htr_input, max_new_tokens, temperature], outputs=corrected_output)
109
+ clear_button.click(clear_fields, outputs=[raw_htr_input, corrected_output])
110
 
111
  with gr.Tab("Summarize Legal Text"):
112
  gr.Markdown("### Summarize Legal Text")
 
116
  clear_button = gr.Button("Clear")
117
 
118
  summarize_button.click(summarize_text, inputs=[legal_text_input, max_new_tokens, temperature], outputs=summary_output)
119
+ clear_button.click(clear_fields, outputs=[legal_text_input, summary_output])
120
 
121
  with gr.Tab("Answer Legal Question"):
122
  gr.Markdown("### Answer a Question Based on Legal Text")
 
127
  clear_button = gr.Button("Clear")
128
 
129
  answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, max_new_tokens, temperature], outputs=answer_output)
130
+ clear_button.click(clear_fields, outputs=[legal_text_input_q, question_input, answer_output])
131
 
132
  # The sliders are already defined, so just include them in the layout
133
  with gr.Row():
134
  # No need to redefine max_new_tokens and temperature here
135
  pass
136
 
137
+ # Model warm-up (optional, but useful for performance)
138
+ model.generate(**tokenizer("Warm-up", return_tensors="pt"), max_length=10)
139
+
140
  # Launch the Gradio interface
141
  if __name__ == "__main__":
142
+ demo.launch()
143
+