CHE-Master / test_example.py
Elfenreigen's picture
Upload 2 files
9da3725 verified
import sys
import argparse
import os
import logging
import yaml
import numpy as np
import random
import time
import datetime
import json
import math
from pathlib import Path
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModel, BertConfig, AutoTokenizer
from torch.utils.data import Dataset
from torchvision import transforms
import PIL
from PIL import Image
from models.clip_tqn import CLP_clinical, ModelRes, TQN_Model, ModelConvNeXt, ModelEfficientV2, ModelDense
import numpy as np
import pandas as pd
from factory import utils
class Chestxray14_Dataset(Dataset):
def __init__(self, csv_path,image_res):
data_info = pd.read_csv(csv_path)
self.img_path_list = np.asarray(data_info.iloc[:,0])
self.class_list = np.asarray(data_info.iloc[:,3:])
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
self.transform = transforms.Compose([
transforms.Resize(image_res, interpolation=Image.BICUBIC),
transforms.ToTensor(),
normalize,
])
def __getitem__(self, index):
img_path = self.img_path_list[index].replace('/mnt/petrelfs/zhangxiaoman/DATA/Chestxray/ChestXray8/','/remote-home/share/medical/public/ChestXray8/')
# revise according to the actual cirtumstances
class_label = self.class_list[index]
img = Image.open(img_path).convert('RGB')
image = self.transform(img)
return {
"img_path": img_path,
"image": image,
"label": class_label
}
def __len__(self):
return len(self.img_path_list)
def get_text_features(model,text_list,tokenizer,device,max_length):
text_token = tokenizer(list(text_list),add_special_tokens=True, padding='max_length', truncation=True, max_length= max_length, return_tensors="pt").to(device=device)
text_features = model.encode_text(text_token)
return text_features
def valid_on(model, image_encoder, text_encoder, tokenizer, data_loader, epoch, device, args, config, writer, total_test=False):
model.eval()
image_encoder.eval()
text_encoder.eval()
text_list = ["atelectasis","cardiomegaly","pleural effusion","infiltration","lung mass","lung nodule","pneumonia","pneumothorax","consolidation","edema","emphysema","fibrosis","pleural thicken","hernia"]
text_features = get_text_features(text_encoder,text_list,tokenizer,device,max_length=args.max_length)
device_num = torch.cuda.device_count()
text_features = text_features.repeat(int(device_num),1)
val_scalar_step = epoch*len(data_loader)
val_losses = []
gt = torch.FloatTensor()
gt = gt.cuda()
pred = torch.FloatTensor()
pred = pred.cuda()
for i, sample in enumerate(data_loader):
image = sample['image'].to(device,non_blocking=True)
label = sample['label'].long().to(device)
label = label.float()
gt = torch.cat((gt, label), 0)
with torch.no_grad():
image_features,image_features_pool = image_encoder(image)
pred_class = model(image_features,text_features)#b,14,2/1
val_loss = F.binary_cross_entropy_with_logits(pred_class.view(-1,1),label.view(-1, 1))
pred_class = torch.sigmoid(pred_class)
pred = torch.cat((pred, pred_class[:,:,0]), 0)
val_losses.append(val_loss.item())
writer.add_scalar('val_loss/loss', val_loss, val_scalar_step)
val_scalar_step += 1
gt_np = gt.cpu().numpy()
pred_np = pred.cpu().numpy()
np.save(f'{args.output_dir}/gt.npy', gt_np)
np.save(f'{args.output_dir}/pred.npy', pred_np)
return
def test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, device, args, config, writer, epoch=0, total_test=True):
valid_on(model, image_encoder, text_encoder, tokenizer, test_dataloader, epoch, device, args, config, writer, total_test=True)
def get_dataloader(args, config):
test_dataset = Chestxray14_Dataset(config['test_file'],config['image_res'])
test_dataloader = DataLoader(
test_dataset,
batch_size=config['test_batch_size'],
num_workers=config["test_num_workers"],
pin_memory=True,
collate_fn=None,
shuffle=False,
drop_last=False,
)
return test_dataloader, test_dataset
def get_model(args, config):
if 'resnet' in config['image_encoder_name']:
image_encoder = ModelRes(config['image_encoder_name']).cuda()
preprocess = None
elif 'convnext' in config['image_encoder_name']:
image_encoder = ModelConvNeXt(config['image_encoder_name']).cuda()
preprocess = None
elif 'efficientnet' in config['image_encoder_name']:
image_encoder = ModelEfficientV2(config['image_encoder_name']).cuda()
preprocess = None
elif 'densenet' in config['image_encoder_name']:
image_encoder = ModelDense(config['image_encoder_name']).cuda()
preprocess = None
else:
raise NotImplementedError(f"Unknown image encoder: {config['image_encoder_name']}")
tokenizer = AutoTokenizer.from_pretrained(args.bert_model_name, do_lower_case=True, local_files_only=True)
text_encoder = CLP_clinical(bert_model_name=args.bert_model_name).cuda()
if args.bert_pretrained:
checkpoint = torch.load(args.bert_pretrained, map_location='cpu')
state_dict = checkpoint["state_dict"]
text_encoder.load_state_dict(state_dict, strict=False)
if args.freeze_bert:
for param in text_encoder.parameters():
param.requires_grad = False
if 'lam' in config:
model = TQN_Model(class_num=args.class_num, lam=config['lam']).cuda()
else:
model = TQN_Model(class_num=args.class_num).cuda()
if args.distributed:
model = torch.nn.DataParallel(model)
image_encoder = torch.nn.DataParallel(image_encoder)
return model, image_encoder, text_encoder, tokenizer
def load_checkpoint(model, image_encoder, args):
if os.path.isfile(args.finetune):
checkpoint = torch.load(args.finetune, map_location='cpu')
image_state_dict = checkpoint['image_encoder']
new_image_state_dict = OrderedDict()
if 'module.' in list(image_encoder.state_dict().keys())[0] and 'module.' not in list(image_state_dict.keys())[0]:
for k, v in image_state_dict.items():
name = 'module.' + k
new_image_state_dict[name] = v
elif 'module.' not in list(image_encoder.state_dict().keys())[0] and 'module.' in list(image_state_dict.keys())[0]:
for k, v in image_state_dict.items():
name = k.replace('module.', '')
new_image_state_dict[name] = v
else:
new_image_state_dict = image_state_dict
image_encoder.load_state_dict(new_image_state_dict, strict=False)
state_dict = checkpoint['model']
new_state_dict = OrderedDict()
if 'module.' in list(model.state_dict().keys())[0] and 'module.' not in list(state_dict.keys())[0]:
for k, v in state_dict.items():
name = 'module.' + k
new_state_dict[name] = v
elif 'module.' not in list(model.state_dict().keys())[0] and 'module.' in list(state_dict.keys())[0]:
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v
else:
new_state_dict = state_dict
model.load_state_dict(new_state_dict, strict=False)
print("load model success!")
def seed_torch(seed=42):
# if os.environ['LOCAL_RANK'] == '0':
print('=====> Using fixed random seed: ' + str(seed))
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main(args, config):
'''Data准备'''
test_dataloader, test_dataset = get_dataloader(args, config)
test_dataloader.num_samples = len(test_dataset)
test_dataloader.num_batches = len(test_dataset)
'''Model准备'''
model, image_encoder, text_encoder, tokenizer = get_model(args, config)
writer = SummaryWriter(os.path.join(args.output_dir, 'log'))
load_checkpoint(model, image_encoder, args)
test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, args.device, args, config, writer, epoch=0, total_test=True)
if __name__ == '__main__':
abs_file_path = os.path.abspath(__file__)
os.environ['OMP_NUM_THREADS'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
parser = argparse.ArgumentParser()
parser.add_argument('--momentum', default=False, type=bool)
parser.add_argument('--checkpoint', default='')
parser.add_argument('--finetune', default='base_checkpoint.pt')
parser.add_argument('--freeze_bert', default=True, type=bool)
parser.add_argument("--use_entity_features", default=True, type=bool)
parser.add_argument('--config', default='example.yaml')
parser.add_argument('--fourier', default=True, type=bool)
parser.add_argument('--colourjitter', default=True, type=bool)
parser.add_argument('--class_num', default=1, type=int) # FT1, FF2
parser.add_argument('--ignore_index', default=True, type=bool) #原始为false; +data时-1作为标记不算loss, 改成True
parser.add_argument('--add_dataset', default=False, type=bool)
time_now = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
parser.add_argument('--output_dir', default=f'./results/test-{time_now}')
parser.add_argument('--aws_output_dir', default=f'./results/test-{time_now}')
parser.add_argument('--bert_pretrained', default= './pretrained_bert_weights/epoch_latest.pt')
parser.add_argument('--bert_model_name', default= './pretrained_bert_weights/UMLSBert_ENG/')
parser.add_argument('--max_length', default=256, type=int)
parser.add_argument('--loss_ratio', default=1, type=int)
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
# distributed training parameters
parser.add_argument("--local_rank", type=int)
parser.add_argument('--distributed', action='store_true', default=False, help='Use multi-processing distributed training to launch ')
parser.add_argument('--rho', default=0, type=float, help='gpu')
parser.add_argument('--gpu', default=0, type=int, help='gpu')
args = parser.parse_args()
args.config = f'./configs/{args.config}'
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
if config['finetune'] != '':
args.finetune = config['finetune']
args.checkpoint = config['finetune']
args.loss_ratio = config['loss_ratio']
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.aws_output_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
seed_torch(args.seed)
main(args, config)