|
import sys |
|
sys.path.insert(0, './pytorch-image-models-main') |
|
|
|
|
|
from moe import Moe,all_loss |
|
|
|
|
|
import os |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7" |
|
|
|
import torch |
|
import cv2 |
|
from albumentations.pytorch import ToTensorV2 |
|
from albumentations import ( |
|
HorizontalFlip, VerticalFlip, ShiftScaleRotate, CLAHE, RandomRotate90, |
|
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, |
|
GaussNoise, MotionBlur, MedianBlur, PiecewiseAffine, RandomResizedCrop, |
|
RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, CoarseDropout, |
|
ShiftScaleRotate, CenterCrop, Resize, SmallestMaxSize |
|
) |
|
import time |
|
|
|
import torch.multiprocessing as mp |
|
import torch.distributed as dist |
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
|
from torch.cuda.amp import autocast, GradScaler |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim import Adam, SGD, AdamW, RMSprop |
|
from torch import nn |
|
import random |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import numpy as np |
|
import logging |
|
from sklearn.model_selection import GroupKFold, StratifiedKFold |
|
import pandas as pd |
|
import math |
|
|
|
CFG = { |
|
'seed': 42, |
|
'model_arch': 'convnext_large_mlp', |
|
|
|
|
|
|
|
'patch': 16, |
|
|
|
'mean':[0.485, 0.456, 0.406] , |
|
'std':[0.229, 0.224, 0.225], |
|
|
|
|
|
'mix_type': 'cutmix', |
|
'mix_prob': 0.7, |
|
|
|
'img_size': 512, |
|
|
|
'class_num': 1784, |
|
|
|
'warmup_epochs': 1, |
|
'warmup_lr_factor': 0.01, |
|
'epochs': 11, |
|
|
|
'train_bs': 24, |
|
'valid_bs': 64, |
|
|
|
'lr': 7.5e-5, |
|
'min_lr': 1e-5, |
|
|
|
'differLR': False, |
|
|
|
'head_lr': 0, |
|
'head_wd': 0.05, |
|
'num_workers': 8, |
|
'device': 'cuda', |
|
'smoothing': 0.1, |
|
|
|
'weight_decay': 2e-5, |
|
'accum_iter': 1, |
|
'verbose_step': 1, |
|
|
|
} |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(level=logging.INFO) |
|
handler = logging.FileHandler(f"logs/{CFG['model_arch']}_train_moe.log") |
|
handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
|
|
|
|
def seed_everything(seed): |
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
def get_img(path): |
|
|
|
im_bgr = cv2.imread(path) |
|
im_rgb = im_bgr[:, :, ::-1] |
|
return im_rgb |
|
|
|
|
|
train_data_root = '/data1/dataset/SnakeCLEF2024/' |
|
val_data_root = '/data1/dataset/SnakeCLEF2023/val/SnakeCLEF2023-large_size/' |
|
train_df = pd.read_csv('./metadata/train_full.csv') |
|
|
|
valid_df = pd.read_csv('./metadata/SnakeCLEF2023-ValMetadata.csv') |
|
is_venomous_df = pd.read_csv('./metadata/venomous_status_list.csv') |
|
class_id2venomous ={} |
|
venomous_mask = torch.ones(CFG['class_num']) |
|
for class_id,is_venomous in zip(is_venomous_df['class_id'],is_venomous_df['MIVS']): |
|
venomous_mask[class_id]=is_venomous |
|
if class_id not in class_id2venomous.keys(): |
|
class_id2venomous[class_id] = is_venomous |
|
train_df['MIVS'] = train_df['class_id'].map(class_id2venomous) |
|
valid_df['MIVS'] = valid_df['class_id'].map(class_id2venomous) |
|
|
|
class FGVCDataset(Dataset): |
|
def __init__(self, df, data_root, |
|
transforms=None, |
|
output_label=True, |
|
one_hot_label=False |
|
): |
|
|
|
super().__init__() |
|
self.df = df.reset_index(drop=True).copy() |
|
self.transforms = transforms |
|
self.data_root = data_root |
|
|
|
self.output_label = output_label |
|
self.one_hot_label = one_hot_label |
|
|
|
if output_label == True: |
|
self.labels = self.df['class_id'].values |
|
self.is_venomous = self.df['MIVS'] |
|
if one_hot_label is True: |
|
self.labels = np.eye(self.df['class_id'].max() + 1)[self.labels] |
|
|
|
def __len__(self): |
|
return self.df.shape[0] |
|
|
|
def __getitem__(self, index: int): |
|
|
|
if self.output_label: |
|
target = self.labels[index] |
|
venomous = self.is_venomous[index] |
|
|
|
image_path = self.data_root + self.df.loc[index]['image_path'] |
|
|
|
|
|
img = get_img(image_path) |
|
|
|
if self.transforms: |
|
img = self.transforms(image=img)['image'] |
|
|
|
if self.output_label == True: |
|
return img, target,venomous |
|
else: |
|
return img |
|
|
|
|
|
def get_train_transforms(): |
|
return Compose([ |
|
RandomResizedCrop(CFG['img_size'], CFG['img_size'], |
|
interpolation=cv2.INTER_CUBIC, scale=(0.5, 1.3)), |
|
Transpose(p=0.5), |
|
HorizontalFlip(p=0.5), |
|
VerticalFlip(p=0.5), |
|
ShiftScaleRotate(p=0.3), |
|
PiecewiseAffine(p=0.5), |
|
RandomBrightnessContrast( |
|
brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=1.0), |
|
OneOf([ |
|
OpticalDistortion(distort_limit=1.0), |
|
GridDistortion(num_steps=5, distort_limit=1.), |
|
|
|
], p=0.5), |
|
|
|
Normalize(mean=CFG['mean'], std=CFG['std'], |
|
max_pixel_value=255.0, p=1.0), |
|
ToTensorV2(p=1.0), |
|
], p=1.) |
|
|
|
|
|
|
|
def get_valid_transforms(): |
|
return Compose([ |
|
|
|
Resize(CFG['img_size'], CFG['img_size'], |
|
interpolation=cv2.INTER_CUBIC), |
|
|
|
Normalize(mean=CFG['mean'], std=CFG['std'], |
|
max_pixel_value=255.0, p=1.0), |
|
ToTensorV2(p=1.0), |
|
], p=1.) |
|
|
|
|
|
def prepare_dataloader(train_df, val_df, train_idx, val_idx): |
|
train_ = train_df.loc[train_idx, :].reset_index(drop=True) |
|
valid_ = val_df.loc[val_idx, :].reset_index(drop=True) |
|
|
|
train_ds = FGVCDataset(train_, train_data_root, transforms=get_train_transforms()) |
|
valid_ds = FGVCDataset(valid_, val_data_root, transforms=get_valid_transforms()) |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_ds, |
|
batch_size=CFG['train_bs'], |
|
pin_memory=False, |
|
drop_last=False, |
|
shuffle=True, |
|
num_workers=CFG['num_workers'] |
|
) |
|
val_loader = torch.utils.data.DataLoader( |
|
valid_ds, |
|
batch_size=CFG['valid_bs'], |
|
num_workers=CFG['num_workers'], |
|
shuffle=False, |
|
pin_memory=False, |
|
) |
|
return train_loader, val_loader |
|
|
|
def rand_bbox(size, lam): |
|
W = size[2] |
|
H = size[3] |
|
cut_rat = np.sqrt(1. - lam) |
|
cut_w = np.int32(W * cut_rat) |
|
cut_h = np.int32(H * cut_rat) |
|
|
|
|
|
cx = np.random.randint(W) |
|
cy = np.random.randint(H) |
|
|
|
bbx1 = np.clip(cx - cut_w // 2, 0, W) |
|
bby1 = np.clip(cy - cut_h // 2, 0, H) |
|
bbx2 = np.clip(cx + cut_w // 2, 0, W) |
|
bby2 = np.clip(cy + cut_h // 2, 0, H) |
|
|
|
return bbx1, bby1, bbx2, bby2 |
|
|
|
|
|
def generate_mask_random(imgs, patch=CFG['patch'], mask_token_num_start=14, lam=0.5): |
|
_, _, W, H = imgs.shape |
|
assert W % patch == 0 |
|
assert H % patch == 0 |
|
p = W // patch |
|
|
|
mask_ratio = 1 - lam |
|
num_masking_patches = min(p**2, int(mask_ratio * (p**2)) + mask_token_num_start) |
|
mask_idx = np.random.permutation(p**2)[:num_masking_patches] |
|
lam = 1 - num_masking_patches / (p**2) |
|
return mask_idx, lam |
|
|
|
|
|
def get_mixed_data(imgs, image_labels, is_venomous,mix_type): |
|
mix_lst = ['cutmix', 'tokenmix', 'mixup', 'randommix'] |
|
assert mix_type in mix_lst, f'Not Supported mix type: {mix_type}' |
|
if mix_type == 'randommix': |
|
|
|
mix_type = random.choice(mix_lst[:-2]) |
|
|
|
if mix_type == 'mixup': |
|
alpha = 2.0 |
|
rand_index = torch.randperm(imgs.size()[0]).cuda() |
|
target_a = image_labels |
|
target_b = image_labels[rand_index] |
|
lam = np.random.beta(alpha, alpha) |
|
imgs = imgs * lam + imgs[rand_index] * (1 - lam) |
|
elif mix_type == 'cutmix': |
|
beta = 1.0 |
|
lam = np.random.beta(beta, beta) |
|
rand_index = torch.randperm(imgs.size()[0]).cuda() |
|
target_a = image_labels |
|
target_b = image_labels[rand_index] |
|
is_venomous_a = is_venomous |
|
is_venomous_b = is_venomous[rand_index] |
|
bbx1, bby1, bbx2, bby2 = rand_bbox(imgs.size(), lam) |
|
imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2] |
|
|
|
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (imgs.size()[-1] * imgs.size()[-2])) |
|
elif mix_type == 'tokenmix': |
|
B, C, W, H = imgs.shape |
|
mask_idx, lam = generate_mask_random(imgs) |
|
rand_index = torch.randperm(imgs.size()[0]).cuda() |
|
p = W // CFG['patch'] |
|
patch_w = CFG['patch'] |
|
patch_h = CFG['patch'] |
|
for idx in mask_idx: |
|
row_s = idx // p |
|
col_s = idx % p |
|
x1 = patch_w * row_s |
|
x2 = x1 + patch_w |
|
y1 = patch_h * col_s |
|
y2 = y1 + patch_h |
|
imgs[:, :, x1:x2, y1:y2] = imgs[rand_index, :, x1:x2, y1:y2] |
|
|
|
target_a = image_labels |
|
target_b = image_labels[rand_index] |
|
|
|
return imgs, target_a, target_b, is_venomous_a,is_venomous_b,lam |
|
|
|
|
|
def train_one_epoch_mix(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False, mix_type=CFG['mix_type']): |
|
model.train() |
|
|
|
running_loss = None |
|
image_preds_all = [] |
|
image_targets_all = [] |
|
|
|
pbar = tqdm(enumerate(train_loader), total=len(train_loader),ncols=70) |
|
for step, (imgs, image_labels,is_venomous) in pbar: |
|
imgs = imgs.to(device).float() |
|
image_labels = image_labels.to(device).long() |
|
is_venomous = is_venomous.to(device).float() |
|
|
|
if np.random.rand(1) < CFG['mix_prob']: |
|
imgs, target_a, target_b,is_venomous_a,is_venomous_b ,lam = get_mixed_data(imgs, image_labels, is_venomous,mix_type) |
|
with autocast(): |
|
|
|
|
|
|
|
y_hat,expert_pred,alpha,image_preds = model(imgs) |
|
loss = loss_fn(y_hat,expert_pred,alpha,image_preds,target_a,is_venomous_a)*lam+loss_fn(y_hat,expert_pred,alpha,image_preds,target_b,is_venomous_b)*(1.0-lam) |
|
scaler.scale(loss).backward() |
|
else: |
|
with autocast(): |
|
y_hat,expert_pred,alpha,image_preds = model(imgs) |
|
loss = loss_fn(y_hat,expert_pred,alpha,image_preds,image_labels,is_venomous) |
|
scaler.scale(loss).backward() |
|
image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()] |
|
image_targets_all += [image_labels.detach().cpu().numpy()] |
|
if running_loss is None: |
|
running_loss = loss.item() |
|
else: |
|
running_loss = running_loss * .99 + loss.item() * .01 |
|
|
|
|
|
if ((step + 1) % CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)): |
|
|
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad() |
|
|
|
if scheduler is not None and schd_batch_update: |
|
scheduler.step() |
|
|
|
if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)): |
|
description = f'epoch {epoch} loss: {running_loss:.4f}' |
|
pbar.set_description(description) |
|
|
|
image_preds_all = np.concatenate(image_preds_all) |
|
image_targets_all = np.concatenate(image_targets_all) |
|
accuracy = (image_preds_all == image_targets_all).mean() |
|
|
|
print('Train multi-class accuracy = {:.4f}'.format(accuracy)) |
|
logger.info(' Epoch: ' + str(epoch) + ' Train multi-class accuracy = {:.4f}'.format(accuracy)) |
|
logger.info(' Epoch: ' + str(epoch) + ' Train loss = {:.4f}'.format(running_loss)) |
|
|
|
if scheduler is not None and not schd_batch_update: |
|
scheduler.step() |
|
|
|
|
|
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False): |
|
model.eval() |
|
|
|
loss_sum = 0 |
|
sample_num = 0 |
|
image_preds_all = [] |
|
image_targets_all = [] |
|
|
|
pbar = tqdm(enumerate(val_loader), total=len(val_loader),ncols=70) |
|
for step, (imgs, image_labels,is_venomous) in pbar: |
|
imgs = imgs.to(device).float() |
|
image_labels = image_labels.to(device).long() |
|
is_venomous = is_venomous.to(device).float() |
|
|
|
y_hat,expert_pred,alpha,image_preds = model(imgs) |
|
image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()] |
|
image_targets_all += [image_labels.detach().cpu().numpy()] |
|
|
|
openset_idx = image_labels == -1 |
|
image_labels[openset_idx] = 0 |
|
loss = loss_fn(image_preds, image_labels) |
|
|
|
loss_sum += loss.item() * image_labels.shape[0] |
|
sample_num += image_labels.shape[0] |
|
|
|
if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)): |
|
description = f'epoch {epoch} loss: {loss_sum / sample_num:.4f}' |
|
pbar.set_description(description) |
|
|
|
image_preds_all = np.concatenate(image_preds_all) |
|
image_targets_all = np.concatenate(image_targets_all) |
|
|
|
accuracy = (image_preds_all == image_targets_all).mean() |
|
print('validation multi-class accuracy = {:.4f}'.format(accuracy)) |
|
logger.info(' Epoch: ' + str(epoch) + ' validation multi-class accuracy = {:.4f}'.format(accuracy)) |
|
|
|
if scheduler is not None: |
|
if schd_loss_update: |
|
scheduler.step(loss_sum / sample_num) |
|
else: |
|
scheduler.step() |
|
return accuracy |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
seed_everything(CFG['seed']) |
|
logger.info(CFG) |
|
|
|
trn_idx = np.arange(train_df.shape[0]) |
|
val_idx = np.arange(valid_df.shape[0]) |
|
|
|
df_class_id = np.array(train_df['class_id']) |
|
class_counts = np.bincount(df_class_id) |
|
device = torch.device(CFG['device']) |
|
|
|
|
|
model = Moe(CFG['model_arch'],CFG['class_num'],venomous_mask) |
|
model = nn.DataParallel(model) |
|
model.to(device) |
|
model.module.not_venomous_mask.to(device) |
|
model.module.venomous_mask.to(device) |
|
|
|
|
|
train_loader, val_loader = prepare_dataloader(train_df, valid_df, trn_idx, val_idx) |
|
|
|
scaler = GradScaler() |
|
|
|
|
|
if CFG['differLR']: |
|
backbone_params = list(map(id, model.module.backbone.parameters())) |
|
head_params = filter(lambda p: id(p) not in backbone_params, model.parameters()) |
|
|
|
if CFG['head_lr']>0: |
|
lr_cfg = [ {'params': model.module.backbone.parameters(), 'lr': CFG['lr'] ,'weight_decay':CFG['weight_decay']}, |
|
{'params': head_params , 'lr': CFG['head_lr'],'weight_decay':CFG['head_wd']}] |
|
optimizer = torch.optim.AdamW(lr_cfg, lr=CFG['lr'], weight_decay=CFG['weight_decay']) |
|
else: |
|
|
|
print('frozen center') |
|
|
|
model.module.center.requires_grad = False |
|
lr_cfg = [ |
|
{'params': model.module.backbone.parameters(), 'lr': CFG['lr'], 'weight_decay': CFG['weight_decay']}] |
|
|
|
optimizer = torch.optim.AdamW(lr_cfg, lr=CFG['lr'], weight_decay=CFG['weight_decay']) |
|
|
|
|
|
else: |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay']) |
|
|
|
|
|
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
optimizer, T_max=CFG['epochs'] - CFG['warmup_epochs'], eta_min=CFG['min_lr'] |
|
) |
|
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
|
optimizer, start_factor=CFG['warmup_lr_factor'], total_iters=CFG['warmup_epochs'] |
|
) |
|
scheduler = torch.optim.lr_scheduler.SequentialLR( |
|
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[CFG['warmup_epochs']] |
|
) |
|
|
|
|
|
loss_tr = all_loss(class_counts,CFG['class_num']).to(device) |
|
|
|
loss_fn = nn.CrossEntropyLoss(label_smoothing=CFG['smoothing']).to(device) |
|
|
|
best_acc = 0.0 |
|
for epoch in range(CFG['epochs']): |
|
print(optimizer.param_groups[0]['lr']) |
|
train_one_epoch_mix(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler) |
|
temp_acc = 0.0 |
|
with torch.no_grad(): |
|
temp_acc = valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False) |
|
if temp_acc > best_acc: |
|
torch.save(model.state_dict(), './checkpoints_moe/moe_{}_mix_{}_mixprob_{}_seed_{}_ls_{}_epochs_{}_differLR_{}_imsize{}.pth'.format( |
|
CFG['model_arch'], |
|
CFG['mix_type'], |
|
CFG['mix_prob'], |
|
CFG['seed'], |
|
CFG['smoothing'], |
|
CFG['epochs'], |
|
CFG['differLR'], |
|
CFG['img_size'])) |
|
if temp_acc > best_acc: |
|
best_acc = temp_acc |
|
|
|
del model, optimizer, train_loader, val_loader, scaler, scheduler |
|
print(best_acc) |
|
logger.info('BEST-Valid-ACC: ' + str(best_acc)) |
|
torch.cuda.empty_cache() |
|
|