Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
5dd37ab
1
Parent(s):
bb74d9f
Ensure `input_ids` are on the correct device when predicting
Browse files- src/predict.py +1 -1
src/predict.py
CHANGED
@@ -171,7 +171,7 @@ DEFAULT_TOKEN_PREFIX = 'summarize: '
|
|
171 |
def predict_sponsor_text(text, model, tokenizer):
|
172 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
173 |
input_ids = tokenizer(
|
174 |
-
f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids
|
175 |
|
176 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
177 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|
|
|
171 |
def predict_sponsor_text(text, model, tokenizer):
|
172 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
173 |
input_ids = tokenizer(
|
174 |
+
f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids.to(device())
|
175 |
|
176 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
177 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|