Upload model
Browse files- configuration_cxrmate_ed.py +0 -164
- modelling_cxrmate_ed.py +267 -445
configuration_cxrmate_ed.py
CHANGED
|
@@ -46,167 +46,3 @@ class CXRMateEDConfig(transformers.PretrainedConfig):
|
|
| 46 |
text_config = CONFIG_MAPPING[text_config['model_type']](**text_config)
|
| 47 |
|
| 48 |
self.text_config = text_config
|
| 49 |
-
|
| 50 |
-
# class CXRMateEDConfig(transformers.PretrainedConfig):
|
| 51 |
-
|
| 52 |
-
# model_type = 'cxrmate-ed'
|
| 53 |
-
|
| 54 |
-
# # def __init__(
|
| 55 |
-
# # self,
|
| 56 |
-
# # index_value_encoder_intermediate_size: int = 2048,
|
| 57 |
-
# # include_time_delta: bool = True,
|
| 58 |
-
# # time_delta_monotonic_inversion: bool = True,
|
| 59 |
-
# # add_time_deltas: bool = True,
|
| 60 |
-
# # history: int = 0,
|
| 61 |
-
# # tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
|
| 62 |
-
# # prompt_report_sections_filter: list = ['indication', 'history'],
|
| 63 |
-
# # pad_token_id: int = 4,
|
| 64 |
-
# # **kwargs: Any,
|
| 65 |
-
# # ) -> None:
|
| 66 |
-
# # super().__init__(**kwargs)
|
| 67 |
-
# # self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
|
| 68 |
-
# # self.include_time_delta = include_time_delta
|
| 69 |
-
# # self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
|
| 70 |
-
# # self.add_time_deltas = add_time_deltas
|
| 71 |
-
# # self.history = history
|
| 72 |
-
# # self.tables_filter = tables_filter
|
| 73 |
-
# # self.prompt_report_sections_filter = prompt_report_sections_filter
|
| 74 |
-
# # self.pad_token_id = pad_token_id
|
| 75 |
-
|
| 76 |
-
# # self.hidden_size = self.text_config.hidden_size
|
| 77 |
-
|
| 78 |
-
# def __init__(
|
| 79 |
-
# self,
|
| 80 |
-
# vision_config=None,
|
| 81 |
-
# text_config=None,
|
| 82 |
-
# # ignore_index=-100,
|
| 83 |
-
# # image_token_index=32000,
|
| 84 |
-
# # projector_hidden_act="gelu",
|
| 85 |
-
# # vision_feature_select_strategy="default",
|
| 86 |
-
# # vision_feature_layer=-2,
|
| 87 |
-
# # image_seq_length=576,
|
| 88 |
-
# index_value_encoder_intermediate_size: int = 2048,
|
| 89 |
-
# include_time_delta: bool = True,
|
| 90 |
-
# time_delta_monotonic_inversion: bool = True,
|
| 91 |
-
# add_time_deltas: bool = True,
|
| 92 |
-
# history: int = 0,
|
| 93 |
-
# tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
|
| 94 |
-
# prompt_report_sections_filter: list = ['indication', 'history'],
|
| 95 |
-
# pad_token_id: int = 4,
|
| 96 |
-
# **kwargs,
|
| 97 |
-
# ):
|
| 98 |
-
# transformers.PretrainedConfig.__init__(self, **kwargs)
|
| 99 |
-
|
| 100 |
-
# self.vision_config = vision_config
|
| 101 |
-
# self.text_config = text_config
|
| 102 |
-
# self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
|
| 103 |
-
# self.include_time_delta = include_time_delta
|
| 104 |
-
# self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
|
| 105 |
-
# self.add_time_deltas = add_time_deltas
|
| 106 |
-
# self.history = history
|
| 107 |
-
# self.tables_filter = tables_filter
|
| 108 |
-
# self.prompt_report_sections_filter = prompt_report_sections_filter
|
| 109 |
-
# self.pad_token_id = pad_token_id
|
| 110 |
-
|
| 111 |
-
# self.ignore_index = ignore_index
|
| 112 |
-
# self.image_token_index = image_token_index
|
| 113 |
-
# self.projector_hidden_act = projector_hidden_act
|
| 114 |
-
# self.image_seq_length = image_seq_length
|
| 115 |
-
|
| 116 |
-
# if vision_feature_select_strategy not in ["default", "full"]:
|
| 117 |
-
# raise ValueError(
|
| 118 |
-
# "vision_feature_select_strategy should be one of 'default', 'full'."
|
| 119 |
-
# f"Got: {vision_feature_select_strategy}"
|
| 120 |
-
# )
|
| 121 |
-
|
| 122 |
-
# self.vision_feature_select_strategy = vision_feature_select_strategy
|
| 123 |
-
# self.vision_feature_layer = vision_feature_layer
|
| 124 |
-
|
| 125 |
-
# if isinstance(vision_config, dict):
|
| 126 |
-
# vision_config["model_type"] = (
|
| 127 |
-
# vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
|
| 128 |
-
# )
|
| 129 |
-
# vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
| 130 |
-
# elif vision_config is None:
|
| 131 |
-
# vision_config = CONFIG_MAPPING["clip_vision_model"](
|
| 132 |
-
# intermediate_size=4096,
|
| 133 |
-
# hidden_size=1024,
|
| 134 |
-
# patch_size=14,
|
| 135 |
-
# image_size=336,
|
| 136 |
-
# num_hidden_layers=24,
|
| 137 |
-
# num_attention_heads=16,
|
| 138 |
-
# vocab_size=32000,
|
| 139 |
-
# projection_dim=768,
|
| 140 |
-
# )
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
# if isinstance(text_config, dict):
|
| 144 |
-
# text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
| 145 |
-
# text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
| 146 |
-
# elif text_config is None:
|
| 147 |
-
# text_config = CONFIG_MAPPING["llama"]()
|
| 148 |
-
|
| 149 |
-
# super().__init__(**kwargs)
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
# import transformers
|
| 153 |
-
# from transformers.configuration_utils import PretrainedConfig
|
| 154 |
-
# from transformers.utils import logging
|
| 155 |
-
|
| 156 |
-
# logger = logging.get_logger(__name__)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# class CXRMateEDConfig(PretrainedConfig):
|
| 160 |
-
|
| 161 |
-
# model_type = "cxrmate-ed"
|
| 162 |
-
|
| 163 |
-
# def __init__(self, **kwargs):
|
| 164 |
-
# super().__init__(**kwargs)
|
| 165 |
-
|
| 166 |
-
# if 'decoder' not in kwargs:
|
| 167 |
-
|
| 168 |
-
# self.decoder = transformers.LlamaConfig(
|
| 169 |
-
# vocab_size=30000,
|
| 170 |
-
# hidden_size=768,
|
| 171 |
-
# intermediate_size=3072,
|
| 172 |
-
# num_attention_heads=12,
|
| 173 |
-
# num_hidden_layers=6,
|
| 174 |
-
# max_position_embeddings=2048,
|
| 175 |
-
# )
|
| 176 |
-
# self.decoder.is_decoder = True
|
| 177 |
-
|
| 178 |
-
# self.decoder.index_value_encoder_intermediate_size = 2048
|
| 179 |
-
# self.decoder.include_time_delta = True
|
| 180 |
-
# self.decoder.time_delta_monotonic_inversion = True
|
| 181 |
-
# self.decoder.add_time_deltas = True
|
| 182 |
-
# self.decoder.history = 0
|
| 183 |
-
# self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"]
|
| 184 |
-
# self.decoder.prompt_report_sections_filter = ["indication", "history"]
|
| 185 |
-
# self.decoder.pad_token_id = 4
|
| 186 |
-
|
| 187 |
-
# else:
|
| 188 |
-
# self.decoder = kwargs.pop("decoder")
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
# if 'encoder' not in kwargs:
|
| 192 |
-
# self.encoder = transformers.AutoConfig.from_pretrained(
|
| 193 |
-
# 'aehrc/uniformer_base_tl_384',
|
| 194 |
-
# projection_size=768,
|
| 195 |
-
# trust_remote_code=True,
|
| 196 |
-
# )
|
| 197 |
-
# else:
|
| 198 |
-
# self.encoder = kwargs.pop("encoder")
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
# self.is_encoder_decoder = True
|
| 202 |
-
|
| 203 |
-
# @classmethod
|
| 204 |
-
# def from_encoder_decoder_configs(
|
| 205 |
-
# cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
| 206 |
-
# ) -> PretrainedConfig:
|
| 207 |
-
|
| 208 |
-
# logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
| 209 |
-
# decoder_config.is_decoder = True
|
| 210 |
-
# decoder_config.add_cross_attention = True
|
| 211 |
-
|
| 212 |
-
# return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)
|
|
|
|
| 46 |
text_config = CONFIG_MAPPING[text_config['model_type']](**text_config)
|
| 47 |
|
| 48 |
self.text_config = text_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modelling_cxrmate_ed.py
CHANGED
|
@@ -8,13 +8,13 @@ import datasets
|
|
| 8 |
import torch
|
| 9 |
import transformers
|
| 10 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 11 |
from torch.nn import CrossEntropyLoss
|
| 12 |
from torch.utils.data import Subset
|
| 13 |
from torchvision.io import decode_image
|
| 14 |
-
from
|
| 15 |
-
from transformers
|
| 16 |
from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
|
| 17 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
from transformers.utils import check_min_version, logging
|
| 19 |
|
| 20 |
from .configuration_cxrmate_ed import CXRMateEDConfig
|
|
@@ -187,162 +187,39 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 187 |
|
| 188 |
self.inf_time_delta_value = self.time_delta_map(float('inf'))
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
# decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
|
| 222 |
-
# Information necessary to initiate the text decoder. Can be either:
|
| 223 |
-
|
| 224 |
-
# - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 225 |
-
# - A path to a *directory* containing model weights saved using
|
| 226 |
-
# [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
| 227 |
-
# - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
| 228 |
-
# this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
| 229 |
-
# `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
| 230 |
-
# PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
| 231 |
-
|
| 232 |
-
# model_args (remaining positional arguments, *optional*):
|
| 233 |
-
# All remaning positional arguments will be passed to the underlying model's `__init__` method.
|
| 234 |
-
|
| 235 |
-
# kwargs (remaining dictionary of keyword arguments, *optional*):
|
| 236 |
-
# Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
| 237 |
-
# `output_attentions=True`).
|
| 238 |
-
|
| 239 |
-
# - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
|
| 240 |
-
# - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
|
| 241 |
-
# - To update the parent model configuration, do not use a prefix for each configuration parameter.
|
| 242 |
-
|
| 243 |
-
# Behaves differently depending on whether a `config` is provided or automatically loaded.
|
| 244 |
-
|
| 245 |
-
# Example:
|
| 246 |
-
|
| 247 |
-
# ```python
|
| 248 |
-
# >>> from transformers import VisionEncoderDecoderModel
|
| 249 |
-
|
| 250 |
-
# >>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
|
| 251 |
-
# >>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
| 252 |
-
# ... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
|
| 253 |
-
# ... )
|
| 254 |
-
# >>> # saving model after fine-tuning
|
| 255 |
-
# >>> model.save_pretrained("./vit-bert")
|
| 256 |
-
# >>> # load fine-tuned model
|
| 257 |
-
# >>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
|
| 258 |
-
# ```"""
|
| 259 |
-
|
| 260 |
-
# kwargs_encoder = {
|
| 261 |
-
# argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
| 262 |
-
# }
|
| 263 |
-
|
| 264 |
-
# kwargs_decoder = {
|
| 265 |
-
# argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
| 266 |
-
# }
|
| 267 |
-
|
| 268 |
-
# # remove encoder, decoder kwargs from kwargs
|
| 269 |
-
# for key in kwargs_encoder.keys():
|
| 270 |
-
# del kwargs["encoder_" + key]
|
| 271 |
-
# for key in kwargs_decoder.keys():
|
| 272 |
-
# del kwargs["decoder_" + key]
|
| 273 |
-
|
| 274 |
-
# # Load and initialize the encoder and decoder
|
| 275 |
-
# # The distinction between encoder and decoder at the model level is made
|
| 276 |
-
# # by the value of the flag `is_decoder` that we need to set correctly.
|
| 277 |
-
# encoder = kwargs_encoder.pop("model", None)
|
| 278 |
-
# if encoder is None:
|
| 279 |
-
# if encoder_pretrained_model_name_or_path is None:
|
| 280 |
-
# raise ValueError(
|
| 281 |
-
# "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
| 282 |
-
# "to be defined."
|
| 283 |
-
# )
|
| 284 |
-
|
| 285 |
-
# if "config" not in kwargs_encoder:
|
| 286 |
-
# encoder_config, kwargs_encoder = transformers.AutoConfig.from_pretrained(
|
| 287 |
-
# encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
| 288 |
-
# )
|
| 289 |
-
|
| 290 |
-
# if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
| 291 |
-
# logger.info(
|
| 292 |
-
# f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
| 293 |
-
# "from a decoder model. Cross-attention and casual mask are disabled."
|
| 294 |
-
# )
|
| 295 |
-
# encoder_config.is_decoder = False
|
| 296 |
-
# encoder_config.add_cross_attention = False
|
| 297 |
-
|
| 298 |
-
# kwargs_encoder["config"] = encoder_config
|
| 299 |
-
|
| 300 |
-
# encoder = transformers.AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
| 301 |
-
|
| 302 |
-
# decoder = kwargs_decoder.pop("model", None)
|
| 303 |
-
# if decoder is None:
|
| 304 |
-
# if decoder_pretrained_model_name_or_path is None:
|
| 305 |
-
# raise ValueError(
|
| 306 |
-
# "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
| 307 |
-
# "to be defined."
|
| 308 |
-
# )
|
| 309 |
-
|
| 310 |
-
# if "config" not in kwargs_decoder:
|
| 311 |
-
# decoder_config, kwargs_decoder = transformers.AutoConfig.from_pretrained(
|
| 312 |
-
# decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
| 313 |
-
# )
|
| 314 |
-
|
| 315 |
-
# if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
| 316 |
-
# logger.info(
|
| 317 |
-
# f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
|
| 318 |
-
# f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
|
| 319 |
-
# f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
| 320 |
-
# )
|
| 321 |
-
# decoder_config.is_decoder = True
|
| 322 |
-
# decoder_config.add_cross_attention = False
|
| 323 |
-
|
| 324 |
-
# kwargs_decoder["config"] = decoder_config
|
| 325 |
-
|
| 326 |
-
# if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
| 327 |
-
# logger.warning(
|
| 328 |
-
# f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
| 329 |
-
# f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
| 330 |
-
# "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
| 331 |
-
# "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
| 332 |
-
# "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
| 333 |
-
# )
|
| 334 |
-
|
| 335 |
-
# decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
| 336 |
-
|
| 337 |
-
# # instantiate config with corresponding kwargs
|
| 338 |
-
# config = CXRMateEDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
| 339 |
-
|
| 340 |
-
# # make sure input & output embeddings is not tied
|
| 341 |
-
# config.tie_word_embeddings = False
|
| 342 |
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
# return cls(encoder=encoder, decoder=decoder, config=config)
|
| 346 |
|
| 347 |
def forward(
|
| 348 |
self,
|
|
@@ -712,80 +589,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 712 |
sections[j].append(section_string)
|
| 713 |
|
| 714 |
return tuple(sections.values())
|
| 715 |
-
|
| 716 |
-
def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
|
| 717 |
-
"""
|
| 718 |
-
Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
|
| 719 |
-
Time deltas for the input_ids are also prepared here.
|
| 720 |
-
|
| 721 |
-
Argument/s:
|
| 722 |
-
tokenizer - Hugging Face tokenizer.
|
| 723 |
-
|
| 724 |
-
Returns:
|
| 725 |
-
ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns.
|
| 726 |
-
cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
|
| 727 |
-
"""
|
| 728 |
-
|
| 729 |
-
batch_size = len(kwargs['study_id'])
|
| 730 |
-
|
| 731 |
-
tokenized = {
|
| 732 |
-
'input_ids': {i: [] for i in range(batch_size)},
|
| 733 |
-
'token_type_ids': {i: [] for i in range(batch_size)},
|
| 734 |
-
'time_delta': {i: [] for i in range(batch_size)},
|
| 735 |
-
'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
|
| 736 |
-
}
|
| 737 |
-
|
| 738 |
-
prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression']
|
| 739 |
-
|
| 740 |
-
for i in prompt_text_columns:
|
| 741 |
-
if i in kwargs:
|
| 742 |
-
if f'{i}_time_delta' not in kwargs:
|
| 743 |
-
kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
|
| 744 |
-
for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
|
| 745 |
-
if y is not None:
|
| 746 |
-
assert isinstance(y, list)
|
| 747 |
-
assert isinstance(z, list)
|
| 748 |
-
for text, time_delta in zip(y, z):
|
| 749 |
-
if text is not None:
|
| 750 |
-
tokenized['input_ids'][x].append(
|
| 751 |
-
tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
|
| 752 |
-
)
|
| 753 |
-
tokenized['token_type_ids'][x].append(
|
| 754 |
-
torch.full(
|
| 755 |
-
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
| 756 |
-
self.token_type_to_token_type_id[i],
|
| 757 |
-
dtype=torch.long,
|
| 758 |
-
device=self.device,
|
| 759 |
-
)
|
| 760 |
-
)
|
| 761 |
-
tokenized['time_delta'][x].append(
|
| 762 |
-
torch.full(
|
| 763 |
-
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
| 764 |
-
time_delta,
|
| 765 |
-
dtype=torch.float32,
|
| 766 |
-
device=self.device,
|
| 767 |
-
)
|
| 768 |
-
)
|
| 769 |
|
| 770 |
-
tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
|
| 771 |
-
tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
|
| 772 |
-
tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()]
|
| 773 |
-
|
| 774 |
-
tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence(
|
| 775 |
-
tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
|
| 776 |
-
)[:, :, 0]
|
| 777 |
-
tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(
|
| 778 |
-
tokenized['token_type_ids'], batch_first=True, padding_value=0,
|
| 779 |
-
)[:, :, 0]
|
| 780 |
-
|
| 781 |
-
tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int()
|
| 782 |
-
|
| 783 |
-
tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence(
|
| 784 |
-
tokenized['time_delta'], batch_first=True, padding_value=0,
|
| 785 |
-
)
|
| 786 |
-
|
| 787 |
-
return tokenized
|
| 788 |
-
|
| 789 |
def prepare_inputs(
|
| 790 |
self,
|
| 791 |
images,
|
|
@@ -914,7 +718,219 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 914 |
assert inputs_embeds.shape[1] == token_type_ids.shape[1]
|
| 915 |
|
| 916 |
return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
@staticmethod
|
| 919 |
def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
|
| 920 |
|
|
@@ -983,86 +999,24 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 983 |
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
|
| 984 |
|
| 985 |
return mixed_causality_4d_attention_mask
|
| 986 |
-
|
| 987 |
-
# @staticmethod
|
| 988 |
-
# def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
|
| 989 |
-
|
| 990 |
-
# prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
|
| 991 |
-
# report_seq_len = causal_2d_attention_mask.shape[-1]
|
| 992 |
-
|
| 993 |
-
# non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
|
| 994 |
-
# causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
| 995 |
-
|
| 996 |
-
# # Upper left of attention matrix:
|
| 997 |
-
# upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
|
| 998 |
-
# upper_left = upper_left * non_causal_2d_attention_mask
|
| 999 |
-
# upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
|
| 1000 |
-
|
| 1001 |
-
# causal_mask = torch.tril(
|
| 1002 |
-
# torch.ones(
|
| 1003 |
-
# (
|
| 1004 |
-
# report_seq_len,
|
| 1005 |
-
# report_seq_len,
|
| 1006 |
-
# ),
|
| 1007 |
-
# dtype=torch.long,
|
| 1008 |
-
# device=causal_2d_attention_mask.device,
|
| 1009 |
-
# ),
|
| 1010 |
-
# )
|
| 1011 |
-
|
| 1012 |
-
# # Lower right of attention matrix:
|
| 1013 |
-
# lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
|
| 1014 |
-
# lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
|
| 1015 |
-
# lower_right = lower_right * causal_mask
|
| 1016 |
-
|
| 1017 |
-
# # Upper right of attention matrix:
|
| 1018 |
-
# upper_right = torch.zeros(
|
| 1019 |
-
# causal_2d_attention_mask.shape[0],
|
| 1020 |
-
# 1,
|
| 1021 |
-
# prompt_seq_len,
|
| 1022 |
-
# report_seq_len,
|
| 1023 |
-
# dtype=torch.long,
|
| 1024 |
-
# device=causal_2d_attention_mask.device,
|
| 1025 |
-
# )
|
| 1026 |
-
|
| 1027 |
-
# # Lower left of attention matrix:
|
| 1028 |
-
# lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
|
| 1029 |
-
# lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
|
| 1030 |
-
|
| 1031 |
-
# left = torch.cat((upper_left, lower_left), dim=2)
|
| 1032 |
-
# right = torch.cat((upper_right, lower_right), dim=2)
|
| 1033 |
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
| 1040 |
-
|
| 1041 |
-
# non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
|
| 1042 |
-
# causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
| 1043 |
-
|
| 1044 |
-
# mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
|
| 1045 |
-
# return mixed_causality_4d_attention_mask
|
| 1046 |
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
|
| 1051 |
-
_, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion)
|
| 1052 |
|
| 1053 |
-
|
| 1054 |
|
| 1055 |
-
|
| 1056 |
-
position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
|
| 1057 |
-
position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
|
| 1058 |
-
position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
|
| 1059 |
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
def get_dataset(self, dataset_path, train_transforms=None, test_transforms=None, max_train_images_per_study=None, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False):
|
| 1063 |
|
| 1064 |
-
assert max_train_images_per_study is not None, 'max_train_images_per_study must be defined.'
|
| 1065 |
-
assert test_transforms is not None, 'test_transforms must be defined.'
|
| 1066 |
|
| 1067 |
def train_set_transform(batch):
|
| 1068 |
|
|
@@ -1081,7 +1035,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 1081 |
|
| 1082 |
# Sort based on ViewPosition:
|
| 1083 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1084 |
-
batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
|
| 1085 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1086 |
|
| 1087 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
|
@@ -1104,7 +1058,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 1104 |
|
| 1105 |
# Sort based on ViewPosition:
|
| 1106 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1107 |
-
batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
|
| 1108 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1109 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
| 1110 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
@@ -1177,7 +1131,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 1177 |
else:
|
| 1178 |
return test_set
|
| 1179 |
|
| 1180 |
-
def get_stage_1_dataset(self,
|
|
|
|
|
|
|
| 1181 |
|
| 1182 |
def train_set_transform(batch):
|
| 1183 |
|
|
@@ -1192,7 +1148,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 1192 |
|
| 1193 |
# Sort based on ViewPosition:
|
| 1194 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1195 |
-
batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
|
| 1196 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1197 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
| 1198 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
@@ -1204,7 +1160,7 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 1204 |
|
| 1205 |
# Sort based on ViewPosition:
|
| 1206 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1207 |
-
batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
|
| 1208 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1209 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
| 1210 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
@@ -1256,138 +1212,4 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
| 1256 |
test_set = Subset(test_set, indices)
|
| 1257 |
|
| 1258 |
return train_set, val_set, test_set
|
| 1259 |
-
|
| 1260 |
-
def prepare_index_value_feats(self, table, batch):
|
| 1261 |
-
|
| 1262 |
-
index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
|
| 1263 |
-
index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns
|
| 1264 |
-
|
| 1265 |
-
# Map to indices with lookup table:
|
| 1266 |
-
if 'index_columns' in self.tables[table]:
|
| 1267 |
-
for i in self.tables[table]['index_columns']:
|
| 1268 |
-
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
| 1269 |
-
batch[k] = [
|
| 1270 |
-
[self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k]
|
| 1271 |
-
]
|
| 1272 |
-
|
| 1273 |
-
batch_index_value_feats_list = []
|
| 1274 |
-
batch_token_type_ids_list = []
|
| 1275 |
-
batch_time_deltas_list = []
|
| 1276 |
-
|
| 1277 |
-
for batch_idx in range(len(batch['study_id'])):
|
| 1278 |
-
|
| 1279 |
-
if any([batch[k][batch_idx] for k in index_value_columns]):
|
| 1280 |
-
|
| 1281 |
-
num_rows = [len(batch[i][batch_idx]) for i in index_value_columns]
|
| 1282 |
-
assert all(x == num_rows[0] for x in num_rows)
|
| 1283 |
-
num_rows = num_rows[0]
|
| 1284 |
-
|
| 1285 |
-
# The y-index and the datetime for each group:
|
| 1286 |
-
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
| 1287 |
-
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
| 1288 |
-
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
| 1289 |
-
assert len(set(y_indices)) == len(datetime)
|
| 1290 |
-
else:
|
| 1291 |
-
y_indices = [0] * num_rows
|
| 1292 |
-
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
| 1293 |
-
|
| 1294 |
-
time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None]
|
| 1295 |
-
|
| 1296 |
-
tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total'])
|
| 1297 |
-
|
| 1298 |
-
# Index columns to feats:
|
| 1299 |
-
if 'index_columns' in self.tables[table]:
|
| 1300 |
-
|
| 1301 |
-
for i in self.tables[table]['index_columns']:
|
| 1302 |
-
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
| 1303 |
-
y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None]
|
| 1304 |
-
x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None]
|
| 1305 |
-
|
| 1306 |
-
tensor[y_indices_column, x_indices_column] = 1.0
|
| 1307 |
-
|
| 1308 |
-
if 'value_columns' in self.tables[table]:
|
| 1309 |
-
for i in self.tables[table]['value_columns']:
|
| 1310 |
-
|
| 1311 |
-
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
| 1312 |
-
y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None]
|
| 1313 |
-
x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None]
|
| 1314 |
-
values = [value for value in batch[k][batch_idx] if value is not None]
|
| 1315 |
-
|
| 1316 |
-
tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype)
|
| 1317 |
-
assert not torch.isnan(tensor).any()
|
| 1318 |
-
else:
|
| 1319 |
-
tensor = torch.empty(0, self.luts[table]['total'])
|
| 1320 |
-
time_deltas = torch.empty(0, 1)
|
| 1321 |
-
|
| 1322 |
-
batch_index_value_feats_list.append(tensor)
|
| 1323 |
-
batch_token_type_ids_list.append(torch.full(
|
| 1324 |
-
[tensor.shape[0]],
|
| 1325 |
-
self.token_type_to_token_type_id[table],
|
| 1326 |
-
dtype=torch.long,
|
| 1327 |
-
)
|
| 1328 |
-
)
|
| 1329 |
-
batch_time_deltas_list.append(time_deltas)
|
| 1330 |
-
|
| 1331 |
-
assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0]
|
| 1332 |
-
assert tensor.shape[0] == time_deltas.shape[0]
|
| 1333 |
-
|
| 1334 |
-
batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
|
| 1335 |
-
batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0)
|
| 1336 |
-
batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0)
|
| 1337 |
-
|
| 1338 |
-
batch_mask = (batch_index_value_feats != -1).any(dim=-1).int()
|
| 1339 |
-
|
| 1340 |
-
return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask
|
| 1341 |
-
|
| 1342 |
-
def prepare_text_prompt(self, table, column, batch):
|
| 1343 |
-
|
| 1344 |
-
key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column
|
| 1345 |
-
|
| 1346 |
-
batch_text_list = []
|
| 1347 |
-
batch_time_deltas_list = []
|
| 1348 |
-
|
| 1349 |
-
for batch_idx in range(len(batch['study_id'])):
|
| 1350 |
-
if batch[key][batch_idx]:
|
| 1351 |
-
|
| 1352 |
-
num_rows = len(batch[key][batch_idx])
|
| 1353 |
-
|
| 1354 |
-
# The y-index and the datetime for each group:
|
| 1355 |
-
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
| 1356 |
-
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
| 1357 |
-
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
| 1358 |
-
assert len(set(y_indices)) == len(datetime)
|
| 1359 |
-
else:
|
| 1360 |
-
y_indices = [0] * num_rows
|
| 1361 |
-
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
| 1362 |
-
|
| 1363 |
-
# Remove None values:
|
| 1364 |
-
text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]]
|
| 1365 |
-
y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None]
|
| 1366 |
-
text_rows = [i for i in text_rows if i is not None]
|
| 1367 |
-
datetime = [datetime[i] for i in set(y_indices)]
|
| 1368 |
-
if text_rows:
|
| 1369 |
-
|
| 1370 |
-
# Those in the same group (or those with the same y-index) get joined as the same string:
|
| 1371 |
-
batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)])
|
| 1372 |
-
batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])
|
| 1373 |
-
|
| 1374 |
-
assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1])
|
| 1375 |
-
else:
|
| 1376 |
-
batch_text_list.append([])
|
| 1377 |
-
batch_time_deltas_list.append([])
|
| 1378 |
-
else:
|
| 1379 |
-
batch_text_list.append([])
|
| 1380 |
-
batch_time_deltas_list.append([])
|
| 1381 |
-
|
| 1382 |
-
return batch_text_list, batch_time_deltas_list
|
| 1383 |
-
|
| 1384 |
-
@staticmethod
|
| 1385 |
-
def collate_fn(batch):
|
| 1386 |
-
keys = set().union(*(d.keys() for d in batch))
|
| 1387 |
-
batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
|
| 1388 |
-
batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
|
| 1389 |
-
return batch
|
| 1390 |
-
|
| 1391 |
-
@staticmethod
|
| 1392 |
-
def prepare_dataset(physionet_dir: str, database_dir: str):
|
| 1393 |
-
prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
|
|
|
|
| 8 |
import torch
|
| 9 |
import transformers
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 12 |
from torch.nn import CrossEntropyLoss
|
| 13 |
from torch.utils.data import Subset
|
| 14 |
from torchvision.io import decode_image
|
| 15 |
+
from torchvision.transforms import v2
|
| 16 |
+
from transformers import PreTrainedTokenizerFast
|
| 17 |
from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
|
|
|
|
| 18 |
from transformers.utils import check_min_version, logging
|
| 19 |
|
| 20 |
from .configuration_cxrmate_ed import CXRMateEDConfig
|
|
|
|
| 187 |
|
| 188 |
self.inf_time_delta_value = self.time_delta_map(float('inf'))
|
| 189 |
|
| 190 |
+
# Image transformations:
|
| 191 |
+
self.train_transforms = v2.Compose(
|
| 192 |
+
[
|
| 193 |
+
v2.Grayscale(num_output_channels=3),
|
| 194 |
+
v2.Resize(
|
| 195 |
+
size=self.config.vision_config.image_size,
|
| 196 |
+
antialias=True,
|
| 197 |
+
interpolation=v2.InterpolationMode.BICUBIC,
|
| 198 |
+
),
|
| 199 |
+
v2.RandomCrop(
|
| 200 |
+
size=[self.config.vision_config.image_size, self.config.vision_config.image_size],
|
| 201 |
+
pad_if_needed=True,
|
| 202 |
+
),
|
| 203 |
+
v2.RandomRotation(degrees=5),
|
| 204 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 205 |
+
v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
| 206 |
+
]
|
| 207 |
+
)
|
| 208 |
+
self.test_transforms = v2.Compose(
|
| 209 |
+
[
|
| 210 |
+
v2.Grayscale(num_output_channels=3),
|
| 211 |
+
v2.Resize(
|
| 212 |
+
size=self.config.vision_config.image_size,
|
| 213 |
+
antialias=True,
|
| 214 |
+
interpolation=v2.InterpolationMode.BICUBIC,
|
| 215 |
+
),
|
| 216 |
+
v2.CenterCrop(size=[self.config.vision_config.image_size, self.config.vision_config.image_size]),
|
| 217 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 218 |
+
v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
| 219 |
+
]
|
| 220 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
+
self.post_init()
|
|
|
|
|
|
|
| 223 |
|
| 224 |
def forward(
|
| 225 |
self,
|
|
|
|
| 589 |
sections[j].append(section_string)
|
| 590 |
|
| 591 |
return tuple(sections.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
def prepare_inputs(
|
| 594 |
self,
|
| 595 |
images,
|
|
|
|
| 718 |
assert inputs_embeds.shape[1] == token_type_ids.shape[1]
|
| 719 |
|
| 720 |
return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
|
| 721 |
+
|
| 722 |
+
def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
|
| 723 |
+
"""
|
| 724 |
+
Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
|
| 725 |
+
Time deltas for the input_ids are also prepared here.
|
| 726 |
+
|
| 727 |
+
Argument/s:
|
| 728 |
+
tokenizer - Hugging Face tokenizer.
|
| 729 |
+
|
| 730 |
+
Returns:
|
| 731 |
+
ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns.
|
| 732 |
+
cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
|
| 733 |
+
"""
|
| 734 |
+
|
| 735 |
+
batch_size = len(kwargs['study_id'])
|
| 736 |
+
|
| 737 |
+
tokenized = {
|
| 738 |
+
'input_ids': {i: [] for i in range(batch_size)},
|
| 739 |
+
'token_type_ids': {i: [] for i in range(batch_size)},
|
| 740 |
+
'time_delta': {i: [] for i in range(batch_size)},
|
| 741 |
+
'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression']
|
| 745 |
+
|
| 746 |
+
for i in prompt_text_columns:
|
| 747 |
+
if i in kwargs:
|
| 748 |
+
if f'{i}_time_delta' not in kwargs:
|
| 749 |
+
kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
|
| 750 |
+
for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
|
| 751 |
+
if y is not None:
|
| 752 |
+
assert isinstance(y, list)
|
| 753 |
+
assert isinstance(z, list)
|
| 754 |
+
for text, time_delta in zip(y, z):
|
| 755 |
+
if text is not None:
|
| 756 |
+
tokenized['input_ids'][x].append(
|
| 757 |
+
tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
|
| 758 |
+
)
|
| 759 |
+
tokenized['token_type_ids'][x].append(
|
| 760 |
+
torch.full(
|
| 761 |
+
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
| 762 |
+
self.token_type_to_token_type_id[i],
|
| 763 |
+
dtype=torch.long,
|
| 764 |
+
device=self.device,
|
| 765 |
+
)
|
| 766 |
+
)
|
| 767 |
+
tokenized['time_delta'][x].append(
|
| 768 |
+
torch.full(
|
| 769 |
+
(1, tokenized['input_ids'][x][-1].shape[-1]),
|
| 770 |
+
time_delta,
|
| 771 |
+
dtype=torch.float32,
|
| 772 |
+
device=self.device,
|
| 773 |
+
)
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
|
| 777 |
+
tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
|
| 778 |
+
tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()]
|
| 779 |
+
|
| 780 |
+
tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence(
|
| 781 |
+
tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
|
| 782 |
+
)[:, :, 0]
|
| 783 |
+
tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(
|
| 784 |
+
tokenized['token_type_ids'], batch_first=True, padding_value=0,
|
| 785 |
+
)[:, :, 0]
|
| 786 |
+
|
| 787 |
+
tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int()
|
| 788 |
+
|
| 789 |
+
tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence(
|
| 790 |
+
tokenized['time_delta'], batch_first=True, padding_value=0,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
return tokenized
|
| 794 |
|
| 795 |
+
def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
|
| 796 |
+
mask_value = torch.finfo(time_deltas.dtype).max if self.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min
|
| 797 |
+
|
| 798 |
+
masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
|
| 799 |
+
_, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion)
|
| 800 |
+
|
| 801 |
+
num_rows, num_cols, _ = time_deltas.shape
|
| 802 |
+
|
| 803 |
+
row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
|
| 804 |
+
position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
|
| 805 |
+
position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
|
| 806 |
+
position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
|
| 807 |
+
|
| 808 |
+
return position_ids
|
| 809 |
+
|
| 810 |
+
def prepare_index_value_feats(self, table, batch):
|
| 811 |
+
|
| 812 |
+
index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
|
| 813 |
+
index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns
|
| 814 |
+
|
| 815 |
+
# Map to indices with lookup table:
|
| 816 |
+
if 'index_columns' in self.tables[table]:
|
| 817 |
+
for i in self.tables[table]['index_columns']:
|
| 818 |
+
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
| 819 |
+
batch[k] = [
|
| 820 |
+
[self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k]
|
| 821 |
+
]
|
| 822 |
+
|
| 823 |
+
batch_index_value_feats_list = []
|
| 824 |
+
batch_token_type_ids_list = []
|
| 825 |
+
batch_time_deltas_list = []
|
| 826 |
+
|
| 827 |
+
for batch_idx in range(len(batch['study_id'])):
|
| 828 |
+
|
| 829 |
+
if any([batch[k][batch_idx] for k in index_value_columns]):
|
| 830 |
+
|
| 831 |
+
num_rows = [len(batch[i][batch_idx]) for i in index_value_columns]
|
| 832 |
+
assert all(x == num_rows[0] for x in num_rows)
|
| 833 |
+
num_rows = num_rows[0]
|
| 834 |
+
|
| 835 |
+
# The y-index and the datetime for each group:
|
| 836 |
+
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
| 837 |
+
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
| 838 |
+
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
| 839 |
+
assert len(set(y_indices)) == len(datetime)
|
| 840 |
+
else:
|
| 841 |
+
y_indices = [0] * num_rows
|
| 842 |
+
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
| 843 |
+
|
| 844 |
+
time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None]
|
| 845 |
+
|
| 846 |
+
tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total'])
|
| 847 |
+
|
| 848 |
+
# Index columns to feats:
|
| 849 |
+
if 'index_columns' in self.tables[table]:
|
| 850 |
+
|
| 851 |
+
for i in self.tables[table]['index_columns']:
|
| 852 |
+
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
| 853 |
+
y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None]
|
| 854 |
+
x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None]
|
| 855 |
+
|
| 856 |
+
tensor[y_indices_column, x_indices_column] = 1.0
|
| 857 |
+
|
| 858 |
+
if 'value_columns' in self.tables[table]:
|
| 859 |
+
for i in self.tables[table]['value_columns']:
|
| 860 |
+
|
| 861 |
+
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
|
| 862 |
+
y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None]
|
| 863 |
+
x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None]
|
| 864 |
+
values = [value for value in batch[k][batch_idx] if value is not None]
|
| 865 |
+
|
| 866 |
+
tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype)
|
| 867 |
+
assert not torch.isnan(tensor).any()
|
| 868 |
+
else:
|
| 869 |
+
tensor = torch.empty(0, self.luts[table]['total'])
|
| 870 |
+
time_deltas = torch.empty(0, 1)
|
| 871 |
+
|
| 872 |
+
batch_index_value_feats_list.append(tensor)
|
| 873 |
+
batch_token_type_ids_list.append(torch.full(
|
| 874 |
+
[tensor.shape[0]],
|
| 875 |
+
self.token_type_to_token_type_id[table],
|
| 876 |
+
dtype=torch.long,
|
| 877 |
+
)
|
| 878 |
+
)
|
| 879 |
+
batch_time_deltas_list.append(time_deltas)
|
| 880 |
+
|
| 881 |
+
assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0]
|
| 882 |
+
assert tensor.shape[0] == time_deltas.shape[0]
|
| 883 |
+
|
| 884 |
+
batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
|
| 885 |
+
batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0)
|
| 886 |
+
batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0)
|
| 887 |
+
|
| 888 |
+
batch_mask = (batch_index_value_feats != -1).any(dim=-1).int()
|
| 889 |
+
|
| 890 |
+
return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask
|
| 891 |
+
|
| 892 |
+
def prepare_text_prompt(self, table, column, batch):
|
| 893 |
+
|
| 894 |
+
key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column
|
| 895 |
+
|
| 896 |
+
batch_text_list = []
|
| 897 |
+
batch_time_deltas_list = []
|
| 898 |
+
|
| 899 |
+
for batch_idx in range(len(batch['study_id'])):
|
| 900 |
+
if batch[key][batch_idx]:
|
| 901 |
+
|
| 902 |
+
num_rows = len(batch[key][batch_idx])
|
| 903 |
+
|
| 904 |
+
# The y-index and the datetime for each group:
|
| 905 |
+
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
|
| 906 |
+
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
|
| 907 |
+
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
|
| 908 |
+
assert len(set(y_indices)) == len(datetime)
|
| 909 |
+
else:
|
| 910 |
+
y_indices = [0] * num_rows
|
| 911 |
+
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
|
| 912 |
+
|
| 913 |
+
# Remove None values:
|
| 914 |
+
text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]]
|
| 915 |
+
y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None]
|
| 916 |
+
text_rows = [i for i in text_rows if i is not None]
|
| 917 |
+
datetime = [datetime[i] for i in set(y_indices)]
|
| 918 |
+
if text_rows:
|
| 919 |
+
|
| 920 |
+
# Those in the same group (or those with the same y-index) get joined as the same string:
|
| 921 |
+
batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)])
|
| 922 |
+
batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])
|
| 923 |
+
|
| 924 |
+
assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1])
|
| 925 |
+
else:
|
| 926 |
+
batch_text_list.append([])
|
| 927 |
+
batch_time_deltas_list.append([])
|
| 928 |
+
else:
|
| 929 |
+
batch_text_list.append([])
|
| 930 |
+
batch_time_deltas_list.append([])
|
| 931 |
+
|
| 932 |
+
return batch_text_list, batch_time_deltas_list
|
| 933 |
+
|
| 934 |
@staticmethod
|
| 935 |
def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
|
| 936 |
|
|
|
|
| 999 |
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
|
| 1000 |
|
| 1001 |
return mixed_causality_4d_attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
+
@staticmethod
|
| 1004 |
+
def collate_fn(batch):
|
| 1005 |
+
keys = set().union(*(d.keys() for d in batch))
|
| 1006 |
+
batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
|
| 1007 |
+
batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
|
| 1008 |
+
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
+
@staticmethod
|
| 1011 |
+
def prepare_dataset(physionet_dir: str, database_dir: str):
|
|
|
|
|
|
|
|
|
|
| 1012 |
|
| 1013 |
+
prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
|
| 1014 |
|
| 1015 |
+
def get_dataset(self, database_dir, max_train_images_per_study=None, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False):
|
|
|
|
|
|
|
|
|
|
| 1016 |
|
| 1017 |
+
dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset')
|
|
|
|
|
|
|
| 1018 |
|
| 1019 |
+
assert max_train_images_per_study is not None or test_set_only, 'max_train_images_per_study must be defined if training.'
|
|
|
|
| 1020 |
|
| 1021 |
def train_set_transform(batch):
|
| 1022 |
|
|
|
|
| 1035 |
|
| 1036 |
# Sort based on ViewPosition:
|
| 1037 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1038 |
+
batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']]
|
| 1039 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1040 |
|
| 1041 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
|
|
|
| 1058 |
|
| 1059 |
# Sort based on ViewPosition:
|
| 1060 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1061 |
+
batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']]
|
| 1062 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1063 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
| 1064 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
|
|
| 1131 |
else:
|
| 1132 |
return test_set
|
| 1133 |
|
| 1134 |
+
def get_stage_1_dataset(self, database_dir, max_train_images_per_study):
|
| 1135 |
+
|
| 1136 |
+
dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset')
|
| 1137 |
|
| 1138 |
def train_set_transform(batch):
|
| 1139 |
|
|
|
|
| 1148 |
|
| 1149 |
# Sort based on ViewPosition:
|
| 1150 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1151 |
+
batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']]
|
| 1152 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1153 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
| 1154 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
|
|
| 1160 |
|
| 1161 |
# Sort based on ViewPosition:
|
| 1162 |
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
|
| 1163 |
+
batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']]
|
| 1164 |
max_size = max(i.shape[0] for i in batch['images'])
|
| 1165 |
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
|
| 1166 |
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
|
|
|
| 1212 |
test_set = Subset(test_set, indices)
|
| 1213 |
|
| 1214 |
return train_set, val_set, test_set
|
| 1215 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|