Spaces:
Runtime error
Runtime error
import re | |
import copy | |
import global_vars | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
from transformers import GenerationConfig | |
def contains_image_markdown(string): | |
regex = re.compile(r'!\[(.*?)\]\((.*?)\)') | |
match = regex.search(string) | |
return match | |
def build_model_inputs(prompt, return_token_type_ids): | |
model_inputs = global_vars.tokenizer( | |
[prompt], | |
return_tensors="pt", | |
return_token_type_ids=return_token_type_ids | |
).to(global_vars.device) | |
return model_inputs | |
def build_streamer( | |
timeout=20., | |
skip_prompt=True, | |
skip_special_tokens=True | |
): | |
if global_vars.device == "cpu" or \ | |
global_vars.device == "mps": | |
timeout=100000. | |
print(f"timeout set to {timeout}") | |
streamer = TextIteratorStreamer( | |
global_vars.tokenizer, | |
timeout=timeout, | |
skip_prompt=skip_prompt, | |
skip_special_tokens=skip_special_tokens | |
) | |
return streamer | |
def build_gen_config( | |
temperature, top_p, top_k, repetition_penalty, max_new_tokens, | |
num_beams, use_cache, do_sample, eos_token_id, pad_token_id | |
): | |
gen_config_raw = { | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty, | |
"max_new_tokens": max_new_tokens, | |
"num_beams": num_beams, | |
"use_cache": use_cache, | |
"do_sample": do_sample, | |
"eos_token_id": eos_token_id, | |
"pad_token_id": pad_token_id | |
} | |
return gen_config_raw, GenerationConfig(**gen_config_raw) | |
def build_gen_kwargs( | |
gen_config, | |
model_inputs, | |
streamer, | |
stopping_criteria | |
): | |
gen_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
stopping_criteria=stopping_criteria | |
) | |
gen_kwargs.update(gen_config) | |
return gen_kwargs | |
def start_gen(gen_kwargs): | |
t = Thread( | |
target=global_vars.stream_model.generate, | |
kwargs=gen_kwargs | |
) | |
t.start() | |
def build( | |
prompt, | |
temperature, top_p, top_k, repetition_penalty, max_new_tokens, | |
num_beams, use_cache, do_sample, eos_token_id, pad_token_id, | |
stopping_criteria=None, return_token_type_ids=True | |
): | |
gen_config_raw, _ = build_gen_config( | |
temperature, top_p, top_k, repetition_penalty, max_new_tokens, | |
num_beams, use_cache, do_sample, eos_token_id, pad_token_id | |
) | |
model_inputs = build_model_inputs( | |
prompt, return_token_type_ids=return_token_type_ids | |
) | |
streamer = build_streamer() | |
gen_kwargs = build_gen_kwargs( | |
gen_config_raw, | |
model_inputs, | |
streamer, | |
stopping_criteria | |
) | |
return gen_kwargs, streamer |