muPPIt / muppit /main.py
AlienChen's picture
Upload 139 files
65bd8af verified
import json
import os
import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree
import torch
from tqdm import tqdm
from datasets import load_from_disk
import pdb
import classifier
import dataloader
import diffusion
import eval_utils
import utils
omegaconf.OmegaConf.register_new_resolver(
'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
'div_up', lambda x, y: (x + y - 1) // y)
omegaconf.OmegaConf.register_new_resolver(
'if_then_else',
lambda condition, x, y: x if condition else y
)
def _load_from_checkpoint(config, tokenizer):
if 'hf' in config.backbone:
return diffusion.Diffusion(
config, tokenizer=tokenizer).to('cuda')
return diffusion.Diffusion.load_from_checkpoint(
config.eval.checkpoint_path,
tokenizer=tokenizer,
config=config, logger=False).to('cuda')
@L.pytorch.utilities.rank_zero_only
def _print_config(
config: omegaconf.DictConfig,
resolve: bool = True,
save_cfg: bool = True) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool): Whether to resolve reference fields of DictConfig.
save_cfg (bool): Whether to save the configuration tree to a file.
"""
style = 'dim'
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, omegaconf.DictConfig):
branch_content = omegaconf.OmegaConf.to_yaml(
config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
rich.print(tree)
if save_cfg:
with fsspec.open(
'{}/config_tree.txt'.format(
config.checkpointing.save_dir), 'w') as fp:
rich.print(tree, file=fp)
@L.pytorch.utilities.rank_zero_only
def _print_batch(train_ds, valid_ds, tokenizer, k=64):
for dl_type, dl in [
('train', train_ds), ('valid', valid_ds)]:
print(f'Printing {dl_type} dataloader batch.')
batch = next(iter(dl))
print('Batch input_ids.shape', batch['input_ids'].shape)
first = batch['input_ids'][0, :k]
last = batch['input_ids'][0, -k:]
print(f'First {k} tokens:', tokenizer.decode(first))
print('ids:', first)
print(f'Last {k} tokens:', tokenizer.decode(last))
print('ids:', last)
def _train(config, logger, tokenizer,
train_classifier=False):
logger.info('Starting Training.')
wandb_logger = None
if config.get('wandb', None) is not None:
wandb_logger = L.pytorch.loggers.WandbLogger(
config=omegaconf.OmegaConf.to_object(config),
** config.wandb)
if (config.checkpointing.resume_from_ckpt
and config.checkpointing.resume_ckpt_path is not None
and utils.fsspec_exists(
config.checkpointing.resume_ckpt_path)):
ckpt_path = config.checkpointing.resume_ckpt_path
else:
ckpt_path = None
# Lightning callbacks
callbacks = []
if 'callbacks' in config:
for _, callback in config.callbacks.items():
callbacks.append(hydra.utils.instantiate(callback))
# train_ds, valid_ds = dataloader.get_dataloaders(
# config, tokenizer)
train_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/train')
val_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/val')
test_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/test')
data_module = dataloader.CustomDataModule(train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size=config.loader.batch_size)
train_ds = data_module.train_dataloader()
valid_ds = data_module.val_dataloader()
if not config.is_vision:
_print_batch(train_ds, valid_ds, tokenizer)
if train_classifier:
# This param indicates classifier will be used for
# PPLM / NOS-style guidance
# (see: https://arxiv.org/abs/2305.20009).
if getattr(config, 'is_pplm_classifier', False):
pretrained_model = _load_from_checkpoint(
config, tokenizer)
if (getattr(config.classifier_model, 'use_encoder_ema', True)
and pretrained_model.ema):
pretrained_model.load_ema_params()
pretrained_backbone = pretrained_model.backbone
# Remove the last layer for the classifier
if hasattr(pretrained_backbone, 'output_layer'): #DiT
delattr(pretrained_backbone, 'output_layer')
if hasattr(pretrained_backbone, 'model.lm_head'): #DiMamba
delattr(pretrained_backbone, 'model.lm_head')
if getattr(config.classifier_model, 'freeze_encoder', True):
for param in pretrained_backbone.parameters():
param.requires_grad = False
else:
pretrained_backbone = None
model = classifier.Classifier(
config,
tokenizer=valid_ds.tokenizer,
pretrained_backbone=pretrained_backbone)
else:
model = diffusion.Diffusion(
config, tokenizer=tokenizer)
# model = diffusion.Diffusion(
# config, tokenizer=valid_ds.tokenizer)
trainer = hydra.utils.instantiate(
config.trainer,
default_root_dir=os.getcwd(),
callbacks=callbacks,
strategy=hydra.utils.instantiate(config.strategy),
logger=wandb_logger)
trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)
def _gen_ppl_eval(config, tokenizer):
pretrained = _load_from_checkpoint(
config=config, tokenizer=tokenizer)
pretrained.eval()
samples = []
for _ in tqdm(range(config.sampling.num_sample_batches),
desc='Gen. batches', leave=False):
sample = pretrained.sample()
samples.extend(
pretrained.tokenizer.batch_decode(sample))
# Replace CLS token with BOS token (if applicable) and
# remove padding and mask tokens
tok_bos_token = tokenizer.bos_token if tokenizer.bos_token is not None else tokenizer.cls_token
samples = [
s.replace('[PAD]', '').replace('[MASK]', '').strip()
for s in samples
]
# Add BOS token to the beginning of each sample (if not already present)
samples = [
s if s.startswith(tok_bos_token) else f"{tok_bos_token} {s}"
for s in samples
]
del pretrained # free up space for eval
print(f"Generated {len(samples)} samples.")
generative_ppl = eval_utils.compute_generative_ppl(
samples,
eval_model_name_or_path=config.eval.generative_ppl_model_name_or_path,
gen_ppl_eval_batch_size=8,
max_length=config.model.length)
tokens = tokenizer.batch_encode_plus(
samples,
return_tensors='pt',
add_special_tokens=False,
max_length=config.model.length,
padding='max_length',
truncation=True)['input_ids']
_, counts = torch.unique(
torch.tensor(tokens), return_counts=True, sorted=False)
entropy = torch.special.entr(
counts.float() / counts.sum()).sum().item()
with open(config.eval.generated_samples_path, 'w') as f:
json.dump({
'generative_ppl': generative_ppl,
'entropy': entropy,
'generated_seqs': samples,
},
f, indent=4) # type: ignore
print(f"Entropy: {entropy:0.3f}")
print(f"Gen. PPL: {generative_ppl:0.3f}")
def _ppl_eval(config, tokenizer):
print(f"Evaluating perplexity on {config.data.valid}.")
pretrained = _load_from_checkpoint(
config=config, tokenizer=tokenizer)
pretrained.eval()
if not config.eval.disable_ema:
pretrained.load_ema_params()
_, valid_ds = dataloader.get_dataloaders(
config, tokenizer, skip_train=True, valid_seed=config.seed)
ppl = eval_utils.compute_ppl(pretrained, valid_ds)
print(f"PPL: {ppl:0.3f}")
@hydra.main(version_base=None, config_path='configs',
config_name='config')
def main(config):
"""Main entry point for training."""
L.seed_everything(config.seed)
_print_config(config, resolve=True, save_cfg=True)
logger = utils.get_logger(__name__)
tokenizer = dataloader.get_tokenizer(config)
if config.mode == 'gen_ppl_eval':
_gen_ppl_eval(config, tokenizer)
elif config.mode == 'ppl_eval':
_ppl_eval(config, tokenizer)
elif 'train' in config.mode:
_train(config, logger, tokenizer,
train_classifier='classifier' in config.mode)
else:
raise NotImplementedError(f"Mode {config.mode} not implemented.")
if __name__ == '__main__':
main()