leemeng commited on
Commit
311e2d3
ยท
1 Parent(s): 4762d68

update version with fujiki-tuned model

Browse files
Files changed (1) hide show
  1. app.py +66 -23
app.py CHANGED
@@ -10,6 +10,7 @@ from dataclasses import dataclass
10
 
11
  import torch
12
  import sentencepiece as spm
 
13
  from transformers import GPTNeoXForCausalLM, GPTNeoXConfig
14
  from transformers.generation.streamers import BaseStreamer
15
  from huggingface_hub import hf_hub_download, login
@@ -20,11 +21,12 @@ logger.setLevel("INFO")
20
 
21
  gr_interface = None
22
 
23
- VERSION = "0.0.0-a.1"
24
 
25
  @dataclass
26
  class DefaultArgs:
27
  hf_model_name_or_path: str = None
 
28
  spm_model_path: str = None
29
  env: str = "dev"
30
  port: int = 7860
@@ -35,10 +37,13 @@ if os.getenv("RUNNING_ON_HF_SPACE"):
35
  hf_repo = os.getenv("HF_MODEL_REPO")
36
  args = DefaultArgs()
37
  args.hf_model_name_or_path = hf_repo
 
38
  args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
 
39
  else:
40
  parser = argparse.ArgumentParser(description="")
41
  parser.add_argument("--hf_model_name_or_path", type=str, required=True)
 
42
  parser.add_argument("--spm_model_path", type=str, required=True)
43
  parser.add_argument("--env", type=str, default="dev")
44
  parser.add_argument("--port", type=int, default=7860)
@@ -60,9 +65,15 @@ model = load_model(args.hf_model_name_or_path)
60
  sp = spm.SentencePieceProcessor(model_file=args.spm_model_path)
61
  logging.info("Finished loading model")
62
 
63
- class SentencePieceStreamer(BaseStreamer):
64
- def __init__(self, sp: spm.SentencePieceProcessor):
65
- self.sp = sp
 
 
 
 
 
 
66
  self.num_invoked = 0
67
  self.prompt = ""
68
  self.generated_text = ""
@@ -80,7 +91,10 @@ class SentencePieceStreamer(BaseStreamer):
80
 
81
  t = [int(x) for x in t.numpy()]
82
 
83
- text = self.sp.decode_ids(t)
 
 
 
84
 
85
  if self.num_invoked == 0:
86
  self.prompt = text
@@ -93,6 +107,35 @@ class SentencePieceStreamer(BaseStreamer):
93
  def end(self):
94
  self.ended = True
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def generate(
97
  prompt,
98
  max_new_tokens,
@@ -104,29 +147,29 @@ def generate(
104
  ):
105
  log = dict(locals())
106
  logging.debug(log)
107
-
108
- tokens = sp.encode(prompt)
109
- input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
110
-
111
- streamer = SentencePieceStreamer(sp=sp)
112
-
113
- max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens)
114
  max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
115
 
116
  thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
117
- input_ids=input_ids,
118
  do_sample=do_sample,
119
  temperature=temperature,
120
  repetition_penalty=repetition_penalty,
121
  no_repeat_ngram_size=no_repeat_ngram_size,
122
  max_new_tokens=max_possilbe_new_tokens,
 
 
 
 
123
  streamer=streamer,
124
- # max_length=4096,
125
- # top_k=100,
126
- # top_p=0.9,
127
- # num_return_sequences=2,
128
- # num_beams=2,
129
  ))
 
130
  thr.start()
131
 
132
  while not streamer.ended:
@@ -165,10 +208,10 @@ if gr_interface:
165
 
166
  with gr.Blocks() as gr_interface:
167
  with gr.Row():
168
- gr.Markdown(f"# ๆ—ฅๆœฌ่ชž StableLM Pre-Alpha ({VERSION})")
169
  # gr.Markdown(f"ใƒใƒผใ‚ธใƒงใƒณ๏ผš{VERSION}")
170
  with gr.Row():
171
- gr.Markdown("ใ“ใฎ่จ€่ชžใƒขใƒ‡ใƒซใฏ Stability AI Japan ใŒ้–‹็™บใ—ใŸๅˆๆœŸใƒใƒผใ‚ธใƒงใƒณใฎๆ—ฅๆœฌ่ชžใƒขใƒ‡ใƒซใงใ™ใ€‚ใ“ใฎใƒขใƒ‡ใƒซใฏ่ณชๅ•ๅฟœ็ญ”ใซ็‰นๅŒ–ใ—ใฆใ„ใพใ›ใ‚“ใ€‚ใใฎใŸใ‚ใ€ๆœŸๅพ…ใ™ใ‚‹ๅ›ž็ญ”ใŒใƒ—ใƒญใƒณใƒ—ใƒˆใฎ่‡ช็„ถใช็ถšใใจใชใ‚‹ใ‚ˆใ†ใซใƒ—ใƒญใƒณใƒ—ใƒˆใ‚’่จญๅฎšใ™ใ‚‹ๅฟ…่ฆใŒใ‚ใ‚Šใพใ™ใ€‚ไพ‹ใ‚’ๆŒ™ใ’ใ‚‹ใจใ€Œไบบ็”Ÿใฎๆ„ๅ‘ณใฏไฝ•ใงใ™ใ‹๏ผŸใ€ใงใฏใชใใ€ใ€Œ็งใŒๆ€ใฃใŸไบบ็”Ÿใฎๆ„ๅ‘ณใฏใ€ใฎใ‚ˆใ†ใซใƒ—ใƒญใƒณใƒ—ใƒˆใ‚’่จญๅฎšใ—ใฆใใ ใ•ใ„ใ€‚")
172
  with gr.Row():
173
 
174
  # left panel
@@ -180,7 +223,7 @@ with gr.Blocks() as gr_interface:
180
 
181
  # hidden default params
182
  do_sample = gr.Checkbox(True, label="Do Sample", info="ใ‚ตใƒณใƒ—ใƒชใƒณใ‚ฐ็”Ÿๆˆ", visible=True)
183
- no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False)
184
 
185
  # visible params
186
  max_new_tokens = gr.Slider(
@@ -192,7 +235,7 @@ with gr.Blocks() as gr_interface:
192
  info="็”Ÿๆˆใ™ใ‚‹ใƒˆใƒผใ‚ฏใƒณใฎๆœ€ๅคงๆ•ฐใ‚’ๆŒ‡ๅฎšใ™ใ‚‹",
193
  )
194
  temperature = gr.Slider(
195
- 0, 1, value=0.7, step=0.05, label="temperature",
196
  info="ไฝŽใ„ๅ€คใฏๅ‡บๅŠ›ใ‚’ใ‚ˆใ‚Š้›†ไธญใ•ใ›ใฆๆฑบๅฎš่ซ–็š„ใซใ™ใ‚‹")
197
  repetition_penalty = gr.Slider(
198
  1, 1.5, value=1.2, step=0.05, label="frequency penalty",
@@ -214,7 +257,7 @@ with gr.Blocks() as gr_interface:
214
  with gr.Box():
215
  textbox_prompt = gr.Textbox(
216
  label="ใƒ—ใƒญใƒณใƒ—ใƒˆ",
217
- placeholder="็งใŒๆ€ใฃใŸไบบ็”Ÿใฎๆ„ๅ‘ณใฏ",
218
  interactive=True,
219
  lines=5,
220
  value=""
 
10
 
11
  import torch
12
  import sentencepiece as spm
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
  from transformers import GPTNeoXForCausalLM, GPTNeoXConfig
15
  from transformers.generation.streamers import BaseStreamer
16
  from huggingface_hub import hf_hub_download, login
 
21
 
22
  gr_interface = None
23
 
24
+ VERSION = "0.1.0-a.1"
25
 
26
  @dataclass
27
  class DefaultArgs:
28
  hf_model_name_or_path: str = None
29
+ hf_tokenizer_name_or_path: str = None
30
  spm_model_path: str = None
31
  env: str = "dev"
32
  port: int = 7860
 
37
  hf_repo = os.getenv("HF_MODEL_REPO")
38
  args = DefaultArgs()
39
  args.hf_model_name_or_path = hf_repo
40
+ args.hf_tokenizer_name_or_path = os.path.join(hf_repo, "tokenizer")
41
  args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
42
+
43
  else:
44
  parser = argparse.ArgumentParser(description="")
45
  parser.add_argument("--hf_model_name_or_path", type=str, required=True)
46
+ parser.add_argument("--hf_tokenizer_name_or_path", type=str, required=False)
47
  parser.add_argument("--spm_model_path", type=str, required=True)
48
  parser.add_argument("--env", type=str, default="dev")
49
  parser.add_argument("--port", type=int, default=7860)
 
65
  sp = spm.SentencePieceProcessor(model_file=args.spm_model_path)
66
  logging.info("Finished loading model")
67
 
68
+ tokenizer = AutoTokenizer.from_pretrained(
69
+ args.hf_model_name_or_path,
70
+ subfolder="tokenizer",
71
+ use_fast=False
72
+ )
73
+
74
+ class TokenizerStreamer(BaseStreamer):
75
+ def __init__(self, tokenizer):
76
+ self.tokenizer = tokenizer
77
  self.num_invoked = 0
78
  self.prompt = ""
79
  self.generated_text = ""
 
91
 
92
  t = [int(x) for x in t.numpy()]
93
 
94
+ text = tokenizer.decode(t)
95
+ if text in [tokenizer.bos_token, tokenizer.eos_token]:
96
+ text = ""
97
+
98
 
99
  if self.num_invoked == 0:
100
  self.prompt = text
 
107
  def end(self):
108
  self.ended = True
109
 
110
+ INPUT_PROMPT = """ไปฅไธ‹ใฏใ€ใ‚ฟใ‚นใ‚ฏใ‚’่ชฌๆ˜Žใ™ใ‚‹ๆŒ‡็คบใจใ€ๆ–‡่„ˆใฎใ‚ใ‚‹ๅ…ฅๅŠ›ใฎ็ต„ใฟๅˆใ‚ใ›ใงใ™ใ€‚่ฆๆฑ‚ใ‚’้ฉๅˆ‡ใซๆบ€ใŸใ™ๅฟœ็ญ”ใ‚’ๆ›ธใใชใ•ใ„ใ€‚
111
+
112
+ ### ๆŒ‡็คบ:
113
+ {instruction}
114
+
115
+ ### ๅ…ฅๅŠ›:
116
+ {input}
117
+
118
+ ### ๅฟœ็ญ”: """
119
+
120
+ NO_INPUT_PROMPT = """ไปฅไธ‹ใฏใ€ใ‚ฟใ‚นใ‚ฏใ‚’่ชฌๆ˜Žใ™ใ‚‹ๆŒ‡็คบใจใ€ๆ–‡่„ˆใฎใ‚ใ‚‹ๅ…ฅๅŠ›ใฎ็ต„ใฟๅˆใ‚ใ›ใงใ™ใ€‚่ฆๆฑ‚ใ‚’้ฉๅˆ‡ใซๆบ€ใŸใ™ๅฟœ็ญ”ใ‚’ๆ›ธใใชใ•ใ„ใ€‚
121
+
122
+ ### ๆŒ‡็คบ:
123
+ {instruction}
124
+
125
+ ### ๅฟœ็ญ”: """
126
+
127
+
128
+ def postprocess_output(output):
129
+ output = output\
130
+ .split('### ๅฟœ็ญ”:')[1]\
131
+ .split('###')[0]\
132
+ .split('##')[0]\
133
+ .lstrip(tokenizer.bos_token)\
134
+ .rstrip(tokenizer.eos_token)\
135
+ .replace("###", "")\
136
+ .strip()
137
+ return output
138
+
139
  def generate(
140
  prompt,
141
  max_new_tokens,
 
147
  ):
148
  log = dict(locals())
149
  logging.debug(log)
150
+
151
+ input_text = NO_INPUT_PROMPT.format(instruction=prompt)
152
+ input_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
153
+
154
+ streamer = TokenizerStreamer(tokenizer=tokenizer)
155
+
156
+ max_possilbe_new_tokens = model.config.max_position_embeddings - input_ids.shape[0]
157
  max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
158
 
159
  thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
160
+ input_ids=input_ids.to(model.device),
161
  do_sample=do_sample,
162
  temperature=temperature,
163
  repetition_penalty=repetition_penalty,
164
  no_repeat_ngram_size=no_repeat_ngram_size,
165
  max_new_tokens=max_possilbe_new_tokens,
166
+ pad_token_id=tokenizer.pad_token_id,
167
+ bos_token_id=tokenizer.bos_token_id,
168
+ eos_token_id=tokenizer.eos_token_id,
169
+ bad_words_ids=[[tokenizer.unk_token_id]],
170
  streamer=streamer,
 
 
 
 
 
171
  ))
172
+
173
  thr.start()
174
 
175
  while not streamer.ended:
 
208
 
209
  with gr.Blocks() as gr_interface:
210
  with gr.Row():
211
+ gr.Markdown(f"# ๆ—ฅๆœฌ่ชž StableLM Tuned Pre-Alpha ({VERSION})")
212
  # gr.Markdown(f"ใƒใƒผใ‚ธใƒงใƒณ๏ผš{VERSION}")
213
  with gr.Row():
214
+ gr.Markdown("ใ“ใฎ่จ€่ชžใƒขใƒ‡ใƒซใฏ Stability AI Japan ใŒ้–‹็™บใ—ใŸๅˆๆœŸใƒใƒผใ‚ธใƒงใƒณใฎๆ—ฅๆœฌ่ชžใƒขใƒ‡ใƒซใงใ™ใ€‚ใƒขใƒ‡ใƒซใฏใ€Œใƒ—ใƒญใƒณใƒ—ใƒˆใ€ใซๅ…ฅๅŠ›ใ—ใŸ่žใใŸใ„ใ“ใจใซๅฏพใ—ใฆใ€ใใ‚Œใ‚‰ใ—ใ„ๅฟœ็ญ”ใ‚’ใ™ใ‚‹ใ“ใจใŒใงใใพใ™ใ€‚")
215
  with gr.Row():
216
 
217
  # left panel
 
223
 
224
  # hidden default params
225
  do_sample = gr.Checkbox(True, label="Do Sample", info="ใ‚ตใƒณใƒ—ใƒชใƒณใ‚ฐ็”Ÿๆˆ", visible=True)
226
+ no_repeat_ngram_size = gr.Slider(0, 10, value=3, step=1, label="No Repeat Ngram Size", visible=False)
227
 
228
  # visible params
229
  max_new_tokens = gr.Slider(
 
235
  info="็”Ÿๆˆใ™ใ‚‹ใƒˆใƒผใ‚ฏใƒณใฎๆœ€ๅคงๆ•ฐใ‚’ๆŒ‡ๅฎšใ™ใ‚‹",
236
  )
237
  temperature = gr.Slider(
238
+ 0, 1, value=0.1, step=0.05, label="temperature",
239
  info="ไฝŽใ„ๅ€คใฏๅ‡บๅŠ›ใ‚’ใ‚ˆใ‚Š้›†ไธญใ•ใ›ใฆๆฑบๅฎš่ซ–็š„ใซใ™ใ‚‹")
240
  repetition_penalty = gr.Slider(
241
  1, 1.5, value=1.2, step=0.05, label="frequency penalty",
 
257
  with gr.Box():
258
  textbox_prompt = gr.Textbox(
259
  label="ใƒ—ใƒญใƒณใƒ—ใƒˆ",
260
+ placeholder="ๆ—ฅๆœฌใฎ้ฆ–้ƒฝใฏ๏ผŸ",
261
  interactive=True,
262
  lines=5,
263
  value=""