Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,841 Bytes
5464cad b33b2b4 5464cad ecf0440 5464cad ecf0440 5464cad ecf0440 5464cad 0f06e3f ecf0440 5464cad ecf0440 5464cad ecf0440 5464cad ecf0440 5464cad ecf0440 5464cad 0f06e3f 5464cad ecf0440 5464cad ecf0440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.amp import autocast, GradScaler
from safetensors.torch import save_file
from data_loader import DUTSDataset, MSRADataset
from model import U2Net
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scaler = GradScaler()
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
inputs = torch.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice
def train_one_epoch(model, loader, criterion, optimizer):
model.train()
running_loss = 0.
for images, masks in tqdm(loader, desc='Training', leave=False):
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
optimizer.zero_grad()
with autocast(device_type='cuda'):
outputs = model(images)
loss = sum([criterion(output, masks) for output in outputs])
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
return running_loss / len(loader)
def validate(model, loader, criterion):
model.eval()
running_loss = 0.
with torch.no_grad():
for images, masks in tqdm(loader, desc='Validating', leave=False):
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
outputs = model(images)
loss = sum([criterion(output, masks) for output in outputs])
running_loss += loss.item()
avg_loss = running_loss / len(loader)
return avg_loss
def save(model, model_name, losses):
save_file(model.state_dict(), f'results/{model_name}.safetensors')
with open('results/loss.txt', 'wb') as f:
pickle.dump(losses, f)
if __name__ == '__main__':
batch_size = 40
valid_batch_size = 80
epochs = 200
lr = 1e-3
loss_fn_bce = nn.BCEWithLogitsLoss(reduction='mean')
loss_fn_dice = DiceLoss()
alpha = 0.6
loss_fn = lambda o, m: alpha * loss_fn_bce(o, m) + (1 - alpha) * loss_fn_dice(o, m)
model_name = 'u2net-duts-msra'
model = U2Net()
model = torch.nn.parallel.DataParallel(model.to(device))
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
train_loader = DataLoader(
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
batch_size=batch_size, shuffle=True, pin_memory=True,
num_workers=8, persistent_workers=True
)
valid_loader = DataLoader(
ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
batch_size=valid_batch_size, shuffle=False, pin_memory=True,
num_workers=8, persistent_workers=True
)
best_val_loss = float('inf')
losses = {'train': [], 'val': []}
# training loop
try:
for epoch in range(epochs):
train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
val_loss = validate(model, valid_loader, loss_fn)
losses['train'].append(train_loss)
losses['val'].append(val_loss)
if val_loss < best_val_loss:
best_val_loss = val_loss
save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
finally:
save(model, model_name, losses) |