nbroad commited on
Commit
af7a07a
·
1 Parent(s): f827190

fix batching

Browse files
Files changed (1) hide show
  1. utils.py +3 -3
utils.py CHANGED
@@ -312,14 +312,14 @@ def batch_embed(
312
  pin_memory=True,
313
  drop_last=False,
314
  ):
315
- ids = torch.tensor([b["input_ids"] for b in batch], device=device)
316
- mask = torch.tensor([b["attention_mask"] for b in batch], device=device)
317
  t_ids = torch.zeros_like(ids)
318
 
319
  outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
320
 
321
  embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
322
- texts.extend([b[column_name] for b in batch])
323
 
324
  current_count += ids.shape[0]
325
 
 
312
  pin_memory=True,
313
  drop_last=False,
314
  ):
315
+ ids = torch.tensor(batch["input_ids"], device=device)
316
+ mask = torch.tensor(batch["attention_mask"], device=device)
317
  t_ids = torch.zeros_like(ids)
318
 
319
  outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
320
 
321
  embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
322
+ texts.extend(batch[column_name])
323
 
324
  current_count += ids.shape[0]
325