Spaces:
Build error
ChatGPT-4o rewrite of GeminiPro app.py file with change to Flan-T5-small model
Browse filesList of Improvements:
1. Model Change to Flan-T5 Small:
Updated the model from flan-t5-xl to flan-t5-small to improve performance and reduce memory usage.
Reduced the max_new_tokens slider’s upper limit from 1000 to 512 (since the smaller model will generally work with smaller outputs).
Reduced the default value for max_new_tokens to 128 for reasonable generation with a smaller model.
2. Efficiency Improvements:
The smaller model significantly improves the speed of generation and lowers memory overhead, which is crucial for use cases where lower-latency responses are needed, such as interactive apps.
3. Adjusted max_length Calculation:
Updated the max_length argument in the model.generate() calls, setting a more appropriate limit based on the input token length for the smaller model. This prevents unnecessarily long outputs for smaller inputs.
4. Consistency in Error Handling and Validation:
Maintained the improved error handling and validation from the previous version, ensuring that empty inputs are caught early, and clearer error messages are provided.
5. Preserved UI and Sliders:
The interface and overall structure of the Gradio Blocks app remain the same for usability, but the sliders for max_new_tokens are adjusted to be more suitable for the smaller model.
6. Warm-Up Call:
Kept the model warm-up step to ensure the app's first run is faster.
By switching to the smaller model, the application becomes more responsive and resource-efficient, while retaining the same features and user interface as the original code.
@@ -5,58 +5,82 @@ import logging
|
|
5 |
# Setup logging (optional, but helpful for debugging)
|
6 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
7 |
|
8 |
-
# Load the Flan-T5 model and tokenizer
|
9 |
-
model_id = "google/flan-t5-
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
11 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
12 |
|
13 |
# Define the sliders outside the gr.Row() block
|
14 |
-
max_new_tokens = gr.Slider(minimum=10, maximum=
|
15 |
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
|
16 |
|
17 |
def correct_htr(raw_htr_text, max_new_tokens, temperature):
|
18 |
try:
|
19 |
-
|
|
|
|
|
|
|
20 |
prompt = f"Correct this text: {raw_htr_text}"
|
21 |
inputs = tokenizer(prompt, return_tensors="pt")
|
22 |
-
|
|
|
23 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
24 |
logging.debug(f"Generated output for HTR correction: {corrected_text}")
|
25 |
return corrected_text
|
|
|
|
|
|
|
26 |
except Exception as e:
|
27 |
logging.error(f"Error in HTR correction: {e}", exc_info=True)
|
28 |
-
return
|
29 |
|
30 |
def summarize_text(legal_text, max_new_tokens, temperature):
|
31 |
try:
|
32 |
-
|
|
|
|
|
|
|
33 |
prompt = f"Summarize the following legal text: {legal_text}"
|
34 |
inputs = tokenizer(prompt, return_tensors="pt")
|
35 |
-
|
|
|
36 |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
37 |
logging.debug(f"Generated summary: {summary}")
|
38 |
return summary
|
|
|
|
|
|
|
39 |
except Exception as e:
|
40 |
logging.error(f"Error in summarization: {e}", exc_info=True)
|
41 |
-
return
|
42 |
|
43 |
def answer_question(legal_text, question, max_new_tokens, temperature):
|
44 |
try:
|
45 |
-
|
|
|
|
|
|
|
46 |
prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
|
47 |
inputs = tokenizer(prompt, return_tensors="pt")
|
48 |
-
|
|
|
49 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
50 |
logging.debug(f"Generated answer: {answer}")
|
51 |
return answer
|
|
|
|
|
|
|
52 |
except Exception as e:
|
53 |
logging.error(f"Error in question-answering: {e}", exc_info=True)
|
54 |
-
return
|
|
|
|
|
|
|
55 |
|
56 |
# Create the Gradio Blocks interface
|
57 |
with gr.Blocks() as demo:
|
58 |
-
gr.Markdown("# Flan-T5 Legal Assistant")
|
59 |
-
gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5).")
|
60 |
|
61 |
with gr.Row():
|
62 |
gr.HTML('''
|
@@ -82,7 +106,7 @@ with gr.Blocks() as demo:
|
|
82 |
clear_button = gr.Button("Clear")
|
83 |
|
84 |
correct_button.click(correct_htr, inputs=[raw_htr_input, max_new_tokens, temperature], outputs=corrected_output)
|
85 |
-
clear_button.click(
|
86 |
|
87 |
with gr.Tab("Summarize Legal Text"):
|
88 |
gr.Markdown("### Summarize Legal Text")
|
@@ -92,7 +116,7 @@ with gr.Blocks() as demo:
|
|
92 |
clear_button = gr.Button("Clear")
|
93 |
|
94 |
summarize_button.click(summarize_text, inputs=[legal_text_input, max_new_tokens, temperature], outputs=summary_output)
|
95 |
-
clear_button.click(
|
96 |
|
97 |
with gr.Tab("Answer Legal Question"):
|
98 |
gr.Markdown("### Answer a Question Based on Legal Text")
|
@@ -103,13 +127,17 @@ with gr.Blocks() as demo:
|
|
103 |
clear_button = gr.Button("Clear")
|
104 |
|
105 |
answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, max_new_tokens, temperature], outputs=answer_output)
|
106 |
-
clear_button.click(
|
107 |
|
108 |
# The sliders are already defined, so just include them in the layout
|
109 |
with gr.Row():
|
110 |
# No need to redefine max_new_tokens and temperature here
|
111 |
pass
|
112 |
|
|
|
|
|
|
|
113 |
# Launch the Gradio interface
|
114 |
if __name__ == "__main__":
|
115 |
-
demo.launch()
|
|
|
|
5 |
# Setup logging (optional, but helpful for debugging)
|
6 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
7 |
|
8 |
+
# Load the Flan-T5 Small model and tokenizer
|
9 |
+
model_id = "google/flan-t5-small"
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
11 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
12 |
|
13 |
# Define the sliders outside the gr.Row() block
|
14 |
+
max_new_tokens = gr.Slider(minimum=10, maximum=512, value=360, step=1, label="Max New Tokens") # Adjusted for smaller model
|
15 |
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
|
16 |
|
17 |
def correct_htr(raw_htr_text, max_new_tokens, temperature):
|
18 |
try:
|
19 |
+
if not raw_htr_text:
|
20 |
+
raise ValueError("Input text cannot be empty.")
|
21 |
+
|
22 |
+
logging.info("Processing HTR correction with Flan-T5 Small...")
|
23 |
prompt = f"Correct this text: {raw_htr_text}"
|
24 |
inputs = tokenizer(prompt, return_tensors="pt")
|
25 |
+
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens) # Cap max_length
|
26 |
+
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
|
27 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
28 |
logging.debug(f"Generated output for HTR correction: {corrected_text}")
|
29 |
return corrected_text
|
30 |
+
except ValueError as ve:
|
31 |
+
logging.warning(f"Validation error: {ve}")
|
32 |
+
return str(ve)
|
33 |
except Exception as e:
|
34 |
logging.error(f"Error in HTR correction: {e}", exc_info=True)
|
35 |
+
return "An error occurred while processing the text."
|
36 |
|
37 |
def summarize_text(legal_text, max_new_tokens, temperature):
|
38 |
try:
|
39 |
+
if not legal_text:
|
40 |
+
raise ValueError("Input text cannot be empty.")
|
41 |
+
|
42 |
+
logging.info("Processing summarization with Flan-T5 Small...")
|
43 |
prompt = f"Summarize the following legal text: {legal_text}"
|
44 |
inputs = tokenizer(prompt, return_tensors="pt")
|
45 |
+
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens) # Cap max_length
|
46 |
+
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
|
47 |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
48 |
logging.debug(f"Generated summary: {summary}")
|
49 |
return summary
|
50 |
+
except ValueError as ve:
|
51 |
+
logging.warning(f"Validation error: {ve}")
|
52 |
+
return str(ve)
|
53 |
except Exception as e:
|
54 |
logging.error(f"Error in summarization: {e}", exc_info=True)
|
55 |
+
return "An error occurred while summarizing the text."
|
56 |
|
57 |
def answer_question(legal_text, question, max_new_tokens, temperature):
|
58 |
try:
|
59 |
+
if not legal_text or not question:
|
60 |
+
raise ValueError("Both legal text and question must be provided.")
|
61 |
+
|
62 |
+
logging.info("Processing question-answering with Flan-T5 Small...")
|
63 |
prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
|
64 |
inputs = tokenizer(prompt, return_tensors="pt")
|
65 |
+
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens) # Cap max_length
|
66 |
+
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
|
67 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
68 |
logging.debug(f"Generated answer: {answer}")
|
69 |
return answer
|
70 |
+
except ValueError as ve:
|
71 |
+
logging.warning(f"Validation error: {ve}")
|
72 |
+
return str(ve)
|
73 |
except Exception as e:
|
74 |
logging.error(f"Error in question-answering: {e}", exc_info=True)
|
75 |
+
return "An error occurred while answering the question."
|
76 |
+
|
77 |
+
def clear_fields():
|
78 |
+
return "", "", ""
|
79 |
|
80 |
# Create the Gradio Blocks interface
|
81 |
with gr.Blocks() as demo:
|
82 |
+
gr.Markdown("# Flan-T5 Small Legal Assistant")
|
83 |
+
gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5 Small).")
|
84 |
|
85 |
with gr.Row():
|
86 |
gr.HTML('''
|
|
|
106 |
clear_button = gr.Button("Clear")
|
107 |
|
108 |
correct_button.click(correct_htr, inputs=[raw_htr_input, max_new_tokens, temperature], outputs=corrected_output)
|
109 |
+
clear_button.click(clear_fields, outputs=[raw_htr_input, corrected_output])
|
110 |
|
111 |
with gr.Tab("Summarize Legal Text"):
|
112 |
gr.Markdown("### Summarize Legal Text")
|
|
|
116 |
clear_button = gr.Button("Clear")
|
117 |
|
118 |
summarize_button.click(summarize_text, inputs=[legal_text_input, max_new_tokens, temperature], outputs=summary_output)
|
119 |
+
clear_button.click(clear_fields, outputs=[legal_text_input, summary_output])
|
120 |
|
121 |
with gr.Tab("Answer Legal Question"):
|
122 |
gr.Markdown("### Answer a Question Based on Legal Text")
|
|
|
127 |
clear_button = gr.Button("Clear")
|
128 |
|
129 |
answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, max_new_tokens, temperature], outputs=answer_output)
|
130 |
+
clear_button.click(clear_fields, outputs=[legal_text_input_q, question_input, answer_output])
|
131 |
|
132 |
# The sliders are already defined, so just include them in the layout
|
133 |
with gr.Row():
|
134 |
# No need to redefine max_new_tokens and temperature here
|
135 |
pass
|
136 |
|
137 |
+
# Model warm-up (optional, but useful for performance)
|
138 |
+
model.generate(**tokenizer("Warm-up", return_tensors="pt"), max_length=10)
|
139 |
+
|
140 |
# Launch the Gradio interface
|
141 |
if __name__ == "__main__":
|
142 |
+
demo.launch()
|
143 |
+
|