import sys import argparse import os import logging import yaml import numpy as np import random import time import datetime import json import math from pathlib import Path from functools import partial from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn import torch.distributed as dist from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from transformers import AutoModel, BertConfig, AutoTokenizer from torch.utils.data import Dataset from torchvision import transforms import PIL from PIL import Image from models.clip_tqn import CLP_clinical, ModelRes, TQN_Model, ModelConvNeXt, ModelEfficientV2, ModelDense import numpy as np import pandas as pd from factory import utils class Chestxray14_Dataset(Dataset): def __init__(self, csv_path,image_res): data_info = pd.read_csv(csv_path) self.img_path_list = np.asarray(data_info.iloc[:,0]) self.class_list = np.asarray(data_info.iloc[:,3:]) normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) self.transform = transforms.Compose([ transforms.Resize(image_res, interpolation=Image.BICUBIC), transforms.ToTensor(), normalize, ]) def __getitem__(self, index): img_path = self.img_path_list[index].replace('/mnt/petrelfs/zhangxiaoman/DATA/Chestxray/ChestXray8/','/remote-home/share/medical/public/ChestXray8/') # revise according to the actual cirtumstances class_label = self.class_list[index] img = Image.open(img_path).convert('RGB') image = self.transform(img) return { "img_path": img_path, "image": image, "label": class_label } def __len__(self): return len(self.img_path_list) def get_text_features(model,text_list,tokenizer,device,max_length): text_token = tokenizer(list(text_list),add_special_tokens=True, padding='max_length', truncation=True, max_length= max_length, return_tensors="pt").to(device=device) text_features = model.encode_text(text_token) return text_features def valid_on(model, image_encoder, text_encoder, tokenizer, data_loader, epoch, device, args, config, writer, total_test=False): model.eval() image_encoder.eval() text_encoder.eval() text_list = ["atelectasis","cardiomegaly","pleural effusion","infiltration","lung mass","lung nodule","pneumonia","pneumothorax","consolidation","edema","emphysema","fibrosis","pleural thicken","hernia"] text_features = get_text_features(text_encoder,text_list,tokenizer,device,max_length=args.max_length) device_num = torch.cuda.device_count() text_features = text_features.repeat(int(device_num),1) val_scalar_step = epoch*len(data_loader) val_losses = [] gt = torch.FloatTensor() gt = gt.cuda() pred = torch.FloatTensor() pred = pred.cuda() for i, sample in enumerate(data_loader): image = sample['image'].to(device,non_blocking=True) label = sample['label'].long().to(device) label = label.float() gt = torch.cat((gt, label), 0) with torch.no_grad(): image_features,image_features_pool = image_encoder(image) pred_class = model(image_features,text_features)#b,14,2/1 val_loss = F.binary_cross_entropy_with_logits(pred_class.view(-1,1),label.view(-1, 1)) pred_class = torch.sigmoid(pred_class) pred = torch.cat((pred, pred_class[:,:,0]), 0) val_losses.append(val_loss.item()) writer.add_scalar('val_loss/loss', val_loss, val_scalar_step) val_scalar_step += 1 gt_np = gt.cpu().numpy() pred_np = pred.cpu().numpy() np.save(f'{args.output_dir}/gt.npy', gt_np) np.save(f'{args.output_dir}/pred.npy', pred_np) return def test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, device, args, config, writer, epoch=0, total_test=True): valid_on(model, image_encoder, text_encoder, tokenizer, test_dataloader, epoch, device, args, config, writer, total_test=True) def get_dataloader(args, config): test_dataset = Chestxray14_Dataset(config['test_file'],config['image_res']) test_dataloader = DataLoader( test_dataset, batch_size=config['test_batch_size'], num_workers=config["test_num_workers"], pin_memory=True, collate_fn=None, shuffle=False, drop_last=False, ) return test_dataloader, test_dataset def get_model(args, config): if 'resnet' in config['image_encoder_name']: image_encoder = ModelRes(config['image_encoder_name']).cuda() preprocess = None elif 'convnext' in config['image_encoder_name']: image_encoder = ModelConvNeXt(config['image_encoder_name']).cuda() preprocess = None elif 'efficientnet' in config['image_encoder_name']: image_encoder = ModelEfficientV2(config['image_encoder_name']).cuda() preprocess = None elif 'densenet' in config['image_encoder_name']: image_encoder = ModelDense(config['image_encoder_name']).cuda() preprocess = None else: raise NotImplementedError(f"Unknown image encoder: {config['image_encoder_name']}") tokenizer = AutoTokenizer.from_pretrained(args.bert_model_name, do_lower_case=True, local_files_only=True) text_encoder = CLP_clinical(bert_model_name=args.bert_model_name).cuda() if args.bert_pretrained: checkpoint = torch.load(args.bert_pretrained, map_location='cpu') state_dict = checkpoint["state_dict"] text_encoder.load_state_dict(state_dict, strict=False) if args.freeze_bert: for param in text_encoder.parameters(): param.requires_grad = False if 'lam' in config: model = TQN_Model(class_num=args.class_num, lam=config['lam']).cuda() else: model = TQN_Model(class_num=args.class_num).cuda() if args.distributed: model = torch.nn.DataParallel(model) image_encoder = torch.nn.DataParallel(image_encoder) return model, image_encoder, text_encoder, tokenizer def load_checkpoint(model, image_encoder, args): if os.path.isfile(args.finetune): checkpoint = torch.load(args.finetune, map_location='cpu') image_state_dict = checkpoint['image_encoder'] new_image_state_dict = OrderedDict() if 'module.' in list(image_encoder.state_dict().keys())[0] and 'module.' not in list(image_state_dict.keys())[0]: for k, v in image_state_dict.items(): name = 'module.' + k new_image_state_dict[name] = v elif 'module.' not in list(image_encoder.state_dict().keys())[0] and 'module.' in list(image_state_dict.keys())[0]: for k, v in image_state_dict.items(): name = k.replace('module.', '') new_image_state_dict[name] = v else: new_image_state_dict = image_state_dict image_encoder.load_state_dict(new_image_state_dict, strict=False) state_dict = checkpoint['model'] new_state_dict = OrderedDict() if 'module.' in list(model.state_dict().keys())[0] and 'module.' not in list(state_dict.keys())[0]: for k, v in state_dict.items(): name = 'module.' + k new_state_dict[name] = v elif 'module.' not in list(model.state_dict().keys())[0] and 'module.' in list(state_dict.keys())[0]: for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v else: new_state_dict = state_dict model.load_state_dict(new_state_dict, strict=False) print("load model success!") def seed_torch(seed=42): # if os.environ['LOCAL_RANK'] == '0': print('=====> Using fixed random seed: ' + str(seed)) os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def main(args, config): '''Data准备''' test_dataloader, test_dataset = get_dataloader(args, config) test_dataloader.num_samples = len(test_dataset) test_dataloader.num_batches = len(test_dataset) '''Model准备''' model, image_encoder, text_encoder, tokenizer = get_model(args, config) writer = SummaryWriter(os.path.join(args.output_dir, 'log')) load_checkpoint(model, image_encoder, args) test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, args.device, args, config, writer, epoch=0, total_test=True) if __name__ == '__main__': abs_file_path = os.path.abspath(__file__) os.environ['OMP_NUM_THREADS'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" parser = argparse.ArgumentParser() parser.add_argument('--momentum', default=False, type=bool) parser.add_argument('--checkpoint', default='') parser.add_argument('--finetune', default='base_checkpoint.pt') parser.add_argument('--freeze_bert', default=True, type=bool) parser.add_argument("--use_entity_features", default=True, type=bool) parser.add_argument('--config', default='example.yaml') parser.add_argument('--fourier', default=True, type=bool) parser.add_argument('--colourjitter', default=True, type=bool) parser.add_argument('--class_num', default=1, type=int) # FT1, FF2 parser.add_argument('--ignore_index', default=True, type=bool) #原始为false; +data时-1作为标记不算loss, 改成True parser.add_argument('--add_dataset', default=False, type=bool) time_now = time.strftime("%Y-%m-%d-%H-%M", time.localtime()) parser.add_argument('--output_dir', default=f'./results/test-{time_now}') parser.add_argument('--aws_output_dir', default=f'./results/test-{time_now}') parser.add_argument('--bert_pretrained', default= './pretrained_bert_weights/epoch_latest.pt') parser.add_argument('--bert_model_name', default= './pretrained_bert_weights/UMLSBert_ENG/') parser.add_argument('--max_length', default=256, type=int) parser.add_argument('--loss_ratio', default=1, type=int) parser.add_argument('--device', default='cuda') parser.add_argument('--seed', default=42, type=int) # distributed training parameters parser.add_argument("--local_rank", type=int) parser.add_argument('--distributed', action='store_true', default=False, help='Use multi-processing distributed training to launch ') parser.add_argument('--rho', default=0, type=float, help='gpu') parser.add_argument('--gpu', default=0, type=int, help='gpu') args = parser.parse_args() args.config = f'./configs/{args.config}' config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) if config['finetune'] != '': args.finetune = config['finetune'] args.checkpoint = config['finetune'] args.loss_ratio = config['loss_ratio'] Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.aws_output_dir).mkdir(parents=True, exist_ok=True) yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) seed_torch(args.seed) main(args, config)