Addaci commited on
Commit
325d895
·
verified ·
1 Parent(s): f82692d

Update app.py (reverted to MT%Tokenizer; added further logging function)

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -1,15 +1,15 @@
1
  import os
2
  import gradio as gr
3
  import logging
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  # Setup logging
7
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
8
 
9
  # Load your fine-tuned mT5 model
10
  model_name = "Addaci/mT5-small-experiment-13-checkpoint-2790"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
 
14
  def correct_htr(raw_htr_text):
15
  try:
@@ -17,12 +17,16 @@ def correct_htr(raw_htr_text):
17
  inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
18
  logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
19
 
20
- # Set do_sample=True
21
- outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6, do_sample=True)
22
  logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
23
 
24
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
  logging.debug(f"Decoded Output for HTR Correction: {corrected_text}")
 
 
 
 
26
  return corrected_text
27
  except Exception as e:
28
  logging.error(f"Error in HTR Correction: {e}", exc_info=True)
@@ -34,12 +38,16 @@ def summarize_text(legal_text):
34
  inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
35
  logging.debug(f"Tokenized Inputs for Summarization: {inputs}")
36
 
37
- # Set do_sample=True
38
- outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.8, do_sample=True)
39
  logging.debug(f"Generated Summary (Tokens): {outputs}")
40
 
41
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
  logging.debug(f"Decoded Summary: {summary}")
 
 
 
 
43
  return summary
44
  except Exception as e:
45
  logging.error(f"Error in Summarization: {e}", exc_info=True)
@@ -52,12 +60,16 @@ def answer_question(legal_text, question):
52
  inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
53
  logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
54
 
55
- # Set do_sample=True
56
- outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7, do_sample=True)
57
  logging.debug(f"Generated Answer (Tokens): {outputs}")
58
 
59
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
  logging.debug(f"Decoded Answer: {answer}")
 
 
 
 
61
  return answer
62
  except Exception as e:
63
  logging.error(f"Error in Question Answering: {e}", exc_info=True)
@@ -68,6 +80,7 @@ with gr.Blocks() as demo:
68
  gr.Markdown("# mT5 Legal Assistant")
69
  gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
70
 
 
71
  with gr.Row():
72
  gr.HTML('''
73
  <div style="display: flex; gap: 10px;">
@@ -116,4 +129,4 @@ with gr.Blocks() as demo:
116
  clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
117
 
118
  # Launch the Gradio interface
119
- demo.launch()
 
1
  import os
2
  import gradio as gr
3
  import logging
4
+ from transformers import MT5Tokenizer, MT5ForConditionalGeneration
5
 
6
  # Setup logging
7
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
8
 
9
  # Load your fine-tuned mT5 model
10
  model_name = "Addaci/mT5-small-experiment-13-checkpoint-2790"
11
+ tokenizer = MT5Tokenizer.from_pretrained(model_name)
12
+ model = MT5ForConditionalGeneration.from_pretrained(model_name)
13
 
14
  def correct_htr(raw_htr_text):
15
  try:
 
17
  inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
18
  logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
19
 
20
+ # Generate with beam search and sampling
21
+ outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6, do_sample=True)
22
  logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
23
 
24
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
  logging.debug(f"Decoded Output for HTR Correction: {corrected_text}")
26
+
27
+ # Re-tokenize the output for further inspection
28
+ logging.debug(f"Re-tokenized output for HTR Correction: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
29
+
30
  return corrected_text
31
  except Exception as e:
32
  logging.error(f"Error in HTR Correction: {e}", exc_info=True)
 
38
  inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
39
  logging.debug(f"Tokenized Inputs for Summarization: {inputs}")
40
 
41
+ # Generate with beam search and sampling
42
+ outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.8, do_sample=True)
43
  logging.debug(f"Generated Summary (Tokens): {outputs}")
44
 
45
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
  logging.debug(f"Decoded Summary: {summary}")
47
+
48
+ # Re-tokenize the output for further inspection
49
+ logging.debug(f"Re-tokenized output for Summarization: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
50
+
51
  return summary
52
  except Exception as e:
53
  logging.error(f"Error in Summarization: {e}", exc_info=True)
 
60
  inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
61
  logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
62
 
63
+ # Generate with beam search and sampling
64
+ outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7, do_sample=True)
65
  logging.debug(f"Generated Answer (Tokens): {outputs}")
66
 
67
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
  logging.debug(f"Decoded Answer: {answer}")
69
+
70
+ # Re-tokenize the output for further inspection
71
+ logging.debug(f"Re-tokenized output for Question Answering: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
72
+
73
  return answer
74
  except Exception as e:
75
  logging.error(f"Error in Question Answering: {e}", exc_info=True)
 
80
  gr.Markdown("# mT5 Legal Assistant")
81
  gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
82
 
83
+ # Add the two clickable buttons with separate boxes and bold text
84
  with gr.Row():
85
  gr.HTML('''
86
  <div style="display: flex; gap: 10px;">
 
129
  clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
130
 
131
  # Launch the Gradio interface
132
+ demo.launch()