Addaci commited on
Commit
40c679d
1 Parent(s): 251d490

Update app.py (major changes)

Browse files

1.0 Added input and output string limitations
2.0 Added skip_special_tokens
3.0 Simplified input format
3.1 For summarization, we're using the "summarize: " prefix, as mT5 is generally pre-trained with a task prefix like this.
3.2 For question-answering, we're combining the question and context in a way that aligns with the model's pre-training structure.
4. Improving Output Quality:
4.1 Beam search (num_beams=4) improves the diversity of generated sequences and often leads to better, more coherent results.
4.2 Early stopping prevents the model from generating overly long or repetitive sequences.
5.0 Model Performance Expectations:
5.1 Since mT5-small was not specifically fine-tuned for legal tasks (e.g., summarization of legal documents or answering legal questions), the pre-trained model might struggle with domain-specific terminology. 5.2 You might get better results by fine-tuning the model on a small subset of legal texts if the performance is unsatisfactory.

Files changed (1) hide show
  1. app.py +27 -39
app.py CHANGED
@@ -9,60 +9,48 @@ model = T5ForConditionalGeneration.from_pretrained(model_name)
9
 
10
  def correct_htr(raw_htr_text):
11
  # Tokenize the input text
12
- inputs = tokenizer(raw_htr_text, return_tensors="pt")
 
13
 
14
- # Generate corrected text
15
- outputs = model.generate(**inputs)
 
 
 
16
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
17
 
18
  return corrected_text
19
 
20
  def summarize_text(legal_text):
21
- # Tokenize the input text with summarization prompt
22
- inputs = tokenizer("summarize: " + legal_text, return_tensors="pt")
 
 
 
 
 
23
 
24
- # Generate summary
25
- outputs = model.generate(**inputs)
26
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
27
 
28
  return summary
29
 
30
  def answer_question(legal_text, question):
31
- # Combine context and question
32
- inputs = tokenizer(f"question: {question} context: {legal_text}", return_tensors="pt")
 
 
33
 
34
- # Generate answer
35
- outputs = model.generate(**inputs)
 
 
 
36
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
37
 
38
  return answer
39
 
40
- # Create the Gradio Blocks interface
41
- with gr.Blocks() as demo:
42
- gr.Markdown("# mT5 Legal Assistant")
43
- gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
44
-
45
- with gr.Tab("Correct HTR"):
46
- gr.Markdown("### Correct Raw HTR Text")
47
- raw_htr_input = gr.Textbox(lines=5, placeholder="Enter raw HTR text here...")
48
- corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
49
- correct_button = gr.Button("Correct HTR")
50
- correct_button.click(correct_htr, inputs=raw_htr_input, outputs=corrected_output)
51
-
52
- with gr.Tab("Summarize Legal Text"):
53
- gr.Markdown("### Summarize Legal Text")
54
- legal_text_input = gr.Textbox(lines=10, placeholder="Enter legal text to summarize...")
55
- summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
56
- summarize_button = gr.Button("Summarize Text")
57
- summarize_button.click(summarize_text, inputs=legal_text_input, outputs=summary_output)
58
-
59
- with gr.Tab("Answer Legal Question"):
60
- gr.Markdown("### Answer a Question Based on Legal Text")
61
- legal_text_input_q = gr.Textbox(lines=10, placeholder="Enter legal text...")
62
- question_input = gr.Textbox(lines=2, placeholder="Enter your question...")
63
- answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
64
- answer_button = gr.Button("Get Answer")
65
- answer_button.click(answer_question, inputs=[legal_text_input_q, question_input], outputs=answer_output)
66
-
67
- demo.launch()
68
 
 
9
 
10
  def correct_htr(raw_htr_text):
11
  # Tokenize the input text
12
+ inputs = tokenizer(raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
13
+ print("Tokenized Inputs for HTR Correction:", inputs) # Debugging
14
 
15
+ # Generate corrected text with max_length and beam search
16
+ outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True)
17
+ print("Generated Output (Tokens) for HTR Correction:", outputs) # Debugging
18
+
19
+ # Decode the output, skipping special tokens
20
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+ print("Decoded Output for HTR Correction:", corrected_text) # Debugging
22
 
23
  return corrected_text
24
 
25
  def summarize_text(legal_text):
26
+ # Tokenize the input text with the summarization prompt
27
+ inputs = tokenizer("summarize: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
28
+ print("Tokenized Inputs for Summarization:", inputs) # Debugging
29
+
30
+ # Generate summary with beam search for better results
31
+ outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True)
32
+ print("Generated Summary (Tokens):", outputs) # Debugging
33
 
34
+ # Decode the output, skipping special tokens
 
35
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ print("Decoded Summary:", summary) # Debugging
37
 
38
  return summary
39
 
40
  def answer_question(legal_text, question):
41
+ # Format input for question-answering
42
+ formatted_input = f"question: {question} context: {legal_text}"
43
+ inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
44
+ print("Tokenized Inputs for Question Answering:", inputs) # Debugging
45
 
46
+ # Generate answer using beam search
47
+ outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True)
48
+ print("Generated Answer (Tokens):", outputs) # Debugging
49
+
50
+ # Decode the output, skipping special tokens
51
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ print("Decoded Answer:", answer) # Debugging
53
 
54
  return answer
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56