Update app.py (reverted to MT%Tokenizer; added further logging function)
Browse files
app.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import logging
|
4 |
-
from transformers import
|
5 |
|
6 |
# Setup logging
|
7 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
8 |
|
9 |
# Load your fine-tuned mT5 model
|
10 |
model_name = "Addaci/mT5-small-experiment-13-checkpoint-2790"
|
11 |
-
tokenizer =
|
12 |
-
model =
|
13 |
|
14 |
def correct_htr(raw_htr_text):
|
15 |
try:
|
@@ -17,12 +17,16 @@ def correct_htr(raw_htr_text):
|
|
17 |
inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
|
18 |
logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
|
19 |
|
20 |
-
#
|
21 |
-
outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6, do_sample=True)
|
22 |
logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
|
23 |
|
24 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
25 |
logging.debug(f"Decoded Output for HTR Correction: {corrected_text}")
|
|
|
|
|
|
|
|
|
26 |
return corrected_text
|
27 |
except Exception as e:
|
28 |
logging.error(f"Error in HTR Correction: {e}", exc_info=True)
|
@@ -34,12 +38,16 @@ def summarize_text(legal_text):
|
|
34 |
inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
|
35 |
logging.debug(f"Tokenized Inputs for Summarization: {inputs}")
|
36 |
|
37 |
-
#
|
38 |
-
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.8, do_sample=True)
|
39 |
logging.debug(f"Generated Summary (Tokens): {outputs}")
|
40 |
|
41 |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
42 |
logging.debug(f"Decoded Summary: {summary}")
|
|
|
|
|
|
|
|
|
43 |
return summary
|
44 |
except Exception as e:
|
45 |
logging.error(f"Error in Summarization: {e}", exc_info=True)
|
@@ -52,12 +60,16 @@ def answer_question(legal_text, question):
|
|
52 |
inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
|
53 |
logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
|
54 |
|
55 |
-
#
|
56 |
-
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7, do_sample=True)
|
57 |
logging.debug(f"Generated Answer (Tokens): {outputs}")
|
58 |
|
59 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
60 |
logging.debug(f"Decoded Answer: {answer}")
|
|
|
|
|
|
|
|
|
61 |
return answer
|
62 |
except Exception as e:
|
63 |
logging.error(f"Error in Question Answering: {e}", exc_info=True)
|
@@ -68,6 +80,7 @@ with gr.Blocks() as demo:
|
|
68 |
gr.Markdown("# mT5 Legal Assistant")
|
69 |
gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
|
70 |
|
|
|
71 |
with gr.Row():
|
72 |
gr.HTML('''
|
73 |
<div style="display: flex; gap: 10px;">
|
@@ -116,4 +129,4 @@ with gr.Blocks() as demo:
|
|
116 |
clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
|
117 |
|
118 |
# Launch the Gradio interface
|
119 |
-
demo.launch()
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import logging
|
4 |
+
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
|
5 |
|
6 |
# Setup logging
|
7 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
8 |
|
9 |
# Load your fine-tuned mT5 model
|
10 |
model_name = "Addaci/mT5-small-experiment-13-checkpoint-2790"
|
11 |
+
tokenizer = MT5Tokenizer.from_pretrained(model_name)
|
12 |
+
model = MT5ForConditionalGeneration.from_pretrained(model_name)
|
13 |
|
14 |
def correct_htr(raw_htr_text):
|
15 |
try:
|
|
|
17 |
inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
|
18 |
logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
|
19 |
|
20 |
+
# Generate with beam search and sampling
|
21 |
+
outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6, do_sample=True)
|
22 |
logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
|
23 |
|
24 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
25 |
logging.debug(f"Decoded Output for HTR Correction: {corrected_text}")
|
26 |
+
|
27 |
+
# Re-tokenize the output for further inspection
|
28 |
+
logging.debug(f"Re-tokenized output for HTR Correction: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
|
29 |
+
|
30 |
return corrected_text
|
31 |
except Exception as e:
|
32 |
logging.error(f"Error in HTR Correction: {e}", exc_info=True)
|
|
|
38 |
inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
|
39 |
logging.debug(f"Tokenized Inputs for Summarization: {inputs}")
|
40 |
|
41 |
+
# Generate with beam search and sampling
|
42 |
+
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.8, do_sample=True)
|
43 |
logging.debug(f"Generated Summary (Tokens): {outputs}")
|
44 |
|
45 |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
46 |
logging.debug(f"Decoded Summary: {summary}")
|
47 |
+
|
48 |
+
# Re-tokenize the output for further inspection
|
49 |
+
logging.debug(f"Re-tokenized output for Summarization: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
|
50 |
+
|
51 |
return summary
|
52 |
except Exception as e:
|
53 |
logging.error(f"Error in Summarization: {e}", exc_info=True)
|
|
|
60 |
inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
|
61 |
logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
|
62 |
|
63 |
+
# Generate with beam search and sampling
|
64 |
+
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7, do_sample=True)
|
65 |
logging.debug(f"Generated Answer (Tokens): {outputs}")
|
66 |
|
67 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
68 |
logging.debug(f"Decoded Answer: {answer}")
|
69 |
+
|
70 |
+
# Re-tokenize the output for further inspection
|
71 |
+
logging.debug(f"Re-tokenized output for Question Answering: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
|
72 |
+
|
73 |
return answer
|
74 |
except Exception as e:
|
75 |
logging.error(f"Error in Question Answering: {e}", exc_info=True)
|
|
|
80 |
gr.Markdown("# mT5 Legal Assistant")
|
81 |
gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
|
82 |
|
83 |
+
# Add the two clickable buttons with separate boxes and bold text
|
84 |
with gr.Row():
|
85 |
gr.HTML('''
|
86 |
<div style="display: flex; gap: 10px;">
|
|
|
129 |
clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])
|
130 |
|
131 |
# Launch the Gradio interface
|
132 |
+
demo.launch()
|