Hjgugugjhuhjggg commited on
Commit
cdfd15f
·
verified ·
1 Parent(s): c7434cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -28,7 +28,8 @@ if HUGGINGFACE_HUB_TOKEN:
28
  login(token=HUGGINGFACE_HUB_TOKEN)
29
 
30
  os.system("git config --global credential.helper store")
31
- huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)
 
32
 
33
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
34
  logger = logging.getLogger(__name__)
@@ -169,19 +170,15 @@ model_loader = GCSModelLoader(bucket)
169
 
170
  async def generate_stream(model, tokenizer, input_text, generation_config):
171
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
172
-
173
- async def token_stream():
174
- generation_stream = model.generate(
175
- **inputs,
176
- generation_config=generation_config,
177
- stream=True,
178
- )
179
- async for output in generation_stream:
180
- token_id = output[-1]
181
- token = tokenizer.decode(token_id, skip_special_tokens=True)
182
- yield {"token": token}
183
-
184
- return token_stream()
185
 
186
  def generate_non_stream(model, tokenizer, input_text, generation_config):
187
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
@@ -216,8 +213,8 @@ async def generate(request: GenerateRequest):
216
  generation_config_kwargs = generation_params.copy()
217
  generation_config_kwargs['pad_token_id'] = tokenizer.pad_token_id
218
  generation_config_kwargs['eos_token_id'] = tokenizer.eos_token_id
219
- generation_config_kwargs['sep_token_id'] = tokenizer.sep_token_id
220
- generation_config_kwargs['unk_token_id'] = tokenizer.unk_token_id
221
 
222
  model = model_loader.load_model(model_name, config)
223
  if not model:
 
28
  login(token=HUGGINGFACE_HUB_TOKEN)
29
 
30
  os.system("git config --global credential.helper store")
31
+ if HUGGINGFACE_HUB_TOKEN:
32
+ huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)
33
 
34
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
35
  logger = logging.getLogger(__name__)
 
170
 
171
  async def generate_stream(model, tokenizer, input_text, generation_config):
172
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
173
+ generation_stream = model.generate(
174
+ **inputs,
175
+ generation_config=generation_config,
176
+ stream=True,
177
+ )
178
+ async for output in generation_stream:
179
+ token_id = output[-1]
180
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
181
+ yield {"token": token}
 
 
 
 
182
 
183
  def generate_non_stream(model, tokenizer, input_text, generation_config):
184
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
 
213
  generation_config_kwargs = generation_params.copy()
214
  generation_config_kwargs['pad_token_id'] = tokenizer.pad_token_id
215
  generation_config_kwargs['eos_token_id'] = tokenizer.eos_token_id
216
+ generation_config_kwargs['sep_token_id'] = tokenizer.sep_token_id if tokenizer.sep_token_id else tokenizer.eos_token_id
217
+ generation_config_kwargs['unk_token_id'] = tokenizer.unk_token_id if tokenizer.unk_token_id else tokenizer.eos_token_id
218
 
219
  model = model_loader.load_model(model_name, config)
220
  if not model: