Spaces:
Runtime error
Runtime error
import gc | |
from threading import Thread | |
import torch | |
import transformers | |
from transformers import ( | |
GenerationConfig, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
TextIteratorStreamer, | |
) | |
def generate_stream_codet5p( | |
model, | |
tokenizer, | |
params, | |
device, | |
context_len=2048, | |
stream_interval=2, | |
judge_sent_end=False, | |
): | |
prompt = params["prompt"] | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
top_k = int(params.get("top_k", 50)) # -1 means disable | |
max_new_tokens = int(params.get("max_new_tokens", 1024)) | |
stop_token_ids = params.get("stop_token_ids", None) or [] | |
stop_token_ids.append(tokenizer.eos_token_id) | |
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
streamer = TextIteratorStreamer(tokenizer, **decode_config) | |
encoding = tokenizer(prompt, return_tensors="pt").to(device) | |
input_ids = encoding.input_ids | |
encoding["decoder_input_ids"] = encoding["input_ids"].clone() | |
input_echo_len = len(input_ids) | |
generation_config = GenerationConfig( | |
max_new_tokens=max_new_tokens, | |
do_sample=temperature >= 1e-5, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=10, | |
top_p=top_p, | |
top_k=top_k, | |
eos_token_id=stop_token_ids, | |
) | |
class CodeBlockStopper(StoppingCriteria): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
) -> bool: | |
# Code-completion is open-end generation. | |
# We check \n\n to stop at end of a code block. | |
if list(input_ids[0][-2:]) == [628, 198]: | |
return True | |
return False | |
gen_kwargs = dict( | |
**encoding, | |
streamer=streamer, | |
generation_config=generation_config, | |
stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), | |
) | |
thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
thread.start() | |
i = 0 | |
output = "" | |
for new_text in streamer: | |
i += 1 | |
output += new_text | |
if i % stream_interval == 0 or i == max_new_tokens - 1: | |
yield { | |
"text": output, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
"finish_reason": None, | |
} | |
if i >= max_new_tokens: | |
break | |
if i >= max_new_tokens: | |
finish_reason = "length" | |
else: | |
finish_reason = "stop" | |
yield { | |
"text": output, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
"finish_reason": finish_reason, | |
} | |
thread.join() | |
# clean | |
gc.collect() | |
torch.cuda.empty_cache() | |
if device == "xpu": | |
torch.xpu.empty_cache() | |
if device == "npu": | |
torch.npu.empty_cache() | |