File size: 2,047 Bytes
d8ed92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pandas as pd
import torch
import config
import random
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from pretrained_models import load_esm2_model

class ProteinDataset(Dataset):
    def __init__(self, csv_file, tokenizer):
        self.tokenizer = tokenizer
        self.data = pd.read_csv(csv_file)
        self.max_len = max([len(seq) for seq in self.data['Sequence'].tolist()])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data.iloc[idx]['Sequence'].upper()

        # Randomly mask 15% of the sequence
        num_masks = int(len(sequence) * 0.15)
        mask_indices = random.sample(range(len(sequence)), num_masks)
        masked_sequence = ''.join(["<mask>" if i in mask_indices else sequence[i] for i in range(len(sequence))])

        inputs = self.tokenizer(masked_sequence, padding="max_length", truncation=True, max_length=self.max_len, return_tensors='pt')
        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()

        labels = self.tokenizer(masked_sequence, return_tensors='pt', padding='max_length', max_length=self.max_len, truncation=True)['input_ids'].squeeze()
        labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}



def get_dataloaders(config):
    tokenizer, model = load_esm2_model(config.ESM_MODEL_PATH)
    
    train_dataset = ProteinDataset(config.TRAIN_DATA, tokenizer)
    val_dataset = ProteinDataset(config.VAL_DATA, tokenizer)
    test_dataset = ProteinDataset(config.TEST_DATA, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
    
    return train_loader, val_loader, test_loader