muPPIt / muppit /data.py
AlienChen's picture
Upload 139 files
65bd8af verified
import sys
import glob
import lightning.pytorch as pl
import torch
import pickle
from datasets import Dataset, DatasetDict, load_from_disk, concatenate_datasets
from functools import partial
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from torch.utils.data import DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence
import sentencepiece as spm
import os
from multiprocessing import Pool
from tqdm import tqdm
from datetime import datetime
import random
import gc
import h5py
import numpy as np
from multiprocessing import Pool
torch.cuda.empty_cache()
from transformers import AutoTokenizer, EsmModel
global_tokenizer = None
def init_pool(tokenizer):
global global_tokenizer
global_tokenizer = tokenizer
def standalone_tokenize_function(s1, extra_toks_per_seq=1):
global global_tokenizer
try:
tokens_1 = global_tokenizer.encode()
return tokens_1
except Exception as e:
raise ValueError(f"Error during tokenization of string {s1} : {e}")
class BatchedPPIDataset(object):
"""inspired by esm2, but instead of sorting the original sequences,
we should really sorting based on tokenized sequences
"""
def __init__(self, sequence_strs, tokenizer, max_sequence_length):
self.batch_indices = None
self.sequence_str_1 = sequence_strs['sequence_1']
self.sequence_str_2 = sequence_strs['sequence_2']
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length
self.tokenized_sequences = []
self.accumulated_length = 0
# automatically tokenize sequences upon creation of the object
# if need manual, change it to call object.process_all() in a separate line
# self.tokenize_sequences()
def tokenize_sequences_forward(self):
prot_tuples = list(zip(self.sequence_str_1, self.sequence_str_2))
with Pool(processes=16, initializer=init_pool, initargs=(self.tokenizer,)) as pool:
tokenized_pairs = list(
tqdm(pool.imap(partial(standalone_tokenize_function),
prot_tuples),
total=len(prot_tuples)))
for tokens_1, tokens_2 in tokenized_pairs:
seq_length = len(tokens_1) + len(tokens_2) + 3 # for both bos, eos, sep tokens
if seq_length <= self.max_sequence_length:
forward_sequence = [self.tokenizer.bos_id()] + tokens_1 + [self.tokenizer.piece_to_id('<sep>')] + tokens_2 + [self.tokenizer.eos_id()]
self.tokenized_sequences.append(forward_sequence)
def process_all(self, base_path, split_name):
self.tokenize_sequences_forward()
forward_batches = self.process_chunk(self.tokenized_sequences, self.get_batch_indices())
offset = len(self.tokenized_sequences)
self.tokenize_sequences_backward()
backward_batches = self.process_chunk(self.tokenized_sequences, self.get_batch_indices(offset))
self.tokenized_sequences = []
combined_dataset = concatenate_datasets([forward_batches, backward_batches])
#combined_dataset = forward_batches
# shuffle the datasets overall again
shuffled_dataset = combined_dataset.shuffle()
self.save_checkpoint(shuffled_dataset, base_path,
split_name=split_name)
return shuffled_dataset
def process_chunk(self, tokenized_sequences, batch_indices):
print(f'Start padding and masking for sequences {len(batch_indices)} batches')
token_batch_fn = TokenizeBatch(self.tokenizer)
processed_batches = [
token_batch_fn([tokenized_sequences[i] for i in batch]) for batch
in batch_indices]
assert len(processed_batches) == len(batch_indices)
# Shuffle together using a permutation
permutation = list(torch.randperm(len(processed_batches)))
processed_batches = [processed_batches[i] for i in permutation]
all_attention_masks = []
all_input_ids = []
all_labels = []
all_attention_masks.extend([batch['attention_mask'] for batch in processed_batches])
all_input_ids.extend([batch['input_ids'] for batch in processed_batches])
all_labels.extend([batch['labels'] for batch in processed_batches])
combined_dataset = Dataset.from_dict({
'attention_mask': all_attention_masks,
'input_ids': all_input_ids,
'labels': all_labels
})
del token_batch_fn, processed_batches, batch_indices, tokenized_sequences, all_attention_masks, all_input_ids, all_labels
gc.collect()
return combined_dataset
def save_checkpoint(self, shuffled_dataset, base_path, split_name=None):
print(f'Start generating tokens for shuffled_dataset sequences')
# current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
output_file = f'{base_path}/{split_name}_combined_reversed_ppi_tokenized_sequences.hf'
# clear up memory for the multiprocessing
shuffled_dataset.save_to_disk(output_file)
del shuffled_dataset
print(f'successfully written {split_name} processed datasets into disc!')
self.tokenized_sequences.clear()
gc.collect()
def get_batch_indices(self, offset=0, end=None):
if end is None:
end = len(self.tokenized_sequences)
# list splice to isolate processing of forward and backward sequences who
# are both stored within self.tokenized_sequences
sizes = [(len(tokens), i) for i, tokens in enumerate(self.tokenized_sequences[offset:end])]
sizes = [(sz, idx + offset) for sz, idx in sizes]
sizes.sort()
batches = []
buf = []
current_buf_len = 0
def _flush_current_buf():
nonlocal current_buf_len, buf
if len(buf) == 0:
return
batches.append(buf)
buf = []
current_buf_len = 0
# print('my batches is:')
# print(batches)
for sz, i in sizes:
# sz already has the length of special tokens, handled at tokenization level
# check accumulative seq length in the buffer
if current_buf_len + sz > self.max_sequence_length:
_flush_current_buf()
buf.append(i)
current_buf_len += sz
# print('my buffer is:')
# print(buf)
_flush_current_buf()
return batches
class DynamicBatchingDataset(Dataset):
"""
Process dynamically batched datasets of Huggingface Datasets object. Need special handling since in the previous
steps, each batch (row in the Datasets object) is already processed for per batch loading
Usage:
train_dataset = DynamicBatchingDataset(small_dataset_dict['train'], batch_indices_train)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False,
collate_fn=DynamicBatchingDataset.dynamic_padding_collate_fn)
"""
def __init__(self, dataset_dict):
print('Initializing dataset...')
#self.dataset_dict = dataset_dict
self.dataset_dict = {
'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
'labels': [torch.tensor(item) for item in dataset_dict['labels']]
}
def __len__(self):
return len(self.dataset_dict['attention_mask']) # assuming each entry in dataset_dict represents a batch
def __getitem__(self, idx):
# Check if idx is an integer or a list
if isinstance(idx, int):
return {
'attention_mask': self.dataset_dict['attention_mask'][idx],
'input_ids': self.dataset_dict['input_ids'][idx],
'labels': self.dataset_dict['labels'][idx]
}
elif isinstance(idx, list):
return {
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
'labels': [self.dataset_dict['labels'][i] for i in idx]
}
else:
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
#if isinstance(idx, int):
# indices = [idx]
#else:
# indices = idx
#attention_masks = []
#input_ids = []
#labels = []
#for index in indices:
# attention_masks.append(torch.tensor(self.dataset_dict['attention_mask'][index]))
# input_ids.append(torch.tensor(self.dataset_dict['input_ids'][index]))
# labels.append(torch.tensor(self.dataset_dict['labels'][index]))
#return {
# 'attention_mask': attention_masks,
# 'input_ids': input_ids,
# 'labels': labels
#}
@staticmethod
def collate_fn(batch, verbose=False):
# Since DataLoader's batch_size is 1, batch[0] contains your pre-batched data
item = batch[0]
#if verbose:
# print(f"collate_fn batch shape: {item['input_ids'].shape}")
# attention_mask = item['attention_mask']
# input_ids = item['input_ids']
# if verbose:
# print(f"collate_fn input_ids shape after indexing: {input_ids.shape}")
#labels = item['labels']
# These are already pre-padded, so you can directly return
return {
'attention_mask': item['attention_mask'],
'input_ids': item['input_ids'],
'labels': item['labels']
}
import sys
import lightning.pytorch as pl
import torch
from datasets import Dataset, DatasetDict, load_from_disk
from torch.utils.data import DataLoader
import sentencepiece as spm
import os
from multiprocessing import Pool
from tqdm import tqdm
import gc
import pandas as pd
def standalone_tokenize_function(sequence, tokenizer):
try:
tokens = tokenizer.encode(sequence)
return tokens
except Exception as e:
raise ValueError(f"Error during tokenization: {e}")
class TokenizeBatch:
def __init__(self, tokenizer):
self.pad_token_id = tokenizer.pad_token_id
def __call__(self, batches):
print(f"Processing batch of size {len(batches)}")
data_tokens = [torch.tensor(token_list) for token_list in batches]
data_tokens_padded = torch.nn.utils.rnn.pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id)
attention_masks = (data_tokens_padded != self.pad_token_id).long()
labels = data_tokens_padded.clone()
labels[data_tokens_padded == self.pad_token_id] = -100
print(f"Batch processed. Shape: {data_tokens_padded.shape}")
return {
'input_ids': data_tokens_padded,
'attention_mask': attention_masks,
'labels': labels
}
class SequenceDataset:
def __init__(self, sequences, tokenizer, max_sequence_length):
self.sequences = sequences
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length
self.tokenized_sequences = []
def tokenize_sequences(self):
print(f"Starting tokenization of {len(self.sequences)} sequences")
for i, seq in enumerate(tqdm(self.sequences)):
# ESM tokenizer handles special tokens automatically
tokens = self.tokenizer.encode(seq)
if len(tokens) <= self.max_sequence_length:
self.tokenized_sequences.append(tokens)
if i % 10000 == 0:
print(f"Processed {i} sequences. Current tokenized count: {len(self.tokenized_sequences)}")
print(f"Tokenization complete. Final count: {len(self.tokenized_sequences)}")
def process_sequences(self, batch_size):
print("Starting sequence processing")
self.tokenize_sequences()
print("Sorting sequences by length")
lengths = [(len(seq), i) for i, seq in enumerate(self.tokenized_sequences)]
lengths.sort()
batches = []
current_batch = []
current_length = 0
print("Creating batches")
for length, idx in tqdm(lengths):
if current_length + length > self.max_sequence_length or len(current_batch) == batch_size:
if current_batch:
batches.append([self.tokenized_sequences[i] for i in current_batch])
current_batch = [idx]
current_length = length
else:
current_batch.append(idx)
current_length += length
if current_batch:
batches.append([self.tokenized_sequences[i] for i in current_batch])
print(f"Created {len(batches)} batches")
token_batch_fn = TokenizeBatch(self.tokenizer)
print("Processing batches")
processed_batches = [token_batch_fn(batch) for batch in tqdm(batches)]
print("Creating final dataset")
all_attention_masks = [batch['attention_mask'] for batch in processed_batches]
all_input_ids = [batch['input_ids'] for batch in processed_batches]
all_labels = [batch['labels'] for batch in processed_batches]
dataset = Dataset.from_dict({
'attention_mask': all_attention_masks,
'input_ids': all_input_ids,
'labels': all_labels
})
print(f"Final dataset size: {len(dataset)}")
return dataset
class PretrainSequenceDataModule(pl.LightningDataModule):
def __init__(self,
input_dataset_path,
output_dataset_path,
num_workers,
batch_size,
max_sequence_length=512,
model_name="facebook/esm2_t33_650M_UR50D"):
super().__init__()
print(f"Initializing tokenizer from {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.input_path = input_dataset_path
self.output_path = output_dataset_path
self.num_workers = num_workers
self.batch_size = batch_size
self.max_sequence_length = max_sequence_length
def prepare_data(self):
if not os.path.exists(self.output_path):
print("Loading CSV files")
train_df = pd.read_csv(f"{self.input_path}/train.csv")
val_df = pd.read_csv(f"{self.input_path}/val.csv")
test_df = pd.read_csv(f"{self.input_path}/test.csv")
print("Processing training data")
train_dataset = SequenceDataset(train_df['Sequence'].tolist(),
self.tokenizer,
self.max_sequence_length)
print("Processing validation data")
val_dataset = SequenceDataset(val_df['Sequence'].tolist(),
self.tokenizer,
self.max_sequence_length)
print("Processing test data")
test_dataset = SequenceDataset(test_df['Sequence'].tolist(),
self.tokenizer,
self.max_sequence_length)
processed_train = train_dataset.process_sequences(self.batch_size)
processed_val = val_dataset.process_sequences(self.batch_size)
processed_test = test_dataset.process_sequences(self.batch_size)
print("Combining datasets")
combined_dataset = DatasetDict({
'train': processed_train,
'val': processed_val,
'test': processed_test
})
print(f"Saving dataset to {self.output_path}")
combined_dataset.save_to_disk(self.output_path)
def setup(self, stage: str):
print("Loading processed dataset")
dataset = load_from_disk(self.output_path)
self.train_dataset = DynamicBatchingDataset(dataset['train'])
self.val_dataset = DynamicBatchingDataset(dataset['val'])
self.test_dataset = DynamicBatchingDataset(dataset['test'])
print(f"Dataset sizes - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}, Test: {len(self.test_dataset)}")
def train_dataloader(self):
print("Creating training dataloader")
return DataLoader(self.train_dataset,
batch_size=1,
shuffle=False,
num_workers=self.num_workers,
collate_fn=DynamicBatchingDataset.collate_fn,
pin_memory=True)
def val_dataloader(self):
print("Creating validation dataloader")
return DataLoader(self.val_dataset,
batch_size=1,
shuffle=False,
num_workers=self.num_workers,
collate_fn=DynamicBatchingDataset.collate_fn,
pin_memory=True)
def test_dataloader(self):
print("Creating test dataloader")
return DataLoader(self.test_dataset,
batch_size=1,
shuffle=False,
num_workers=self.num_workers,
collate_fn=DynamicBatchingDataset.collate_fn,
pin_memory=True)
if __name__ == '__main__':
dm = PretrainSequenceDataModule(
input_dataset_path='/home/tc415/discrete-diffusion-guidance/dataset/peptide',
output_dataset_path='/home/tc415/discrete-diffusion-guidance/dataset/tokenized_peptide',
num_workers=8,
batch_size=50,
max_sequence_length=100,
model_name="facebook/esm2_t33_650M_UR50D"
)
dm.prepare_data()
dm.setup('fit')
dm.train_dataloader()
dm.val_dataloader()
dm.test_dataloader()