Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import threading | |
import arrow | |
import time | |
import argparse | |
import logging | |
from dataclasses import dataclass | |
import torch | |
import sentencepiece as spm | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers import GPTNeoXForCausalLM, GPTNeoXConfig | |
from transformers.generation.streamers import BaseStreamer | |
from huggingface_hub import hf_hub_download, login | |
logger = logging.getLogger() | |
logger.setLevel("INFO") | |
gr_interface = None | |
VERSION = "0.1.0" | |
class DefaultArgs: | |
hf_model_name_or_path: str = None | |
hf_tokenizer_name_or_path: str = None | |
spm_model_path: str = None | |
env: str = "dev" | |
port: int = 7860 | |
make_public: bool = False | |
if os.getenv("RUNNING_ON_HF_SPACE"): | |
login(token=os.getenv("HF_TOKEN")) | |
hf_repo = os.getenv("HF_MODEL_REPO") | |
args = DefaultArgs() | |
args.hf_model_name_or_path = hf_repo | |
args.hf_tokenizer_name_or_path = os.path.join(hf_repo, "tokenizer") | |
args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model") | |
else: | |
parser = argparse.ArgumentParser(description="") | |
parser.add_argument("--hf_model_name_or_path", type=str, required=True) | |
parser.add_argument("--hf_tokenizer_name_or_path", type=str, required=False) | |
parser.add_argument("--spm_model_path", type=str, required=True) | |
parser.add_argument("--env", type=str, default="dev") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--make_public", action='store_true') | |
args = parser.parse_args() | |
def load_model( | |
model_dir, | |
): | |
config = GPTNeoXConfig.from_pretrained(model_dir) | |
config.is_decoder = True | |
model = GPTNeoXForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16) | |
if torch.cuda.is_available(): | |
model = model.to("cuda:0") | |
return model | |
logging.info("Loading model") | |
model = load_model(args.hf_model_name_or_path) | |
sp = spm.SentencePieceProcessor(model_file=args.spm_model_path) | |
logging.info("Finished loading model") | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.hf_model_name_or_path, | |
subfolder="tokenizer", | |
use_fast=False | |
) | |
class TokenizerStreamer(BaseStreamer): | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.num_invoked = 0 | |
self.prompt = "" | |
self.generated_text = "" | |
self.ended = False | |
def put(self, t: torch.Tensor): | |
d = t.dim() | |
if d == 1: | |
pass | |
elif d == 2: | |
t = t[0] | |
else: | |
raise NotImplementedError | |
t = [int(x) for x in t.numpy()] | |
text = tokenizer.decode(t) | |
if text in [tokenizer.bos_token, tokenizer.eos_token]: | |
text = "" | |
if self.num_invoked == 0: | |
self.prompt = text | |
self.num_invoked += 1 | |
return | |
self.generated_text += text | |
logging.debug(f"[streamer]: {self.generated_text}") | |
def end(self): | |
self.ended = True | |
INPUT_PROMPT = """ไปฅไธใฏใใฟในใฏใ่ชฌๆใใๆ็คบใจใๆ่ใฎใใๅ ฅๅใฎ็ตใฟๅใใใงใใ่ฆๆฑใ้ฉๅใซๆบใใๅฟ็ญใๆธใใชใใใ | |
### ๆ็คบ: | |
{instruction} | |
### ๅ ฅๅ: | |
{input} | |
### ๅฟ็ญ: """ | |
NO_INPUT_PROMPT = """ไปฅไธใฏใใฟในใฏใ่ชฌๆใใๆ็คบใจใๆ่ใฎใใๅ ฅๅใฎ็ตใฟๅใใใงใใ่ฆๆฑใ้ฉๅใซๆบใใๅฟ็ญใๆธใใชใใใ | |
### ๆ็คบ: | |
{instruction} | |
### ๅฟ็ญ: """ | |
def postprocess_output(output): | |
output = output\ | |
.split('### ๅฟ็ญ:')[1]\ | |
.split('###')[0]\ | |
.split('##')[0]\ | |
.lstrip(tokenizer.bos_token)\ | |
.rstrip(tokenizer.eos_token)\ | |
.replace("###", "")\ | |
.strip() | |
return output | |
def generate( | |
prompt, | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
): | |
log = dict(locals()) | |
logging.debug(log) | |
input_text = NO_INPUT_PROMPT.format(instruction=prompt) | |
input_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt") | |
streamer = TokenizerStreamer(tokenizer=tokenizer) | |
max_possilbe_new_tokens = model.config.max_position_embeddings - input_ids.shape[0] | |
max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) | |
thr = threading.Thread(target=model.generate, args=(), kwargs=dict( | |
input_ids=input_ids.to(model.device), | |
do_sample=do_sample, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_new_tokens=max_possilbe_new_tokens, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
bad_words_ids=[[tokenizer.unk_token_id]], | |
streamer=streamer, | |
)) | |
thr.start() | |
while not streamer.ended: | |
time.sleep(0.05) | |
yield streamer.generated_text | |
# TODO: optimize for final few tokens | |
gen = streamer.generated_text | |
log.update(dict( | |
generation=gen, | |
version=VERSION, | |
time=str(arrow.now("+09:00")))) | |
logging.info(log) | |
yield gen | |
def process_feedback( | |
rating, | |
prompt, | |
generation, | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
): | |
log = dict(locals()) | |
log.update(dict( | |
time=str(arrow.now("+09:00")), | |
version=VERSION, | |
)) | |
logging.info(log) | |
if gr_interface: | |
gr_interface.close(verbose=False) | |
with gr.Blocks() as gr_interface: | |
with gr.Row(): | |
gr.Markdown(f"# ๆฅๆฌ่ช StableLM Tuned Pre-Alpha ({VERSION})") | |
# gr.Markdown(f"ใใผใธใงใณ๏ผ{VERSION}") | |
with gr.Row(): | |
gr.Markdown("ใใฎ่จ่ชใขใใซใฏ Stability AI Japan ใ้็บใใๅๆใใผใธใงใณใฎๆฅๆฌ่ชใขใใซใงใใใขใใซใฏใใใญใณใใใใซๅ ฅๅใใ่ใใใใใจใซๅฏพใใฆใใใใใใๅฟ็ญใใใใใจใใงใใพใใ") | |
with gr.Row(): | |
# left panel | |
with gr.Column(scale=1): | |
# generation params | |
with gr.Box(): | |
gr.Markdown("ใใฉใกใผใฟ") | |
# hidden default params | |
do_sample = gr.Checkbox(True, label="Do Sample", info="ใตใณใใชใณใฐ็ๆ", visible=True) | |
no_repeat_ngram_size = gr.Slider(0, 10, value=3, step=1, label="No Repeat Ngram Size", visible=False) | |
# visible params | |
max_new_tokens = gr.Slider( | |
128, | |
min(512, model.config.max_position_embeddings), | |
value=128, | |
step=128, | |
label="max tokens", | |
info="็ๆใใใใผใฏใณใฎๆๅคงๆฐใๆๅฎใใ", | |
) | |
temperature = gr.Slider( | |
0, 1, value=0.1, step=0.05, label="temperature", | |
info="ไฝใๅคใฏๅบๅใใใ้ไธญใใใฆๆฑบๅฎ่ซ็ใซใใ") | |
repetition_penalty = gr.Slider( | |
1, 1.5, value=1.2, step=0.05, label="frequency penalty", | |
info="้ซใๅคใฏAIใ็นฐใ่ฟใๅฏ่ฝๆงใๆธๅฐใใใ") | |
# grouping params for easier reference | |
gr_params = [ | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
] | |
# right panel | |
with gr.Column(scale=2): | |
# user input block | |
with gr.Box(): | |
textbox_prompt = gr.Textbox( | |
label="ใใญใณใใ", | |
placeholder="ๆฅๆฌใฎ้ฆ้ฝใฏ๏ผ", | |
interactive=True, | |
lines=5, | |
value="" | |
) | |
with gr.Box(): | |
with gr.Row(): | |
btn_stop = gr.Button(value="ใญใฃใณใปใซ", variant="secondary") | |
btn_submit = gr.Button(value="ๅฎ่ก", variant="primary") | |
# model output block | |
with gr.Box(): | |
textbox_generation = gr.Textbox( | |
label="็ๆ็ตๆ", | |
lines=5, | |
value="" | |
) | |
# rating block | |
with gr.Row(): | |
gr.Markdown("ใใ่ฏใ่จ่ชใขใใซใ็ๆงใซๆไพใงใใใใใ็ๆๅ่ณชใซใคใใฆใฎใๆ่ฆใใ่ใใใใ ใใใ") | |
with gr.Box(): | |
with gr.Row(): | |
rating_options = [ | |
"ๆๆช", | |
"ไธๅๆ ผ", | |
"ไธญ็ซ", | |
"ๅๆ ผ", | |
"ๆ้ซ", | |
] | |
btn_ratings = [gr.Button(value=v) for v in rating_options] | |
# TODO: we might not need this for sharing with close groups | |
# with gr.Box(): | |
# gr.Markdown("TODO๏ผFor more feedback link for google form") | |
# event handling | |
inputs = [textbox_prompt] + gr_params | |
click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) | |
btn_stop.click(None, None, None, cancels=click_event, queue=False) | |
for btn_rating in btn_ratings: | |
btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False) | |
gr_interface.queue(max_size=32, concurrency_count=2) | |
gr_interface.launch(server_port=args.port, share=args.make_public) | |