Spaces:
Runtime error
Runtime error
fix batching
Browse files
utils.py
CHANGED
@@ -312,14 +312,14 @@ def batch_embed(
|
|
312 |
pin_memory=True,
|
313 |
drop_last=False,
|
314 |
):
|
315 |
-
ids = torch.tensor([
|
316 |
-
mask = torch.tensor([
|
317 |
t_ids = torch.zeros_like(ids)
|
318 |
|
319 |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
|
320 |
|
321 |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
|
322 |
-
texts.extend([
|
323 |
|
324 |
current_count += ids.shape[0]
|
325 |
|
|
|
312 |
pin_memory=True,
|
313 |
drop_last=False,
|
314 |
):
|
315 |
+
ids = torch.tensor(batch["input_ids"], device=device)
|
316 |
+
mask = torch.tensor(batch["attention_mask"], device=device)
|
317 |
t_ids = torch.zeros_like(ids)
|
318 |
|
319 |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
|
320 |
|
321 |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
|
322 |
+
texts.extend(batch[column_name])
|
323 |
|
324 |
current_count += ids.shape[0]
|
325 |
|