Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import threading | |
import arrow | |
import time | |
import argparse | |
import logging | |
from dataclasses import dataclass | |
import torch | |
import sentencepiece as spm | |
from transformers import GPTNeoXForCausalLM, GPTNeoXConfig | |
from transformers.generation.streamers import BaseStreamer | |
from huggingface_hub import hf_hub_download, login | |
logger = logging.getLogger() | |
logger.setLevel("INFO") | |
gr_interface = None | |
class DefaultArgs: | |
hf_model_name_or_path: str = None | |
spm_model_path: str = None | |
env: str = "dev" | |
port: int = 7860 | |
make_public: bool = False | |
if os.getenv("RUNNING_ON_HF_SPACE"): | |
login(token=os.getenv("HF_TOKEN")) | |
hf_repo = os.getenv("HF_MODEL_REPO") | |
args = DefaultArgs() | |
args.hf_model_name_or_path = hf_repo | |
args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model") | |
else: | |
parser = argparse.ArgumentParser(description="") | |
parser.add_argument("--hf_model_name_or_path", type=str, required=True) | |
parser.add_argument("--spm_model_path", type=str, required=True) | |
parser.add_argument("--env", type=str, default="dev") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--make_public", action='store_true') | |
args = parser.parse_args() | |
def load_model( | |
model_dir, | |
): | |
config = GPTNeoXConfig.from_pretrained(model_dir) | |
config.is_decoder = True | |
model = GPTNeoXForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16) | |
if torch.cuda.is_available(): | |
model = model.to("cuda:0") | |
return model | |
logging.info("Loading model") | |
model = load_model(args.hf_model_name_or_path) | |
sp = spm.SentencePieceProcessor(model_file=args.spm_model_path) | |
logging.info("Finished loading model") | |
class SentencePieceStreamer(BaseStreamer): | |
def __init__(self, sp: spm.SentencePieceProcessor): | |
self.sp = sp | |
self.num_invoked = 0 | |
self.prompt = "" | |
self.generated_text = "" | |
self.ended = False | |
def put(self, t: torch.Tensor): | |
d = t.dim() | |
if d == 1: | |
pass | |
elif d == 2: | |
t = t[0] | |
else: | |
raise NotImplementedError | |
t = [int(x) for x in t.numpy()] | |
text = self.sp.decode_ids(t) | |
if self.num_invoked == 0: | |
self.prompt = text | |
self.num_invoked += 1 | |
return | |
self.generated_text += text | |
logging.debug(f"[streamer]: {self.generated_text}") | |
def end(self): | |
self.ended = True | |
def generate( | |
prompt, | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
): | |
log = dict(locals()) | |
logging.debug(log) | |
tokens = sp.encode(prompt) | |
input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device) | |
streamer = SentencePieceStreamer(sp=sp) | |
max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens) | |
max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) | |
thr = threading.Thread(target=model.generate, args=(), kwargs=dict( | |
input_ids=input_ids, | |
do_sample=do_sample, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_new_tokens=max_possilbe_new_tokens, | |
streamer=streamer, | |
# max_length=4096, | |
# top_k=100, | |
# top_p=0.9, | |
# num_return_sequences=2, | |
# num_beams=2, | |
)) | |
thr.start() | |
while not streamer.ended: | |
time.sleep(0.05) | |
yield streamer.generated_text | |
# TODO: optimize for final few tokens | |
gen = streamer.generated_text | |
log.update(dict(generation=gen, time=str(arrow.now("+09:00")))) | |
logging.info(log) | |
yield gen | |
def process_feedback( | |
rating, | |
prompt, | |
generation, | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
): | |
log = dict(locals()) | |
log["time"] = str(arrow.now("+09:00")) | |
logging.info(log) | |
if gr_interface: | |
gr_interface.close(verbose=False) | |
with gr.Blocks() as gr_interface: | |
with gr.Row(): | |
gr.Markdown("# 日本語 StableLM Pre-Alpha") | |
with gr.Row(): | |
gr.Markdown("Description about this page. ホゲホゲ") | |
with gr.Row(): | |
# left panel | |
with gr.Column(scale=1): | |
# generation params | |
with gr.Box(): | |
gr.Markdown("パラメータ") | |
# hidden default params | |
do_sample = gr.Checkbox(True, label="Do Sample", visible=False) | |
no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False) | |
# visible params | |
max_new_tokens = gr.Slider( | |
128, | |
min(512, model.config.max_position_embeddings), | |
value=128, | |
step=128, | |
label="max tokens", | |
info="生成するトークンの最大数を指定する", | |
) | |
temperature = gr.Slider( | |
0, 1, value=0.7, step=0.05, label="temperature", | |
info="低い値は出力をより集中させて決定論的にする") | |
repetition_penalty = gr.Slider( | |
1, 1.5, value=1.2, step=0.05, label="frequency penalty", | |
info="高い値はAIが繰り返す可能性を減少させる") | |
# grouping params for easier reference | |
gr_params = [ | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
] | |
# right panel | |
with gr.Column(scale=2): | |
# user input block | |
with gr.Box(): | |
textbox_prompt = gr.Textbox( | |
label="Human", | |
placeholder="AIに続きを書いて欲しいプロンプト", | |
interactive=True, | |
lines=5, | |
value="" | |
) | |
with gr.Box(): | |
with gr.Row(): | |
btn_submit = gr.Button(value="実行", variant="primary") | |
btn_stop = gr.Button(value="中止", variant="stop") | |
# model output block | |
with gr.Box(): | |
textbox_generation = gr.Textbox( | |
label="AI", | |
lines=5, | |
value="" | |
) | |
with gr.Box(): | |
with gr.Row(): | |
rating_options = [ | |
"😫すごく悪い", | |
"😞微妙", | |
"😐アリ", | |
"🙂合格", | |
"😄すごく良い", | |
] | |
btn_ratings = [gr.Button(value=v) for v in rating_options] | |
with gr.Box(): | |
gr.Markdown("TODO:For more feedback link for google form") | |
# event handling | |
inputs = [textbox_prompt] + gr_params | |
click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) | |
btn_stop.click(None, None, None, cancels=click_event, queue=False) | |
for btn_rating in btn_ratings: | |
btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False) | |
gr_interface.queue(max_size=32, concurrency_count=2) | |
gr_interface.launch(server_port=args.port, share=args.make_public) | |