|
import torch |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
from utils import categorical_accuracy |
|
|
|
|
|
def loss_fn(outputs, targets): |
|
return nn.CrossEntropyLoss()(outputs, targets) |
|
|
|
|
|
def train_fn(data_loader, model, optimizer, device, scheduler): |
|
model.train() |
|
train_loss, train_acc = 0.0, 0.0 |
|
|
|
for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)): |
|
ids = d["ids"] |
|
token_type_ids = d["token_type_ids"] |
|
mask = d["mask"] |
|
targets = d["targets"] |
|
|
|
ids = ids.to(device, dtype=torch.long) |
|
token_type_ids = token_type_ids.to(device, dtype=torch.long) |
|
mask = mask.to(device, dtype=torch.long) |
|
targets = targets.to(device, dtype=torch.long) |
|
|
|
optimizer.zero_grad() |
|
outputs = model( |
|
ids=ids, |
|
mask=mask, |
|
token_type_ids=token_type_ids |
|
) |
|
|
|
loss = loss_fn(outputs, targets) |
|
loss.backward() |
|
|
|
optimizer.step() |
|
scheduler.step() |
|
train_loss += loss.item() |
|
pred_labels = torch.argmax(outputs, dim=1) |
|
|
|
train_acc += categorical_accuracy(outputs, targets).item() |
|
|
|
train_loss /= len(data_loader) |
|
train_acc /= len(data_loader) |
|
return train_loss, train_acc |
|
|
|
|
|
def eval_fn(data_loader, model, device): |
|
model.eval() |
|
eval_loss, eval_acc = 0.0, 0.0 |
|
fin_targets = [] |
|
fin_outputs = [] |
|
with torch.no_grad(): |
|
for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)): |
|
ids = d["ids"] |
|
token_type_ids = d["token_type_ids"] |
|
mask = d["mask"] |
|
targets = d["targets"] |
|
|
|
ids = ids.to(device, dtype=torch.long) |
|
token_type_ids = token_type_ids.to(device, dtype=torch.long) |
|
mask = mask.to(device, dtype=torch.long) |
|
targets = targets.to(device, dtype=torch.long) |
|
|
|
outputs = model( |
|
ids=ids, |
|
mask=mask, |
|
token_type_ids=token_type_ids |
|
) |
|
loss = loss_fn(outputs, targets) |
|
eval_loss += loss.item() |
|
pred_labels = torch.argmax(outputs, axis=1) |
|
|
|
eval_acc += categorical_accuracy(outputs, targets).item() |
|
fin_targets.extend(targets.cpu().detach().numpy().tolist()) |
|
fin_outputs.extend(torch.argmax( |
|
outputs, dim=1).cpu().detach().numpy().tolist()) |
|
eval_loss /= len(data_loader) |
|
eval_acc /= len(data_loader) |
|
return fin_outputs, fin_targets, eval_loss, eval_acc |
|
|
|
|
|
|
|
def predict_fn(data_loader, model, device, extract_features=False): |
|
model.eval() |
|
|
|
fin_outputs = [] |
|
extracted_features =[] |
|
with torch.no_grad(): |
|
for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)): |
|
ids = d["ids"] |
|
token_type_ids = d["token_type_ids"] |
|
mask = d["mask"] |
|
|
|
|
|
ids = ids.to(device, dtype=torch.long) |
|
token_type_ids = token_type_ids.to(device, dtype=torch.long) |
|
mask = mask.to(device, dtype=torch.long) |
|
|
|
outputs = model( |
|
ids=ids, |
|
mask=mask, |
|
token_type_ids=token_type_ids |
|
) |
|
if extract_features: |
|
extracted_features.extend( model.extract_features( |
|
ids=ids, |
|
mask=mask, |
|
token_type_ids=token_type_ids |
|
).cpu().detach().numpy().tolist()) |
|
|
|
fin_outputs.extend(torch.argmax( |
|
outputs, dim=1).cpu().detach().numpy().tolist()) |
|
|
|
return fin_outputs, extracted_features |
|
|
|
|