Update app.py (added do_sample=True)
Browse filesThe Problem: do_sample is False
The log file shows this warning:
/home/user/.pyenv/versions/3.10.15/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
warnings.warn(
This warning indicates that you're using the temperature parameter in model.generate(), but the do_sample argument is set to False (which is the default).
do_sample=False: This means the model is using greedy decoding, where it always selects the most likely token at each step. This leads to deterministic and often repetitive outputs, even with a non-zero temperature.
temperature: The temperature parameter is only effective when do_sample=True, as it's used to scale the probabilities of the tokens before sampling.
The Solution: Set do_sample=True
@@ -17,7 +17,8 @@ def correct_htr(raw_htr_text):
|
|
17 |
inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
|
18 |
logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
|
19 |
|
20 |
-
|
|
|
21 |
logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
|
22 |
|
23 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
@@ -27,7 +28,7 @@ def correct_htr(raw_htr_text):
|
|
27 |
logging.error(f"Error in HTR Correction: {e}", exc_info=True)
|
28 |
return str(e)
|
29 |
|
30 |
-
|
31 |
try:
|
32 |
logging.info("Processing summarization...")
|
33 |
inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
|
@@ -51,7 +52,8 @@ def answer_question(legal_text, question):
|
|
51 |
inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
|
52 |
logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
|
53 |
|
54 |
-
|
|
|
55 |
logging.debug(f"Generated Answer (Tokens): {outputs}")
|
56 |
|
57 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
17 |
inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
|
18 |
logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
|
19 |
|
20 |
+
# Set do_sample=True
|
21 |
+
outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6, do_sample=True)
|
22 |
logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
|
23 |
|
24 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
28 |
logging.error(f"Error in HTR Correction: {e}", exc_info=True)
|
29 |
return str(e)
|
30 |
|
31 |
+
def summarize_text(legal_text):
|
32 |
try:
|
33 |
logging.info("Processing summarization...")
|
34 |
inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
|
|
|
52 |
inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
|
53 |
logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
|
54 |
|
55 |
+
# Set do_sample=True
|
56 |
+
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7, do_sample=True)
|
57 |
logging.debug(f"Generated Answer (Tokens): {outputs}")
|
58 |
|
59 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|