File size: 3,841 Bytes
5464cad
 
 
 
 
 
 
 
b33b2b4
5464cad
 
 
 
 
 
 
ecf0440
 
 
 
 
 
 
 
 
 
 
 
 
 
5464cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecf0440
 
 
 
 
5464cad
 
 
 
ecf0440
5464cad
0f06e3f
ecf0440
 
 
 
5464cad
ecf0440
5464cad
ecf0440
 
5464cad
 
 
 
ecf0440
5464cad
 
 
 
ecf0440
5464cad
 
0f06e3f
5464cad
ecf0440
 
 
 
 
 
 
 
5464cad
ecf0440
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import pickle
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.amp import autocast, GradScaler
from safetensors.torch import save_file

from data_loader import DUTSDataset, MSRADataset
from model import U2Net

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scaler = GradScaler()


class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
    
    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice


def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.
    for images, masks in tqdm(loader, desc='Training', leave=False):
        images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        with autocast(device_type='cuda'):
            outputs = model(images)
            loss = sum([criterion(output, masks) for output in outputs])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
    return running_loss / len(loader)

def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.
    with torch.no_grad():
        for images, masks in tqdm(loader, desc='Validating', leave=False):
            images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
            outputs = model(images)
            loss = sum([criterion(output, masks) for output in outputs])
            running_loss += loss.item()
    avg_loss = running_loss / len(loader)
    return avg_loss

def save(model, model_name, losses):
    save_file(model.state_dict(), f'results/{model_name}.safetensors')
    with open('results/loss.txt', 'wb') as f:
        pickle.dump(losses, f)


if __name__ == '__main__':
    batch_size = 40
    valid_batch_size = 80
    epochs = 200
    
    lr = 1e-3
    loss_fn_bce = nn.BCEWithLogitsLoss(reduction='mean')
    loss_fn_dice = DiceLoss()
    alpha = 0.6
    loss_fn = lambda o, m: alpha * loss_fn_bce(o, m) + (1 - alpha) * loss_fn_dice(o, m)
    
    model_name = 'u2net-duts-msra'
    model = U2Net()
    model = torch.nn.parallel.DataParallel(model.to(device))
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    train_loader = DataLoader(
        ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]), 
        batch_size=batch_size, shuffle=True, pin_memory=True, 
        num_workers=8, persistent_workers=True
    )
    valid_loader = DataLoader(
        ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]), 
        batch_size=valid_batch_size, shuffle=False, pin_memory=True, 
        num_workers=8, persistent_workers=True
    )
    
    best_val_loss = float('inf')
    losses = {'train': [], 'val': []}
    
    # training loop
    try:
        for epoch in range(epochs):
            train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
            val_loss = validate(model, valid_loader, loss_fn)
            losses['train'].append(train_loss)
            losses['val'].append(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
                
            print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
    finally:
        save(model, model_name, losses)