Addaci commited on
Commit
f82692d
·
verified ·
1 Parent(s): 328a504

Update app.py (added do_sample=True)

Browse files

The Problem: do_sample is False

The log file shows this warning:

/home/user/.pyenv/versions/3.10.15/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
warnings.warn(  

This warning indicates that you're using the temperature parameter in model.generate(), but the do_sample argument is set to False (which is the default).

do_sample=False: This means the model is using greedy decoding, where it always selects the most likely token at each step. This leads to deterministic and often repetitive outputs, even with a non-zero temperature.
temperature: The temperature parameter is only effective when do_sample=True, as it's used to scale the probabilities of the tokens before sampling.
The Solution: Set do_sample=True

Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -17,7 +17,8 @@ def correct_htr(raw_htr_text):
17
  inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
18
  logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
19
 
20
- outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6)
 
21
  logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
22
 
23
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -27,7 +28,7 @@ def correct_htr(raw_htr_text):
27
  logging.error(f"Error in HTR Correction: {e}", exc_info=True)
28
  return str(e)
29
 
30
- def summarize_text(legal_text):
31
  try:
32
  logging.info("Processing summarization...")
33
  inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
@@ -51,7 +52,8 @@ def answer_question(legal_text, question):
51
  inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
52
  logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
53
 
54
- outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7)
 
55
  logging.debug(f"Generated Answer (Tokens): {outputs}")
56
 
57
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
17
  inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
18
  logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
19
 
20
+ # Set do_sample=True
21
+ outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6, do_sample=True)
22
  logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
23
 
24
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
28
  logging.error(f"Error in HTR Correction: {e}", exc_info=True)
29
  return str(e)
30
 
31
+ def summarize_text(legal_text):
32
  try:
33
  logging.info("Processing summarization...")
34
  inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
 
52
  inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
53
  logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
54
 
55
+ # Set do_sample=True
56
+ outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7, do_sample=True)
57
  logging.debug(f"Generated Answer (Tokens): {outputs}")
58
 
59
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)