lhallee commited on
Commit
4bfc769
·
verified ·
1 Parent(s): ad3c068

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +2 -2
modeling_fastesm.py CHANGED
@@ -575,7 +575,7 @@ class ProteinDataset(TorchDataset):
575
  def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]:
576
  def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
577
  """Collate function for batching sequences."""
578
- return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
579
  return _collate_fn
580
 
581
 
@@ -690,7 +690,7 @@ class EmbeddingMixin:
690
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
691
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
692
  residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
693
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
694
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
695
  if full_embeddings:
696
  emb = emb[mask.bool()].reshape(-1, hidden_size)
 
575
  def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]:
576
  def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
577
  """Collate function for batching sequences."""
578
+ return tokenizer(sequences, return_tensors="pt", padding='longest')
579
  return _collate_fn
580
 
581
 
 
690
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
691
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
692
  residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
693
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
694
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
695
  if full_embeddings:
696
  emb = emb[mask.bool()].reshape(-1, hidden_size)