import gradio as gr import os import threading import arrow import time import argparse import logging from dataclasses import dataclass import torch import sentencepiece as spm from transformers import GPTNeoXForCausalLM, GPTNeoXConfig from transformers.generation.streamers import BaseStreamer from huggingface_hub import hf_hub_download, login logger = logging.getLogger() logger.setLevel("INFO") gr_interface = None @dataclass class DefaultArgs: hf_model_name_or_path: str = None spm_model_path: str = None env: str = "dev" port: int = 7860 make_public: bool = False if os.getenv("RUNNING_ON_HF_SPACE"): login(token=os.getenv("HF_TOKEN")) hf_repo = os.getenv("HF_MODEL_REPO") args = DefaultArgs() args.hf_model_name_or_path = hf_repo args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model") else: parser = argparse.ArgumentParser(description="") parser.add_argument("--hf_model_name_or_path", type=str, required=True) parser.add_argument("--spm_model_path", type=str, required=True) parser.add_argument("--env", type=str, default="dev") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--make_public", action='store_true') args = parser.parse_args() def load_model( model_dir, ): config = GPTNeoXConfig.from_pretrained(model_dir) config.is_decoder = True model = GPTNeoXForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16) if torch.cuda.is_available(): model = model.to("cuda:0") return model logging.info("Loading model") model = load_model(args.hf_model_name_or_path) sp = spm.SentencePieceProcessor(model_file=args.spm_model_path) logging.info("Finished loading model") class SentencePieceStreamer(BaseStreamer): def __init__(self, sp: spm.SentencePieceProcessor): self.sp = sp self.num_invoked = 0 self.prompt = "" self.generated_text = "" self.ended = False def put(self, t: torch.Tensor): d = t.dim() if d == 1: pass elif d == 2: t = t[0] else: raise NotImplementedError t = [int(x) for x in t.numpy()] text = self.sp.decode_ids(t) if self.num_invoked == 0: self.prompt = text self.num_invoked += 1 return self.generated_text += text logging.debug(f"[streamer]: {self.generated_text}") def end(self): self.ended = True def generate( prompt, max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ): log = dict(locals()) logging.debug(log) tokens = sp.encode(prompt) input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device) streamer = SentencePieceStreamer(sp=sp) max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens) max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) thr = threading.Thread(target=model.generate, args=(), kwargs=dict( input_ids=input_ids, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, max_new_tokens=max_possilbe_new_tokens, streamer=streamer, # max_length=4096, # top_k=100, # top_p=0.9, # num_return_sequences=2, # num_beams=2, )) thr.start() while not streamer.ended: time.sleep(0.05) yield streamer.generated_text # TODO: optimize for final few tokens gen = streamer.generated_text log.update(dict(generation=gen, time=str(arrow.now("+09:00")))) logging.info(log) yield gen def process_feedback( rating, prompt, generation, max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ): log = dict(locals()) log["time"] = str(arrow.now("+09:00")) logging.info(log) if gr_interface: gr_interface.close(verbose=False) with gr.Blocks() as gr_interface: with gr.Row(): gr.Markdown("# 日本語 StableLM Pre-Alpha") with gr.Row(): gr.Markdown("Description about this page. ホゲホゲ") with gr.Row(): # left panel with gr.Column(scale=1): # generation params with gr.Box(): gr.Markdown("パラメータ") # hidden default params do_sample = gr.Checkbox(True, label="Do Sample", visible=False) no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False) # visible params max_new_tokens = gr.Slider( 128, min(512, model.config.max_position_embeddings), value=128, step=128, label="max tokens", info="生成するトークンの最大数を指定する", ) temperature = gr.Slider( 0, 1, value=0.7, step=0.05, label="temperature", info="低い値は出力をより集中させて決定論的にする") repetition_penalty = gr.Slider( 1, 1.5, value=1.2, step=0.05, label="frequency penalty", info="高い値はAIが繰り返す可能性を減少させる") # grouping params for easier reference gr_params = [ max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ] # right panel with gr.Column(scale=2): # user input block with gr.Box(): textbox_prompt = gr.Textbox( label="Human", placeholder="AIに続きを書いて欲しいプロンプト", interactive=True, lines=5, value="" ) with gr.Box(): with gr.Row(): btn_submit = gr.Button(value="実行", variant="primary") btn_stop = gr.Button(value="中止", variant="stop") # model output block with gr.Box(): textbox_generation = gr.Textbox( label="AI", lines=5, value="" ) with gr.Box(): with gr.Row(): rating_options = [ "😫すごく悪い", "😞微妙", "😐アリ", "🙂合格", "😄すごく良い", ] btn_ratings = [gr.Button(value=v) for v in rating_options] with gr.Box(): gr.Markdown("TODO:For more feedback link for google form") # event handling inputs = [textbox_prompt] + gr_params click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) btn_stop.click(None, None, None, cancels=click_event, queue=False) for btn_rating in btn_ratings: btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False) gr_interface.queue(max_size=32, concurrency_count=2) gr_interface.launch(server_port=args.port, share=args.make_public)