Addaci commited on
Commit
acc8063
1 Parent(s): 9519c42

Major rewrite of app.py

Browse files

Rewrote the app.py file to apply a three tab, two button template from a different Hugging Face space
Added slide baars at bottom for text output size and temperature, with base values, and minimum and maxim,um values
Changed model to google/flan-t5-xl
This model, according to Gemini Pro: "is known for its strong performance in a wide range of tasks, including summarization and question answering"

Gemini also suggested looking at:

BART (facebook/bart-large-cnn): BART is another excellent choice, especially for summarization tasks.
LongT5 (google/long-t5-tglobal-xl): If you're dealing with longer legal texts, LongT5 might be a good option due to its ability to handle longer input sequences.

Files changed (1) hide show
  1. app.py +43 -33
app.py CHANGED
@@ -1,53 +1,58 @@
1
- # Cell 1B: Inference Client
2
-
3
  import gradio as gr
4
- from huggingface_hub import InferenceClient
5
  import logging
6
 
7
- # Setup logging
8
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
9
 
10
- # Initialize Inference Client
11
- client = InferenceClient(model="Addaci/mT5-small-experiment-13-checkpoint-2790")
 
 
12
 
13
- def correct_htr(raw_htr_text):
14
  try:
15
- logging.info("Processing HTR correction with InferenceClient...")
16
- # Sending the input to the hosted model
17
- result = client.text_generation(f"correct this text: {raw_htr_text}")
18
- logging.debug(f"Generated output for HTR correction: {result}")
19
- return result['generated_text'] # Extracting the generated text from the response
 
 
20
  except Exception as e:
21
  logging.error(f"Error in HTR correction: {e}", exc_info=True)
22
  return str(e)
23
 
24
- def summarize_text(legal_text):
25
  try:
26
- logging.info("Processing summarization with InferenceClient...")
27
- # Sending the input to the hosted model
28
- result = client.text_generation(f"summarize the following legal text: {legal_text}")
29
- logging.debug(f"Generated summary: {result}")
30
- return result['generated_text'] # Extracting the generated text from the response
 
 
31
  except Exception as e:
32
  logging.error(f"Error in summarization: {e}", exc_info=True)
33
  return str(e)
34
 
35
- def answer_question(legal_text, question):
36
  try:
37
- logging.info("Processing question-answering with InferenceClient...")
38
- # Sending the input to the hosted model
39
- formatted_input = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
40
- result = client.text_generation(formatted_input)
41
- logging.debug(f"Generated answer: {result}")
42
- return result['generated_text'] # Extracting the generated text from the response
 
43
  except Exception as e:
44
  logging.error(f"Error in question-answering: {e}", exc_info=True)
45
  return str(e)
46
 
47
  # Create the Gradio Blocks interface
48
  with gr.Blocks() as demo:
49
- gr.Markdown("# mT5 Legal Assistant")
50
- gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
51
 
52
  with gr.Row():
53
  gr.HTML('''
@@ -71,8 +76,8 @@ with gr.Blocks() as demo:
71
  corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
72
  correct_button = gr.Button("Correct HTR")
73
  clear_button = gr.Button("Clear")
74
-
75
- correct_button.click(correct_htr, inputs=raw_htr_input, outputs=corrected_output)
76
  clear_button.click(lambda: ("", ""), outputs=[raw_htr_input, corrected_output])
77
 
78
  with gr.Tab("Summarize Legal Text"):
@@ -81,8 +86,8 @@ with gr.Blocks() as demo:
81
  summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
82
  summarize_button = gr.Button("Summarize Text")
83
  clear_button = gr.Button("Clear")
84
-
85
- summarize_button.click(summarize_text, inputs=legal_text_input, outputs=summary_output)
86
  clear_button.click(lambda: ("", ""), outputs=[legal_text_input, summary_output])
87
 
88
  with gr.Tab("Answer Legal Question"):
@@ -92,9 +97,14 @@ with gr.Blocks() as demo:
92
  answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
93
  answer_button = gr.Button("Get Answer")
94
  clear_button = gr.Button("Clear")
95
-
96
- answer_button.click(answer_question, inputs=[legal_text_input_q, question_input], outputs=answer_output)
97
  clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
98
 
 
 
 
 
 
99
  # Launch the Gradio interface
100
  demo.launch()
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import logging
4
 
5
+ # Setup logging (optional, but helpful for debugging)
6
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
7
 
8
+ # Load the Flan-T5 model and tokenizer
9
+ model_id = "google/flan-t5-xl"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
12
 
13
+ def correct_htr(raw_htr_text, max_new_tokens, temperature):
14
  try:
15
+ logging.info("Processing HTR correction with Flan-T5...")
16
+ prompt = f"Correct this text: {raw_htr_text}"
17
+ inputs = tokenizer(prompt, return_tensors="pt")
18
+ outputs = model.generate(**inputs, max_length=max_new_tokens, temperature=temperature)
19
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
+ logging.debug(f"Generated output for HTR correction: {corrected_text}")
21
+ return corrected_text
22
  except Exception as e:
23
  logging.error(f"Error in HTR correction: {e}", exc_info=True)
24
  return str(e)
25
 
26
+ def summarize_text(legal_text, max_new_tokens, temperature):
27
  try:
28
+ logging.info("Processing summarization with Flan-T5...")
29
+ prompt = f"Summarize the following legal text: {legal_text}"
30
+ inputs = tokenizer(prompt, return_tensors="pt")
31
+ outputs = model.generate(**inputs, max_length=max_new_tokens, temperature=temperature)
32
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+ logging.debug(f"Generated summary: {summary}")
34
+ return summary
35
  except Exception as e:
36
  logging.error(f"Error in summarization: {e}", exc_info=True)
37
  return str(e)
38
 
39
+ def answer_question(legal_text, question, max_new_tokens, temperature):
40
  try:
41
+ logging.info("Processing question-answering with Flan-T5...")
42
+ prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
43
+ inputs = tokenizer(prompt, return_tensors="pt")
44
+ outputs = model.generate(**inputs, max_length=max_new_tokens, temperature=temperature)
45
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ logging.debug(f"Generated answer: {answer}")
47
+ return answer
48
  except Exception as e:
49
  logging.error(f"Error in question-answering: {e}", exc_info=True)
50
  return str(e)
51
 
52
  # Create the Gradio Blocks interface
53
  with gr.Blocks() as demo:
54
+ gr.Markdown("# Flan-T5 Legal Assistant")
55
+ gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5).")
56
 
57
  with gr.Row():
58
  gr.HTML('''
 
76
  corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
77
  correct_button = gr.Button("Correct HTR")
78
  clear_button = gr.Button("Clear")
79
+
80
+ correct_button.click(correct_htr, inputs=[raw_htr_input, max_new_tokens, temperature], outputs=corrected_output)
81
  clear_button.click(lambda: ("", ""), outputs=[raw_htr_input, corrected_output])
82
 
83
  with gr.Tab("Summarize Legal Text"):
 
86
  summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
87
  summarize_button = gr.Button("Summarize Text")
88
  clear_button = gr.Button("Clear")
89
+
90
+ summarize_button.click(summarize_text, inputs=[legal_text_input, max_new_tokens, temperature], outputs=summary_output)
91
  clear_button.click(lambda: ("", ""), outputs=[legal_text_input, summary_output])
92
 
93
  with gr.Tab("Answer Legal Question"):
 
97
  answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
98
  answer_button = gr.Button("Get Answer")
99
  clear_button = gr.Button("Clear")
100
+
101
+ answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, max_new_tokens, temperature], outputs=answer_output)
102
  clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
103
 
104
+ # Add sliders for hyperparameters
105
+ with gr.Row():
106
+ max_new_tokens = gr.Slider(minimum=10, maximum=1000, value=500, step=1, label="Max New Tokens")
107
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
108
+
109
  # Launch the Gradio interface
110
  demo.launch()