|
import sys |
|
import argparse |
|
import os |
|
import logging |
|
import yaml |
|
import numpy as np |
|
import random |
|
import time |
|
import datetime |
|
import json |
|
import math |
|
from pathlib import Path |
|
from functools import partial |
|
from collections import OrderedDict |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.backends.cudnn as cudnn |
|
import torch.distributed as dist |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
from transformers import AutoModel, BertConfig, AutoTokenizer |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
import PIL |
|
from PIL import Image |
|
from models.clip_tqn import CLP_clinical, ModelRes, TQN_Model, ModelConvNeXt, ModelEfficientV2, ModelDense |
|
import numpy as np |
|
import pandas as pd |
|
from factory import utils |
|
|
|
class Chestxray14_Dataset(Dataset): |
|
def __init__(self, csv_path,image_res): |
|
data_info = pd.read_csv(csv_path) |
|
self.img_path_list = np.asarray(data_info.iloc[:,0]) |
|
self.class_list = np.asarray(data_info.iloc[:,3:]) |
|
|
|
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
self.transform = transforms.Compose([ |
|
transforms.Resize(image_res, interpolation=Image.BICUBIC), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
def __getitem__(self, index): |
|
img_path = self.img_path_list[index].replace('/mnt/petrelfs/zhangxiaoman/DATA/Chestxray/ChestXray8/','/remote-home/share/medical/public/ChestXray8/') |
|
|
|
|
|
class_label = self.class_list[index] |
|
img = Image.open(img_path).convert('RGB') |
|
image = self.transform(img) |
|
return { |
|
"img_path": img_path, |
|
"image": image, |
|
"label": class_label |
|
} |
|
|
|
def __len__(self): |
|
return len(self.img_path_list) |
|
|
|
|
|
def get_text_features(model,text_list,tokenizer,device,max_length): |
|
text_token = tokenizer(list(text_list),add_special_tokens=True, padding='max_length', truncation=True, max_length= max_length, return_tensors="pt").to(device=device) |
|
text_features = model.encode_text(text_token) |
|
return text_features |
|
|
|
def valid_on(model, image_encoder, text_encoder, tokenizer, data_loader, epoch, device, args, config, writer, total_test=False): |
|
model.eval() |
|
image_encoder.eval() |
|
text_encoder.eval() |
|
|
|
text_list = ["atelectasis","cardiomegaly","pleural effusion","infiltration","lung mass","lung nodule","pneumonia","pneumothorax","consolidation","edema","emphysema","fibrosis","pleural thicken","hernia"] |
|
|
|
text_features = get_text_features(text_encoder,text_list,tokenizer,device,max_length=args.max_length) |
|
device_num = torch.cuda.device_count() |
|
text_features = text_features.repeat(int(device_num),1) |
|
|
|
val_scalar_step = epoch*len(data_loader) |
|
val_losses = [] |
|
|
|
gt = torch.FloatTensor() |
|
gt = gt.cuda() |
|
pred = torch.FloatTensor() |
|
pred = pred.cuda() |
|
|
|
for i, sample in enumerate(data_loader): |
|
image = sample['image'].to(device,non_blocking=True) |
|
label = sample['label'].long().to(device) |
|
label = label.float() |
|
gt = torch.cat((gt, label), 0) |
|
with torch.no_grad(): |
|
image_features,image_features_pool = image_encoder(image) |
|
|
|
pred_class = model(image_features,text_features) |
|
val_loss = F.binary_cross_entropy_with_logits(pred_class.view(-1,1),label.view(-1, 1)) |
|
pred_class = torch.sigmoid(pred_class) |
|
pred = torch.cat((pred, pred_class[:,:,0]), 0) |
|
|
|
|
|
val_losses.append(val_loss.item()) |
|
writer.add_scalar('val_loss/loss', val_loss, val_scalar_step) |
|
val_scalar_step += 1 |
|
gt_np = gt.cpu().numpy() |
|
pred_np = pred.cpu().numpy() |
|
np.save(f'{args.output_dir}/gt.npy', gt_np) |
|
np.save(f'{args.output_dir}/pred.npy', pred_np) |
|
|
|
return |
|
|
|
|
|
|
|
def test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, device, args, config, writer, epoch=0, total_test=True): |
|
|
|
valid_on(model, image_encoder, text_encoder, tokenizer, test_dataloader, epoch, device, args, config, writer, total_test=True) |
|
|
|
|
|
def get_dataloader(args, config): |
|
test_dataset = Chestxray14_Dataset(config['test_file'],config['image_res']) |
|
test_dataloader = DataLoader( |
|
test_dataset, |
|
batch_size=config['test_batch_size'], |
|
num_workers=config["test_num_workers"], |
|
pin_memory=True, |
|
collate_fn=None, |
|
shuffle=False, |
|
drop_last=False, |
|
) |
|
|
|
return test_dataloader, test_dataset |
|
|
|
|
|
def get_model(args, config): |
|
if 'resnet' in config['image_encoder_name']: |
|
image_encoder = ModelRes(config['image_encoder_name']).cuda() |
|
preprocess = None |
|
elif 'convnext' in config['image_encoder_name']: |
|
image_encoder = ModelConvNeXt(config['image_encoder_name']).cuda() |
|
preprocess = None |
|
elif 'efficientnet' in config['image_encoder_name']: |
|
image_encoder = ModelEfficientV2(config['image_encoder_name']).cuda() |
|
preprocess = None |
|
elif 'densenet' in config['image_encoder_name']: |
|
image_encoder = ModelDense(config['image_encoder_name']).cuda() |
|
preprocess = None |
|
else: |
|
raise NotImplementedError(f"Unknown image encoder: {config['image_encoder_name']}") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.bert_model_name, do_lower_case=True, local_files_only=True) |
|
text_encoder = CLP_clinical(bert_model_name=args.bert_model_name).cuda() |
|
|
|
if args.bert_pretrained: |
|
checkpoint = torch.load(args.bert_pretrained, map_location='cpu') |
|
state_dict = checkpoint["state_dict"] |
|
text_encoder.load_state_dict(state_dict, strict=False) |
|
if args.freeze_bert: |
|
for param in text_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
if 'lam' in config: |
|
model = TQN_Model(class_num=args.class_num, lam=config['lam']).cuda() |
|
else: |
|
model = TQN_Model(class_num=args.class_num).cuda() |
|
|
|
if args.distributed: |
|
model = torch.nn.DataParallel(model) |
|
image_encoder = torch.nn.DataParallel(image_encoder) |
|
|
|
return model, image_encoder, text_encoder, tokenizer |
|
|
|
def load_checkpoint(model, image_encoder, args): |
|
if os.path.isfile(args.finetune): |
|
checkpoint = torch.load(args.finetune, map_location='cpu') |
|
image_state_dict = checkpoint['image_encoder'] |
|
new_image_state_dict = OrderedDict() |
|
if 'module.' in list(image_encoder.state_dict().keys())[0] and 'module.' not in list(image_state_dict.keys())[0]: |
|
for k, v in image_state_dict.items(): |
|
name = 'module.' + k |
|
new_image_state_dict[name] = v |
|
elif 'module.' not in list(image_encoder.state_dict().keys())[0] and 'module.' in list(image_state_dict.keys())[0]: |
|
for k, v in image_state_dict.items(): |
|
name = k.replace('module.', '') |
|
new_image_state_dict[name] = v |
|
else: |
|
new_image_state_dict = image_state_dict |
|
image_encoder.load_state_dict(new_image_state_dict, strict=False) |
|
|
|
state_dict = checkpoint['model'] |
|
new_state_dict = OrderedDict() |
|
if 'module.' in list(model.state_dict().keys())[0] and 'module.' not in list(state_dict.keys())[0]: |
|
for k, v in state_dict.items(): |
|
name = 'module.' + k |
|
new_state_dict[name] = v |
|
elif 'module.' not in list(model.state_dict().keys())[0] and 'module.' in list(state_dict.keys())[0]: |
|
for k, v in state_dict.items(): |
|
name = k.replace('module.', '') |
|
new_state_dict[name] = v |
|
else: |
|
new_state_dict = state_dict |
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
print("load model success!") |
|
|
|
def seed_torch(seed=42): |
|
|
|
print('=====> Using fixed random seed: ' + str(seed)) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def main(args, config): |
|
'''Data准备''' |
|
test_dataloader, test_dataset = get_dataloader(args, config) |
|
|
|
test_dataloader.num_samples = len(test_dataset) |
|
test_dataloader.num_batches = len(test_dataset) |
|
|
|
|
|
'''Model准备''' |
|
model, image_encoder, text_encoder, tokenizer = get_model(args, config) |
|
|
|
writer = SummaryWriter(os.path.join(args.output_dir, 'log')) |
|
|
|
load_checkpoint(model, image_encoder, args) |
|
|
|
test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, args.device, args, config, writer, epoch=0, total_test=True) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
abs_file_path = os.path.abspath(__file__) |
|
os.environ['OMP_NUM_THREADS'] = '1' |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--momentum', default=False, type=bool) |
|
parser.add_argument('--checkpoint', default='') |
|
parser.add_argument('--finetune', default='base_checkpoint.pt') |
|
|
|
parser.add_argument('--freeze_bert', default=True, type=bool) |
|
parser.add_argument("--use_entity_features", default=True, type=bool) |
|
|
|
parser.add_argument('--config', default='example.yaml') |
|
|
|
parser.add_argument('--fourier', default=True, type=bool) |
|
parser.add_argument('--colourjitter', default=True, type=bool) |
|
parser.add_argument('--class_num', default=1, type=int) |
|
|
|
parser.add_argument('--ignore_index', default=True, type=bool) |
|
parser.add_argument('--add_dataset', default=False, type=bool) |
|
|
|
time_now = time.strftime("%Y-%m-%d-%H-%M", time.localtime()) |
|
parser.add_argument('--output_dir', default=f'./results/test-{time_now}') |
|
parser.add_argument('--aws_output_dir', default=f'./results/test-{time_now}') |
|
parser.add_argument('--bert_pretrained', default= './pretrained_bert_weights/epoch_latest.pt') |
|
parser.add_argument('--bert_model_name', default= './pretrained_bert_weights/UMLSBert_ENG/') |
|
parser.add_argument('--max_length', default=256, type=int) |
|
parser.add_argument('--loss_ratio', default=1, type=int) |
|
parser.add_argument('--device', default='cuda') |
|
parser.add_argument('--seed', default=42, type=int) |
|
|
|
|
|
parser.add_argument("--local_rank", type=int) |
|
parser.add_argument('--distributed', action='store_true', default=False, help='Use multi-processing distributed training to launch ') |
|
|
|
parser.add_argument('--rho', default=0, type=float, help='gpu') |
|
parser.add_argument('--gpu', default=0, type=int, help='gpu') |
|
args = parser.parse_args() |
|
args.config = f'./configs/{args.config}' |
|
|
|
|
|
|
|
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) |
|
if config['finetune'] != '': |
|
args.finetune = config['finetune'] |
|
args.checkpoint = config['finetune'] |
|
|
|
args.loss_ratio = config['loss_ratio'] |
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
Path(args.aws_output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) |
|
|
|
|
|
seed_torch(args.seed) |
|
|
|
main(args, config) |
|
|
|
|
|
|