Update app.py
Browse files
app.py
CHANGED
@@ -28,7 +28,8 @@ if HUGGINGFACE_HUB_TOKEN:
|
|
28 |
login(token=HUGGINGFACE_HUB_TOKEN)
|
29 |
|
30 |
os.system("git config --global credential.helper store")
|
31 |
-
|
|
|
32 |
|
33 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
34 |
logger = logging.getLogger(__name__)
|
@@ -169,19 +170,15 @@ model_loader = GCSModelLoader(bucket)
|
|
169 |
|
170 |
async def generate_stream(model, tokenizer, input_text, generation_config):
|
171 |
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
182 |
-
yield {"token": token}
|
183 |
-
|
184 |
-
return token_stream()
|
185 |
|
186 |
def generate_non_stream(model, tokenizer, input_text, generation_config):
|
187 |
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
@@ -216,8 +213,8 @@ async def generate(request: GenerateRequest):
|
|
216 |
generation_config_kwargs = generation_params.copy()
|
217 |
generation_config_kwargs['pad_token_id'] = tokenizer.pad_token_id
|
218 |
generation_config_kwargs['eos_token_id'] = tokenizer.eos_token_id
|
219 |
-
generation_config_kwargs['sep_token_id'] = tokenizer.sep_token_id
|
220 |
-
generation_config_kwargs['unk_token_id'] = tokenizer.unk_token_id
|
221 |
|
222 |
model = model_loader.load_model(model_name, config)
|
223 |
if not model:
|
|
|
28 |
login(token=HUGGINGFACE_HUB_TOKEN)
|
29 |
|
30 |
os.system("git config --global credential.helper store")
|
31 |
+
if HUGGINGFACE_HUB_TOKEN:
|
32 |
+
huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)
|
33 |
|
34 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
35 |
logger = logging.getLogger(__name__)
|
|
|
170 |
|
171 |
async def generate_stream(model, tokenizer, input_text, generation_config):
|
172 |
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
173 |
+
generation_stream = model.generate(
|
174 |
+
**inputs,
|
175 |
+
generation_config=generation_config,
|
176 |
+
stream=True,
|
177 |
+
)
|
178 |
+
async for output in generation_stream:
|
179 |
+
token_id = output[-1]
|
180 |
+
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
181 |
+
yield {"token": token}
|
|
|
|
|
|
|
|
|
182 |
|
183 |
def generate_non_stream(model, tokenizer, input_text, generation_config):
|
184 |
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
|
|
213 |
generation_config_kwargs = generation_params.copy()
|
214 |
generation_config_kwargs['pad_token_id'] = tokenizer.pad_token_id
|
215 |
generation_config_kwargs['eos_token_id'] = tokenizer.eos_token_id
|
216 |
+
generation_config_kwargs['sep_token_id'] = tokenizer.sep_token_id if tokenizer.sep_token_id else tokenizer.eos_token_id
|
217 |
+
generation_config_kwargs['unk_token_id'] = tokenizer.unk_token_id if tokenizer.unk_token_id else tokenizer.eos_token_id
|
218 |
|
219 |
model = model_loader.load_model(model_name, config)
|
220 |
if not model:
|