muPPIt / muppit /classifier.py
AlienChen's picture
Upload 139 files
65bd8af verified
import itertools
import typing
import hydra.utils
import lightning as L
import torch
import torch.nn.functional as F
import torchmetrics
import transformers
import dataloader
import models.dit
import noise_schedule
class MicroAveragingMetric(torchmetrics.Metric):
"""Micro-averaging metric.
Adapted from https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py#L12
"""
def __init__(self, class_idx: typing.Optional[int] = 1,
dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.class_idx = torch.tensor(class_idx) \
if class_idx is not None else None
self.add_state("numerator", default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.add_state("denominator", default=torch.tensor(0.0),
dist_reduce_fx="sum")
def _update(
self, numerator, denominator, preds, y) -> tuple:
raise NotImplementedError
def update(self, logits: torch.Tensor, y: torch.Tensor):
# update metric states
preds = torch.argmax(logits, dim=-1)
y = y.view(-1)
assert preds.shape == y.shape, \
f"preds shape {preds.shape} != y shape {y.shape}"
self.numerator, self.denominator = self._update(
self.numerator, self.denominator, preds, y)
def compute(self):
# compute final result
value = self.numerator.float() / self.denominator \
if self.denominator.item() > 0. else torch.tensor(0.0)
return value
def reset(self):
self.numerator = torch.tensor(0.0).to(self.device)
self.denominator = torch.tensor(0.0).to(self.device)
class CrossEntropy(MicroAveragingMetric):
"""Calculates cross-entropy loss."""
def _update(
self, numerator, denominator, logits, y) -> tuple:
with torch.no_grad():
numerator += F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
ignore_index=-100,
reduction='sum')
denominator += y.numel()
return numerator, denominator
# Overrides parent class to use logits and not (argmax) preds
def update(self, logits: torch.Tensor, y: torch.Tensor):
y = y.view(-1)
self.numerator, self.denominator = self._update(
self.numerator, self.denominator, logits, y)
class Accuracy(MicroAveragingMetric):
"""Calculates accuracy.
Can be used to calculate accuracy per class.
Copied from:
https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
"""
def _update(
self, numerator, denominator, preds, y) -> tuple:
if self.class_idx is None:
numerator += (preds == y).sum()
denominator += y.numel()
else:
class_idx = self.class_idx
relevant_idxs = (y == class_idx)
numerator += (preds[relevant_idxs] == class_idx).sum()
denominator += relevant_idxs.sum()
relevant_idxs = (y != class_idx)
numerator += (preds[relevant_idxs] != class_idx).sum()
denominator += relevant_idxs.sum()
return numerator, denominator
class Precision(MicroAveragingMetric):
"""Calculates precision.
Can be used to calculate precision per class.
Adapted from:
https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
"""
def _update(self, numerator, denominator, preds, y) -> tuple:
class_idx = self.class_idx
relevant_idxs = (preds == class_idx)
numerator += (y[relevant_idxs] == class_idx).sum()
denominator += relevant_idxs.sum()
return numerator, denominator
class Recall(MicroAveragingMetric):
"""Calculate recall.
Can be used to calculate recall per class.
Adapted from:
https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
"""
def _update(self, numerator, denominator, preds, y) -> tuple:
class_idx = self.class_idx
relevant_idxs = (y == class_idx)
numerator += (preds[relevant_idxs] == class_idx).sum()
denominator += relevant_idxs.sum()
return numerator, denominator
class Classifier(L.LightningModule):
def __init__(
self,
config,
tokenizer: transformers.PreTrainedTokenizer,
pretrained_backbone: typing.Optional[torch.nn.Module] = None):
super().__init__()
self.save_hyperparameters(ignore=['pretrained_backbone'])
self.config = config
# This param indicates whether this model will be used
# for guidance (False) or only evaluation (True).
self.is_eval_classifier = getattr(
config, 'is_eval_classifier', False)
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size
self.antithetic_sampling = config.training.antithetic_sampling
self.importance_sampling = config.training.importance_sampling
self.change_of_variables = config.training.change_of_variables
if (not hasattr(self.tokenizer, 'mask_token')
or self.tokenizer.mask_token is None):
self.mask_index = self.vocab_size
self.vocab_size += 1
else:
self.mask_index = self.tokenizer.mask_token_id
if config.classifier_backbone == 'dit':
self.classifier_model = models.dit.DITClassifier(
self.config, vocab_size=self.vocab_size)
elif self.config.classifier_backbone == 'dimamba':
self.classifier_model = models.dimamba.DiMambaClassifier(
self.config, vocab_size=self.vocab_size,
pad_token_id=self.tokenizer.pad_token_id)
elif config.classifier_backbone == 'hyenadna':
hyena_config = transformers.AutoConfig.from_pretrained(
config.classifier_model.hyena_model_name_or_path,
n_layer=config.classifier_model.n_layer,
trust_remote_code=True
)
self.classifier_model = transformers.AutoModelForSequenceClassification.from_config(
hyena_config,
pretrained=False,
num_labels=config.data.num_classes,
problem_type='single_label_classification',
trust_remote_code=True
)
else:
raise NotImplementedError(
f"Classifier backbone "
f"{self.config.classifier_backbone} not "
f"implemented.")
if pretrained_backbone is not None: # For PPLM / NOS
self.classifier_model.load_pretrained_encoder(
pretrained_backbone)
# Metrics are automatically reset at end of epoch
metrics = torchmetrics.MetricCollection({
'cross_entropy': CrossEntropy(),
'accuracy': Accuracy(class_idx=None),
})
if config.data.num_classes > 2:
for c in range(config.data.num_classes):
metrics.add_metrics(
{f"accuracy_class{c}": Accuracy(class_idx=c),
f"precision_class{c}": Precision(class_idx=c),
f"recall_class{c}": Recall(class_idx=c)})
else:
metrics.add_metrics(
{'precision': Precision(class_idx=1),
'recall': Recall(class_idx=1)})
metrics.set_dtype(torch.float64)
self.train_metrics = metrics.clone(prefix='train/')
self.valid_metrics = metrics.clone(prefix='val/')
self.T = config.T
self.noise = noise_schedule.get_noise(config,
dtype=self.dtype)
self.sampling_eps = config.training.sampling_eps
self.lr = config.optim.lr
self.time_conditioning = config.time_conditioning
self.fast_forward_epochs = None
self.fast_forward_batches = None
def on_load_checkpoint(self, checkpoint):
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
self.fast_forward_epochs = checkpoint['loops'][
'fit_loop']['epoch_progress']['current']['completed']
self.fast_forward_batches = checkpoint['loops'][
'fit_loop']['epoch_loop.batch_progress'][
'current']['completed']
def on_save_checkpoint(self, checkpoint):
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
# ['epoch_loop.batch_progress']['total']['completed'] is
# 1 iteration behind, so we're using the optimizer's
# progress.
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['total'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total'][
'completed'] * self.trainer.accumulate_grad_batches
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['current'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['current'][
'completed'] * self.trainer.accumulate_grad_batches
# _batches_that_stepped tracks the number of global
# steps, not the number of local steps, so we don't
# multiply with self.trainer.accumulate_grad_batches
# here.
checkpoint['loops']['fit_loop'][
'epoch_loop.state_dict'][
'_batches_that_stepped'] = \
checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total']['completed']
if 'sampler' not in checkpoint.keys():
checkpoint['sampler'] = {}
if hasattr(self.trainer.train_dataloader.sampler,
'state_dict'):
sampler_state_dict = self.trainer. \
train_dataloader.sampler.state_dict()
checkpoint['sampler'][
'random_state'] = sampler_state_dict.get(
'random_state', None)
else:
checkpoint['sampler']['random_state'] = None
def on_train_start(self):
# Adapted from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
distributed = (
self.trainer._accelerator_connector.use_distributed_sampler
and self.trainer._accelerator_connector.is_distributed)
if distributed:
sampler_cls = dataloader.FaultTolerantDistributedSampler
else:
sampler_cls = dataloader.RandomFaultTolerantSampler
updated_dls = []
for dl in self.trainer.fit_loop._combined_loader.flattened:
if hasattr(dl.sampler, 'shuffle'):
dl_sampler = sampler_cls(
dl.dataset, shuffle=dl.sampler.shuffle)
else:
dl_sampler = sampler_cls(dl.dataset)
if (distributed
and self.fast_forward_epochs is not None
and self.fast_forward_batches is not None):
dl_sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': (self.fast_forward_batches
* self.config.loader.batch_size)})
updated_dls.append(
torch.utils.data.DataLoader(
dl.dataset,
batch_size=self.config.loader.batch_size,
num_workers=self.config.loader.num_workers,
pin_memory=self.config.loader.pin_memory,
sampler=dl_sampler,
shuffle=False,
persistent_workers=self.config.loader.persistent_workers
))
self.trainer.fit_loop._combined_loader.flattened = updated_dls
def forward(self, x, sigma=None, x_emb=None, attention_mask=None):
"""Returns logits.
x_emb can be provided during PPLM / NoS-style guidance
(see: https://arxiv.org/abs/2305.20009).
"""
if self.is_eval_classifier:
logits = self.classifier_model(x)
if hasattr(logits, 'logits'):
logits = logits.logits
else:
sigma = self._process_sigma(sigma) if sigma is not None else sigma
with torch.cuda.amp.autocast(dtype=torch.float32):
logits = self.classifier_model(x, sigma, x_emb=x_emb, attention_mask=attention_mask)
return logits
def get_log_probs(self, x, sigma, x_emb=None):
"""Returns log probabilities.
Use for CBG-style guidance.
"""
if self.is_eval_classifier:
raise NotImplementedError(
'`get_log_prob` not implemented for classifiers '
'that are meant to be used for evaluation purposes '
'only.')
with torch.cuda.amp.autocast(dtype=torch.float32):
return torch.nn.functional.log_softmax(
self.forward(x, sigma, x_emb=x_emb), dim=-1)
def training_step(self, batch, batch_idx):
loss = self._compute_loss(batch, prefix='train')
self.log(name='trainer/loss',
value=loss.item(),
on_step=True,
on_epoch=False,
sync_dist=True,
prog_bar=True)
self.log(name='lr',
value=
self.trainer.optimizers[0].param_groups[0][
'lr'],
on_step=True,
on_epoch=False,
sync_dist=True,
prog_bar=True, logger=False)
return loss
def validation_step(self, batch, batch_idx):
return self._compute_loss(batch, prefix='val')
def configure_optimizers(self):
# TODO(yair): Lightning currently giving this warning when using `fp16`:
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
# Not clear if this is a problem or not.
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
optimizer = torch.optim.AdamW(
itertools.chain(self.classifier_model.parameters(),
self.noise.parameters()),
lr=self.config.optim.lr,
betas=(self.config.optim.beta1,
self.config.optim.beta2),
eps=self.config.optim.eps,
weight_decay=self.config.optim.weight_decay)
scheduler = hydra.utils.instantiate(
self.config.lr_scheduler, optimizer=optimizer)
scheduler_dict = {
'scheduler': scheduler,
'interval': 'step',
'monitor': 'val/loss',
'name': 'trainer/lr',
}
return [optimizer], [scheduler_dict]
def _q_xt(self, x, move_chance):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape
(batch_size, 1).
"""
move_indices = torch.rand(
*x.shape, device=x.device) < move_chance
if self.config.diffusion == 'absorbing_state':
return torch.where(move_indices, self.mask_index, x)
if self.config.diffusion == 'uniform':
uniform_tensor = torch.randint(
0, self.vocab_size, x.shape, device=x.device)
return torch.where(move_indices, uniform_tensor, x)
raise NotImplementedError(
f'Diffusion type {self.config.diffusion} not '
'implemented.')
def _compute_loss(self, batch, prefix):
x0 = batch['input_ids']
attention_mask = batch['attention_mask']
t = None
if self.is_eval_classifier:
logits = self.forward(x0)
elif self.config.parameterization == 'ar':
# do not add noise for AR FUDGE and AR PPLM
logits = self.forward(
x0, attention_mask=attention_mask)
else:
t = self._sample_t(x0.shape[0])
if self.T > 0:
t = (t * self.T).to(torch.int)
t = t / self.T
# t \in {1/T, 2/T, ..., 1}
t += (1 / self.T)
if self.change_of_variables:
time_conditioning = t[:, None]
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
move_chance = torch.exp(f_0 + t * (f_T - f_0))
move_chance = move_chance[:, None]
else:
sigma, _ = self.noise(t)
time_conditioning = sigma[:, None]
move_chance = 1 - torch.exp(-sigma[:, None])
xt = self._q_xt(x0, move_chance)
logits = self.forward(xt, time_conditioning, attention_mask=attention_mask)
if hasattr(self.config.data, 'label_col'):
if f"{self.config.data.label_col}_threshold" in batch:
y = batch[f"{self.config.data.label_col}_threshold"]
else:
y = batch[self.config.data.label_col]
else:
y = batch['label']
if (not self.is_eval_classifier
and getattr(self.config.training, 'use_label_smoothing', False)):
# Interpolate between one-hot and uniform distribution
labels = (torch.nn.functional.one_hot(y, self.config.data.num_classes) * (1 - t)[..., None] +
(1 / self.config.data.num_classes) * t[..., None])
else:
labels = y.view(-1)
if getattr(self.config, 'is_fudge_classifier', False):
expanded_y = y.unsqueeze(1).expand(-1, logits.shape[1]) # batch x seq
logits = logits.view(-1, self.config.data.num_classes)[attention_mask.flatten()==1, ...]
y = expanded_y.flatten().long()[attention_mask.flatten()==1]
loss = torch.nn.functional.cross_entropy(
logits,
y,
ignore_index=-100,
reduction='mean')
else:
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels,
ignore_index=-100,
reduction='mean')
if prefix == 'train':
self.train_metrics.update(logits, y)
metrics = self.train_metrics
elif prefix == 'val':
self.valid_metrics.update(logits, y)
metrics = self.valid_metrics
elif prefix == 'test':
self.test_metrics.update(logits, y)
metrics = self.test_metrics
else:
raise ValueError(f'Invalid prefix: {prefix}')
self.log_dict(metrics,
on_step=False,
on_epoch=True,
sync_dist=True)
return loss
def _sample_t(self, n):
_eps_t = torch.rand(n, device=self.device)
if self.antithetic_sampling:
offset = torch.arange(n, device=self.device) / n
_eps_t = (_eps_t / n + offset) % 1
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
if self.importance_sampling:
return self.noise.importance_sampling_transformation(
t)
return t
def _process_sigma(self, sigma):
if sigma.ndim > 1:
sigma = sigma.squeeze(-1)
if not self.time_conditioning:
sigma = torch.zeros_like(sigma)
assert sigma.ndim == 1, sigma.shape
return sigma