sgoel30 commited on
Commit
42149de
·
verified ·
1 Parent(s): 5f9a93d

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/data_loader.py +24 -11
  2. utils/esm_utils.py +1 -0
utils/data_loader.py CHANGED
@@ -1,6 +1,9 @@
1
  import pandas as pd
 
2
  from torch.utils.data import Dataset, DataLoader
3
- from utils.esm_utils import get_latents, load_esm2_model
 
 
4
 
5
  class ProteinDataset(Dataset):
6
  def __init__(self, csv_file, tokenizer, model):
@@ -12,19 +15,29 @@ class ProteinDataset(Dataset):
12
  return len(self.data)
13
 
14
  def __getitem__(self, idx):
15
- sequence = self.data.iloc[idx]['sequence']
16
  latents = get_latents(self.model, self.tokenizer, sequence)
17
- return latents
 
 
 
 
 
 
 
 
 
 
18
 
19
  def get_dataloaders(config):
20
- tokenizer, model = load_esm2_model(config.model_name)
21
 
22
- train_dataset = ProteinDataset(config.data_path + "train.csv", tokenizer, model)
23
- val_dataset = ProteinDataset(config.data_path + "val.csv", tokenizer, model)
24
- test_dataset = ProteinDataset(config.data_path + "test.csv", tokenizer, model)
25
 
26
- train_loader = DataLoader(train_dataset, batch_size=config.training["batch_size"], shuffle=True)
27
- val_loader = DataLoader(val_dataset, batch_size=config.training["batch_size"], shuffle=False)
28
- test_loader = DataLoader(test_dataset, batch_size=config.training["batch_size"], shuffle=False)
29
 
30
- return train_loader, val_loader, test_loader
 
1
  import pandas as pd
2
+ import torch
3
  from torch.utils.data import Dataset, DataLoader
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from esm_utils import get_latents, load_esm2_model
6
+ import config
7
 
8
  class ProteinDataset(Dataset):
9
  def __init__(self, csv_file, tokenizer, model):
 
15
  return len(self.data)
16
 
17
  def __getitem__(self, idx):
18
+ sequence = self.data.iloc[idx]['Sequence']
19
  latents = get_latents(self.model, self.tokenizer, sequence)
20
+
21
+ attention_mask = torch.ones_like(latents)
22
+ attention_mask = torch.mean(attention_mask, dim=-1)
23
+
24
+ return latents, attention_mask
25
+
26
+ def collate_fn(batch):
27
+ latents, attention_mask = zip(*batch)
28
+ latents_padded = pad_sequence([torch.tensor(latent) for latent in latents], batch_first=True, padding_value=0)
29
+ attention_mask_padded = pad_sequence([torch.tensor(mask) for mask in attention_mask], batch_first=True, padding_value=0)
30
+ return latents_padded, attention_mask_padded
31
 
32
  def get_dataloaders(config):
33
+ tokenizer, model = load_esm2_model(config.MODEL_NAME)
34
 
35
+ train_dataset = ProteinDataset(config.Loader.DATA_PATH + "/train.csv", tokenizer, model)
36
+ val_dataset = ProteinDataset(config.Loader.DATA_PATH + "/val.csv", tokenizer, model)
37
+ test_dataset = ProteinDataset(config.Loader.DATA_PATH + "/test.csv", tokenizer, model)
38
 
39
+ train_loader = DataLoader(train_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=True, collate_fn=collate_fn)
40
+ val_loader = DataLoader(val_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=False, collate_fn=collate_fn)
41
+ test_loader = DataLoader(test_dataset, batch_size=config.Loader.BATCH_SIZE, num_workers=0, shuffle=False, collate_fn=collate_fn)
42
 
43
+ return train_loader, val_loader, test_loader
utils/esm_utils.py CHANGED
@@ -11,3 +11,4 @@ def get_latents(model, tokenizer, sequence):
11
  with torch.no_grad():
12
  outputs = model(**inputs)
13
  return outputs.last_hidden_state.squeeze(0)
 
 
11
  with torch.no_grad():
12
  outputs = model(**inputs)
13
  return outputs.last_hidden_state.squeeze(0)
14
+