|
import torch |
|
from tqdm import tqdm |
|
|
|
iterator = tqdm(dataloader, desc="Training", postfix={"train_loss":0.0}) |
|
|
|
for item in iterator: |
|
item = tokenizer.bos_token + " " + item[0] + " " + tokenizer.eos_token |
|
encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to("cuda") |
|
logits = mamba_model(encoded_inp) |
|
|
|
labels = encoded_inp.to(logits.device) |
|
shift_logits = logits[:, :-1, :].contiguous() |
|
labels = labels[:, 1:].contiguous() |
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
loss = loss.detach().cpu().numpy() |
|
logits = logits.detach().cpu().numpy() |
|
labels = labels.detach().cpu().numpy() |
|
encoded_inp = encoded_inp.detach().cpu().numpy() |
|
shift_logits = shift_logits.detach().cpu().numpy() |
|
|
|
iterator.set_postfix({"train_loss": loss.item()}, refresh=False) |