Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +2 -2
modeling_fastesm.py
CHANGED
@@ -575,7 +575,7 @@ class ProteinDataset(TorchDataset):
|
|
575 |
def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]:
|
576 |
def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
|
577 |
"""Collate function for batching sequences."""
|
578 |
-
return tokenizer(sequences, return_tensors="pt", padding='longest'
|
579 |
return _collate_fn
|
580 |
|
581 |
|
@@ -690,7 +690,7 @@ class EmbeddingMixin:
|
|
690 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
691 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
692 |
residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
|
693 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
694 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
695 |
if full_embeddings:
|
696 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
|
|
575 |
def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]:
|
576 |
def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
|
577 |
"""Collate function for batching sequences."""
|
578 |
+
return tokenizer(sequences, return_tensors="pt", padding='longest')
|
579 |
return _collate_fn
|
580 |
|
581 |
|
|
|
690 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
691 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
692 |
residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
|
693 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
694 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
695 |
if full_embeddings:
|
696 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|