| | import sys |
| | import argparse |
| | import os |
| | import logging |
| | import yaml |
| | import numpy as np |
| | import random |
| | import time |
| | import datetime |
| | import json |
| | import math |
| | from pathlib import Path |
| | from functools import partial |
| | from collections import OrderedDict |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.backends.cudnn as cudnn |
| | import torch.distributed as dist |
| | from torch.utils.data import DataLoader |
| | from torch.utils.tensorboard import SummaryWriter |
| | from transformers import AutoModel, BertConfig, AutoTokenizer |
| | from torch.utils.data import Dataset |
| | from torchvision import transforms |
| | import PIL |
| | from PIL import Image |
| | from models.clip_tqn import CLP_clinical, ModelRes, TQN_Model, ModelConvNeXt, ModelEfficientV2, ModelDense |
| | import numpy as np |
| | import pandas as pd |
| | from factory import utils |
| |
|
| | class Chestxray14_Dataset(Dataset): |
| | def __init__(self, csv_path,image_res): |
| | data_info = pd.read_csv(csv_path) |
| | self.img_path_list = np.asarray(data_info.iloc[:,0]) |
| | self.class_list = np.asarray(data_info.iloc[:,3:]) |
| |
|
| | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
| | self.transform = transforms.Compose([ |
| | transforms.Resize(image_res, interpolation=Image.BICUBIC), |
| | transforms.ToTensor(), |
| | normalize, |
| | ]) |
| | |
| | def __getitem__(self, index): |
| | img_path = self.img_path_list[index].replace('/mnt/petrelfs/zhangxiaoman/DATA/Chestxray/ChestXray8/','/remote-home/share/medical/public/ChestXray8/') |
| | |
| |
|
| | class_label = self.class_list[index] |
| | img = Image.open(img_path).convert('RGB') |
| | image = self.transform(img) |
| | return { |
| | "img_path": img_path, |
| | "image": image, |
| | "label": class_label |
| | } |
| | |
| | def __len__(self): |
| | return len(self.img_path_list) |
| |
|
| |
|
| | def get_text_features(model,text_list,tokenizer,device,max_length): |
| | text_token = tokenizer(list(text_list),add_special_tokens=True, padding='max_length', truncation=True, max_length= max_length, return_tensors="pt").to(device=device) |
| | text_features = model.encode_text(text_token) |
| | return text_features |
| |
|
| | def valid_on(model, image_encoder, text_encoder, tokenizer, data_loader, epoch, device, args, config, writer, total_test=False): |
| | model.eval() |
| | image_encoder.eval() |
| | text_encoder.eval() |
| |
|
| | text_list = ["atelectasis","cardiomegaly","pleural effusion","infiltration","lung mass","lung nodule","pneumonia","pneumothorax","consolidation","edema","emphysema","fibrosis","pleural thicken","hernia"] |
| | |
| | text_features = get_text_features(text_encoder,text_list,tokenizer,device,max_length=args.max_length) |
| | device_num = torch.cuda.device_count() |
| | text_features = text_features.repeat(int(device_num),1) |
| | |
| | val_scalar_step = epoch*len(data_loader) |
| | val_losses = [] |
| |
|
| | gt = torch.FloatTensor() |
| | gt = gt.cuda() |
| | pred = torch.FloatTensor() |
| | pred = pred.cuda() |
| |
|
| | for i, sample in enumerate(data_loader): |
| | image = sample['image'].to(device,non_blocking=True) |
| | label = sample['label'].long().to(device) |
| | label = label.float() |
| | gt = torch.cat((gt, label), 0) |
| | with torch.no_grad(): |
| | image_features,image_features_pool = image_encoder(image) |
| | |
| | pred_class = model(image_features,text_features) |
| | val_loss = F.binary_cross_entropy_with_logits(pred_class.view(-1,1),label.view(-1, 1)) |
| | pred_class = torch.sigmoid(pred_class) |
| | pred = torch.cat((pred, pred_class[:,:,0]), 0) |
| |
|
| | |
| | val_losses.append(val_loss.item()) |
| | writer.add_scalar('val_loss/loss', val_loss, val_scalar_step) |
| | val_scalar_step += 1 |
| | gt_np = gt.cpu().numpy() |
| | pred_np = pred.cpu().numpy() |
| | np.save(f'{args.output_dir}/gt.npy', gt_np) |
| | np.save(f'{args.output_dir}/pred.npy', pred_np) |
| |
|
| | return |
| |
|
| |
|
| |
|
| | def test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, device, args, config, writer, epoch=0, total_test=True): |
| | |
| | valid_on(model, image_encoder, text_encoder, tokenizer, test_dataloader, epoch, device, args, config, writer, total_test=True) |
| | |
| | |
| | def get_dataloader(args, config): |
| | test_dataset = Chestxray14_Dataset(config['test_file'],config['image_res']) |
| | test_dataloader = DataLoader( |
| | test_dataset, |
| | batch_size=config['test_batch_size'], |
| | num_workers=config["test_num_workers"], |
| | pin_memory=True, |
| | collate_fn=None, |
| | shuffle=False, |
| | drop_last=False, |
| | ) |
| | |
| | return test_dataloader, test_dataset |
| |
|
| |
|
| | def get_model(args, config): |
| | if 'resnet' in config['image_encoder_name']: |
| | image_encoder = ModelRes(config['image_encoder_name']).cuda() |
| | preprocess = None |
| | elif 'convnext' in config['image_encoder_name']: |
| | image_encoder = ModelConvNeXt(config['image_encoder_name']).cuda() |
| | preprocess = None |
| | elif 'efficientnet' in config['image_encoder_name']: |
| | image_encoder = ModelEfficientV2(config['image_encoder_name']).cuda() |
| | preprocess = None |
| | elif 'densenet' in config['image_encoder_name']: |
| | image_encoder = ModelDense(config['image_encoder_name']).cuda() |
| | preprocess = None |
| | else: |
| | raise NotImplementedError(f"Unknown image encoder: {config['image_encoder_name']}") |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(args.bert_model_name, do_lower_case=True, local_files_only=True) |
| | text_encoder = CLP_clinical(bert_model_name=args.bert_model_name).cuda() |
| | |
| | if args.bert_pretrained: |
| | checkpoint = torch.load(args.bert_pretrained, map_location='cpu') |
| | state_dict = checkpoint["state_dict"] |
| | text_encoder.load_state_dict(state_dict, strict=False) |
| | if args.freeze_bert: |
| | for param in text_encoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | if 'lam' in config: |
| | model = TQN_Model(class_num=args.class_num, lam=config['lam']).cuda() |
| | else: |
| | model = TQN_Model(class_num=args.class_num).cuda() |
| | |
| | if args.distributed: |
| | model = torch.nn.DataParallel(model) |
| | image_encoder = torch.nn.DataParallel(image_encoder) |
| | |
| | return model, image_encoder, text_encoder, tokenizer |
| | |
| | def load_checkpoint(model, image_encoder, args): |
| | if os.path.isfile(args.finetune): |
| | checkpoint = torch.load(args.finetune, map_location='cpu') |
| | image_state_dict = checkpoint['image_encoder'] |
| | new_image_state_dict = OrderedDict() |
| | if 'module.' in list(image_encoder.state_dict().keys())[0] and 'module.' not in list(image_state_dict.keys())[0]: |
| | for k, v in image_state_dict.items(): |
| | name = 'module.' + k |
| | new_image_state_dict[name] = v |
| | elif 'module.' not in list(image_encoder.state_dict().keys())[0] and 'module.' in list(image_state_dict.keys())[0]: |
| | for k, v in image_state_dict.items(): |
| | name = k.replace('module.', '') |
| | new_image_state_dict[name] = v |
| | else: |
| | new_image_state_dict = image_state_dict |
| | image_encoder.load_state_dict(new_image_state_dict, strict=False) |
| | |
| | state_dict = checkpoint['model'] |
| | new_state_dict = OrderedDict() |
| | if 'module.' in list(model.state_dict().keys())[0] and 'module.' not in list(state_dict.keys())[0]: |
| | for k, v in state_dict.items(): |
| | name = 'module.' + k |
| | new_state_dict[name] = v |
| | elif 'module.' not in list(model.state_dict().keys())[0] and 'module.' in list(state_dict.keys())[0]: |
| | for k, v in state_dict.items(): |
| | name = k.replace('module.', '') |
| | new_state_dict[name] = v |
| | else: |
| | new_state_dict = state_dict |
| | |
| | model.load_state_dict(new_state_dict, strict=False) |
| | print("load model success!") |
| |
|
| | def seed_torch(seed=42): |
| | |
| | print('=====> Using fixed random seed: ' + str(seed)) |
| | os.environ['PYTHONHASHSEED'] = str(seed) |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| |
|
| | def main(args, config): |
| | '''Data准备''' |
| | test_dataloader, test_dataset = get_dataloader(args, config) |
| | |
| | test_dataloader.num_samples = len(test_dataset) |
| | test_dataloader.num_batches = len(test_dataset) |
| |
|
| | |
| | '''Model准备''' |
| | model, image_encoder, text_encoder, tokenizer = get_model(args, config) |
| | |
| | writer = SummaryWriter(os.path.join(args.output_dir, 'log')) |
| | |
| | load_checkpoint(model, image_encoder, args) |
| |
|
| | test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, args.device, args, config, writer, epoch=0, total_test=True) |
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | abs_file_path = os.path.abspath(__file__) |
| | os.environ['OMP_NUM_THREADS'] = '1' |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--momentum', default=False, type=bool) |
| | parser.add_argument('--checkpoint', default='') |
| | parser.add_argument('--finetune', default='base_checkpoint.pt') |
| | |
| | parser.add_argument('--freeze_bert', default=True, type=bool) |
| | parser.add_argument("--use_entity_features", default=True, type=bool) |
| |
|
| | parser.add_argument('--config', default='example.yaml') |
| |
|
| | parser.add_argument('--fourier', default=True, type=bool) |
| | parser.add_argument('--colourjitter', default=True, type=bool) |
| | parser.add_argument('--class_num', default=1, type=int) |
| |
|
| | parser.add_argument('--ignore_index', default=True, type=bool) |
| | parser.add_argument('--add_dataset', default=False, type=bool) |
| |
|
| | time_now = time.strftime("%Y-%m-%d-%H-%M", time.localtime()) |
| | parser.add_argument('--output_dir', default=f'./results/test-{time_now}') |
| | parser.add_argument('--aws_output_dir', default=f'./results/test-{time_now}') |
| | parser.add_argument('--bert_pretrained', default= './pretrained_bert_weights/epoch_latest.pt') |
| | parser.add_argument('--bert_model_name', default= './pretrained_bert_weights/UMLSBert_ENG/') |
| | parser.add_argument('--max_length', default=256, type=int) |
| | parser.add_argument('--loss_ratio', default=1, type=int) |
| | parser.add_argument('--device', default='cuda') |
| | parser.add_argument('--seed', default=42, type=int) |
| |
|
| | |
| | parser.add_argument("--local_rank", type=int) |
| | parser.add_argument('--distributed', action='store_true', default=False, help='Use multi-processing distributed training to launch ') |
| | |
| | parser.add_argument('--rho', default=0, type=float, help='gpu') |
| | parser.add_argument('--gpu', default=0, type=int, help='gpu') |
| | args = parser.parse_args() |
| | args.config = f'./configs/{args.config}' |
| | |
| |
|
| | |
| | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) |
| | if config['finetune'] != '': |
| | args.finetune = config['finetune'] |
| | args.checkpoint = config['finetune'] |
| | |
| | args.loss_ratio = config['loss_ratio'] |
| | Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| | Path(args.aws_output_dir).mkdir(parents=True, exist_ok=True) |
| |
|
| | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) |
| |
|
| |
|
| | seed_torch(args.seed) |
| |
|
| | main(args, config) |
| |
|
| |
|
| |
|