hyzhou commited on
Commit
cca9b7e
·
1 Parent(s): 2d71a54

upload everything

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +24 -0
  2. MedVersa/pytorch_model.bin +3 -0
  3. README.md +58 -0
  4. __pycache__/utils.cpython-39.pyc +0 -0
  5. demo.py +544 -0
  6. demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png +0 -0
  7. demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png +0 -0
  8. demo_ex/Case_00840_0000.nii.gz +3 -0
  9. demo_ex/Case_01013_0000.nii.gz +3 -0
  10. demo_ex/ISIC_0032258.jpg +0 -0
  11. demo_ex/ISIC_0033730.jpg +0 -0
  12. demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png +0 -0
  13. demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png +0 -0
  14. demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png +0 -0
  15. demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png +0 -0
  16. environment.yml +479 -0
  17. inference.py +107 -0
  18. medomni/__init__.py +31 -0
  19. medomni/__pycache__/__init__.cpython-311.pyc +0 -0
  20. medomni/__pycache__/__init__.cpython-39.pyc +0 -0
  21. medomni/common/__init__.py +0 -0
  22. medomni/common/__pycache__/__init__.cpython-39.pyc +0 -0
  23. medomni/common/__pycache__/config.cpython-39.pyc +0 -0
  24. medomni/common/__pycache__/dist_utils.cpython-39.pyc +0 -0
  25. medomni/common/__pycache__/logger.cpython-39.pyc +0 -0
  26. medomni/common/__pycache__/optims.cpython-39.pyc +0 -0
  27. medomni/common/__pycache__/registry.cpython-39.pyc +0 -0
  28. medomni/common/__pycache__/utils.cpython-39.pyc +0 -0
  29. medomni/common/config.py +468 -0
  30. medomni/common/dist_utils.py +137 -0
  31. medomni/common/gradcam.py +24 -0
  32. medomni/common/logger.py +200 -0
  33. medomni/common/optims.py +119 -0
  34. medomni/common/registry.py +327 -0
  35. medomni/common/utils.py +424 -0
  36. medomni/configs/datasets/medinterp/align.yaml +5 -0
  37. medomni/configs/default.yaml +5 -0
  38. medomni/configs/models/medomni.yaml +12 -0
  39. medomni/conversation/__init__.py +0 -0
  40. medomni/conversation/__pycache__/__init__.cpython-39.pyc +0 -0
  41. medomni/conversation/__pycache__/conversation.cpython-39.pyc +0 -0
  42. medomni/conversation/conversation.py +222 -0
  43. medomni/datasets/__init__.py +0 -0
  44. medomni/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  45. medomni/datasets/__pycache__/data_utils.cpython-39.pyc +0 -0
  46. medomni/datasets/builders/__init__.py +71 -0
  47. medomni/datasets/builders/__pycache__/__init__.cpython-39.pyc +0 -0
  48. medomni/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc +0 -0
  49. medomni/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc +0 -0
  50. medomni/datasets/builders/base_dataset_builder.py +234 -0
LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright – President and Fellows of Harvard College, 2024. All Rights Reserved.
2
+
3
+ Redistribution and use in source and binary forms, with or without
4
+ modification, are permitted provided that the following conditions are met:
5
+
6
+ Redistributions of source code must retain the above copyright notice, this
7
+ list of conditions and the following disclaimer. Redistributions in binary
8
+ form must reproduce the above copyrightnotice, this list of conditions and the
9
+ following disclaimer in the documentation and/or other materials provided with
10
+ the distribution. Neither the name "Harvard" nor the names of its contributors
11
+ may be used to endorse or promote products derived from this software without
12
+ specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOTLIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17
+ ARE DISCLAIMED. IN NO EVENT SHALLTHECOPYRIGHT HOLDER OR CONTRIBUTORS BE
18
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19
+ CONSEQUENTIAL DAMAGES(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21
+ INTERRUPTION)HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23
+ OTHERWISE)ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
24
+ OF THE POSSIBILITY OF SUCH DAMAGE.
MedVersa/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3ce596897168d79649e6d6df128a1b409a0cc878092f00667873be6f4b8c9d3
3
+ size 13993804625
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: hyzhouMedVersa
3
+ app_file: demo_inter.py
4
+ sdk: gradio
5
+ sdk_version: 4.24.0
6
+ ---
7
+ # MedVersa: An orchestrated medical AI system
8
+ MedVersa is a compound medical AI system that can coordinate multimodal inputs, orchestrate models and tools for varying tasks, and generate multimodal outputs.
9
+
10
+ ## Environment
11
+ MedVersa is written in [Python](https://www.python.org/). It is recommended to configure/manage your python environment using conda. To do this, you need to install the [miniconda](https://docs.anaconda.com/free/miniconda/index.html) or [anaconda](https://www.anaconda.com/) first.
12
+
13
+ After installing conda, you need to set up a new conda environment for MedVersa using the provided `environment.yml`:
14
+ ``` shell
15
+ conda env create -f environment.yml
16
+ conda activate medversa
17
+ ```
18
+
19
+ ## Inference
20
+ ``` python
21
+ from utils import *
22
+
23
+ # --- Launch Model ---
24
+ device = 'cuda:0'
25
+ model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
26
+ model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
27
+ model.eval()
28
+
29
+ # --- Define examples ---
30
+ examples = [
31
+ [
32
+ ["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
33
+ "Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
34
+ "How would you characterize the findings from <img0>?",
35
+ "cxr",
36
+ "report generation",
37
+ ],
38
+ ]
39
+ # --- Define hyperparams ---
40
+ num_beams = 1
41
+ do_sample = True
42
+ min_length = 1
43
+ top_p = 0.9
44
+ repetition_penalty = 1
45
+ length_penalty = 1
46
+ temperature = 0.1
47
+
48
+ # --- Generate a report for an chest X-ray image ---
49
+ index = 0
50
+ demo_ex = examples[index]
51
+ images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
52
+ seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
53
+ print(output_text)
54
+ ```
55
+ For more details and examples, please refer to `inference.py`.
56
+
57
+ ## Demo
58
+ `CUDA_VISIBLE_DEVICES=0 python demo.py --cfg-path medversa.yaml`
__pycache__/utils.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
demo.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms.functional as TF
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import skimage.morphology, skimage.io
9
+ import cv2
10
+ import numpy as np
11
+ import random
12
+ from transformers import StoppingCriteria, StoppingCriteriaList
13
+ from copy import deepcopy
14
+ from medomni.common.config import Config
15
+ from medomni.common.dist_utils import get_rank
16
+ from medomni.common.registry import registry
17
+ import torchio as tio
18
+ import nibabel as nib
19
+ from scipy import ndimage, misc
20
+ import time
21
+ import ipdb
22
+
23
+ # Function to parse command line arguments
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser(description="Demo")
26
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
27
+ parser.add_argument(
28
+ "--options",
29
+ nargs="+",
30
+ help="override some settings in the used config, the key-value pair in xxx=yyy format will be merged into config file (deprecate), change to --cfg-options instead.",
31
+ )
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+ device = 'cuda:0'
36
+ # Launch model
37
+ args = parse_args()
38
+ cfg = Config(args)
39
+
40
+ model_config = cfg.model_cfg
41
+ model_cls = registry.get_model_class(model_config.arch)
42
+ model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
43
+ model.eval()
44
+ global global_images
45
+ global_images = None
46
+
47
+ def seg_2d_process(image_path, pred_mask, img_size=224):
48
+ image = cv2.imread(image_path[0])
49
+ if pred_mask.sum() != 0:
50
+ labels = skimage.morphology.label(pred_mask)
51
+ labelCount = np.bincount(labels.ravel())
52
+ largest_label = np.argmax(labelCount[1:]) + 1
53
+ pred_mask[labels != largest_label] = 0
54
+ pred_mask[labels == largest_label] = 255
55
+ pred_mask = pred_mask.astype(np.uint8)
56
+ contours, _ = cv2.findContours(pred_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
57
+ if contours:
58
+ contours = np.vstack(contours)
59
+ binary_array = np.zeros((img_size, img_size))
60
+ binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED)
61
+ binary_array = cv2.resize(binary_array, (image.shape[1], image.shape[0]), interpolation = cv2.INTER_NEAREST) / 255
62
+ image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
63
+ mask = [binary_array]
64
+ else:
65
+ image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
66
+ mask = [np.zeros((image.shape[1], image.shape[0]))]
67
+ else:
68
+ mask = [np.zeros((image.shape[1], image.shape[0]))]
69
+ image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
70
+ # output_image = cv2.drawContours(binary_array, contours, -1, (110, 0, 255), 2)
71
+ # output_image_pil = Image.fromarray(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
72
+ return image, mask
73
+
74
+ def seg_3d_process(image_path, seg_mask):
75
+ img = nib.load(image_path[0]).get_fdata()
76
+ image = window_scan(img).transpose(2,0,1).astype(np.uint8)
77
+ if seg_mask.sum() != 0:
78
+ seg_mask = resize_back_volume_abd(seg_mask, image.shape).astype(np.uint8)
79
+ image_slices = []
80
+ contour_slices = []
81
+ for i in range(seg_mask.shape[0]):
82
+ slice_img = np.fliplr(np.rot90(image[i]))
83
+ slice_mask = np.fliplr(np.rot90(seg_mask[i]))
84
+ contours, _ = cv2.findContours(slice_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
85
+ image_slices.append(Image.fromarray(slice_img))
86
+ if contours:
87
+ binary_array = np.zeros(seg_mask.shape[1:])
88
+ binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED) / 255
89
+ binary_array = cv2.resize(binary_array, slice_img.shape, interpolation = cv2.INTER_NEAREST)
90
+ contour_slices.append(binary_array)
91
+ else:
92
+ contour_slices.append(np.zeros_like(slice_img))
93
+ else:
94
+ image_slices = []
95
+ contour_slices = []
96
+ slice_img = np.fliplr(np.rot90(image[i]))
97
+ image_slices.append(Image.fromarray(slice_img))
98
+ contour_slices.append(np.zeros_like(slice_img))
99
+
100
+ return image_slices, contour_slices
101
+
102
+ def det_2d_process(image_path, box):
103
+ image_slices = []
104
+ image = cv2.imread(image_path[0])
105
+ if box is not None:
106
+ hi,wd,_ = image.shape
107
+ color = tuple(np.random.random(size=3) * 256)
108
+ x1, y1, x2, y2 = int(box[0]*wd), int(box[1]*hi), int(box[2]*wd), int(box[3]*hi)
109
+ image = cv2.rectangle(image, (x1, y1), (x2, y2), color, 10)
110
+ image_slices.append(Image.fromarray(image))
111
+ return image_slices
112
+
113
+ def window_scan(scan, window_center=50, window_width=400):
114
+ """
115
+ Apply windowing to a scan.
116
+
117
+ Parameters:
118
+ scan (numpy.ndarray): 3D numpy array of the CT scan
119
+ window_center (int): The center of the window
120
+ window_width (int): The width of the window
121
+
122
+ Returns:
123
+ numpy.ndarray: Windowed CT scan
124
+ """
125
+ lower_bound = window_center - (window_width // 2)
126
+ upper_bound = window_center + (window_width // 2)
127
+
128
+ windowed_scan = np.clip(scan, lower_bound, upper_bound)
129
+ windowed_scan = (windowed_scan - lower_bound) / (upper_bound - lower_bound)
130
+ windowed_scan = (windowed_scan * 255).astype(np.uint8)
131
+
132
+ return windowed_scan
133
+
134
+ def task_seg_2d(model, preds, hidden_states, image):
135
+ token_mask = preds == model.seg_token_idx_2d
136
+ indices = torch.where(token_mask == True)[0].cpu().numpy()
137
+ feats = model.model_seg_2d.encoder(image.unsqueeze(0)[:, 0])
138
+ last_feats = feats[-1]
139
+ target_states = [hidden_states[ind][-1] for ind in indices]
140
+ if target_states:
141
+ target_states = torch.cat(target_states).squeeze()
142
+ seg_states = model.text2seg_2d(target_states).unsqueeze(0)
143
+ last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1)
144
+ last_feats = model.text2seg_2d_gn(last_feats)
145
+ feats[-1] = last_feats
146
+ seg_feats = model.model_seg_2d.decoder(*feats)
147
+ seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
148
+ seg_probs = F.sigmoid(seg_preds)
149
+ seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
150
+ return seg_mask
151
+ else:
152
+ return None
153
+
154
+ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
155
+ new_img_embeds_list = deepcopy(img_embeds_list)
156
+ token_mask = preds == model.seg_token_idx_3d
157
+ indices = torch.where(token_mask == True)[0].cpu().numpy()
158
+ target_states = [hidden_states[ind][-1] for ind in indices]
159
+ if target_states:
160
+ target_states = torch.cat(target_states).squeeze().unsqueeze(0)
161
+ seg_states = model.text2seg_3d(target_states)
162
+ last_feats = new_img_embeds_list[-1]
163
+ last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
164
+ last_feats = model.text2seg_3d_gn(last_feats)
165
+ new_img_embeds_list[-1] = last_feats
166
+ seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
167
+ seg_probs = F.sigmoid(seg_preds)
168
+ seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
169
+ return seg_mask
170
+
171
+ def task_det_2d(model, preds, hidden_states):
172
+ token_mask = preds == model.det_token_idx
173
+ indices = torch.where(token_mask == True)[0].cpu().numpy()
174
+ target_states = [hidden_states[ind][-1] for ind in indices]
175
+ if target_states:
176
+ target_states = torch.cat(target_states).squeeze()
177
+ det_states = model.text_det(target_states).detach().cpu()
178
+ return det_states.numpy()
179
+ return torch.zeros_like(indices)
180
+
181
+ class StoppingCriteriaSub(StoppingCriteria):
182
+ def __init__(self, stops=[]):
183
+ super().__init__()
184
+ self.stops = stops
185
+
186
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
187
+ for stop in self.stops:
188
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
189
+ return True
190
+ return False
191
+
192
+ def resize_back_volume_abd(img, target_size):
193
+ desired_depth = target_size[0]
194
+ desired_width = target_size[1]
195
+ desired_height = target_size[2]
196
+
197
+ current_depth = img.shape[0] # [d, w, h]
198
+ current_width = img.shape[1]
199
+ current_height = img.shape[2]
200
+
201
+ depth = current_depth / desired_depth
202
+ width = current_width / desired_width
203
+ height = current_height / desired_height
204
+
205
+ depth_factor = 1 / depth
206
+ width_factor = 1 / width
207
+ height_factor = 1 / height
208
+
209
+ img = ndimage.zoom(img, (depth_factor, width_factor, height_factor), order=0)
210
+ return img
211
+
212
+ def resize_volume_abd(img):
213
+ img[img<=-200] = -200
214
+ img[img>=300] = 300
215
+
216
+ desired_depth = 64
217
+ desired_width = 192
218
+ desired_height = 192
219
+
220
+ current_width = img.shape[0] # [w, h, d]
221
+ current_height = img.shape[1]
222
+ current_depth = img.shape[2]
223
+
224
+ depth = current_depth / desired_depth
225
+ width = current_width / desired_width
226
+ height = current_height / desired_height
227
+
228
+ depth_factor = 1 / depth
229
+ width_factor = 1 / width
230
+ height_factor = 1 / height
231
+
232
+ img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=0)
233
+ return img
234
+
235
+ def load_and_preprocess_image(image):
236
+ mean = (0.48145466, 0.4578275, 0.40821073)
237
+ std = (0.26862954, 0.26130258, 0.27577711)
238
+ transform = transforms.Compose([
239
+ transforms.Resize([224, 224]),
240
+ transforms.ToTensor(),
241
+ transforms.Normalize(mean, std)
242
+ ])
243
+ image = transform(image).type(torch.bfloat16).cuda().unsqueeze(0)
244
+ return image
245
+
246
+ def load_and_preprocess_volume(image):
247
+ img = nib.load(image).get_fdata()
248
+ image = torch.from_numpy(resize_volume_abd(img)).permute(2,0,1)
249
+ transform = tio.Compose([
250
+ tio.ZNormalization(masking_method=tio.ZNormalization.mean),
251
+ ])
252
+ image = transform(image.unsqueeze(0)).type(torch.bfloat16).cuda()
253
+ return image
254
+
255
+ def read_image(image_path):
256
+ if image_path.endswith(('.jpg', '.jpeg', '.png')):
257
+ return load_and_preprocess_image(Image.open(image_path).convert('RGB'))
258
+ elif image_path.endswith('.nii.gz'):
259
+ return load_and_preprocess_volume(image_path)
260
+ else:
261
+ raise ValueError("Unsupported file format")
262
+
263
+ def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
264
+ if (len(context) != 0 and ('report' in prompt or 'finding' in prompt or 'impression' in prompt)) or (len(context) != 0 and modal=='derm' and ('diagnosis' in prompt or 'issue' in prompt or 'problem' in prompt)):
265
+ prompt = '<context>' + context + '</context>' + prompt
266
+ if modal == 'ct' and 'segment' in prompt.lower():
267
+ if 'liver' in prompt:
268
+ prompt = 'Segment the liver.'
269
+ if 'spleen' in prompt:
270
+ prompt = 'Segment the spleen.'
271
+ if 'kidney' in prompt:
272
+ prompt = 'Segment the kidney.'
273
+ if 'pancrea' in prompt:
274
+ prompt = 'Segment the pancreas.'
275
+ img_embeds, atts_img, img_embeds_list = model.encode_img(image.unsqueeze(0), [modal])
276
+ placeholder = ['<ImageHere>'] * 9
277
+ prefix = '###Human:' + ''.join([f'<img{i}>' + ''.join(placeholder) + f'</img{i}>' for i in range(num_imgs)])
278
+ img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img, [prefix], [num_imgs])
279
+ prompt += '###Assistant:'
280
+ prompt_tokens = model.llama_tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(image.device)
281
+ new_img_embeds, new_atts_img = model.prompt_concat(img_embeds, atts_img, prompt_tokens)
282
+
283
+ outputs = model.llama_model.generate(
284
+ inputs_embeds=new_img_embeds,
285
+ max_new_tokens=450,
286
+ stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub(stops=[
287
+ torch.tensor([835]).type(torch.bfloat16).to(image.device),
288
+ torch.tensor([2277, 29937]).type(torch.bfloat16).to(image.device)
289
+ ])]),
290
+ num_beams=num_beams,
291
+ do_sample=do_sample,
292
+ min_length=min_length,
293
+ top_p=top_p,
294
+ repetition_penalty=repetition_penalty,
295
+ length_penalty=length_penalty,
296
+ temperature=temperature,
297
+ output_hidden_states=True,
298
+ return_dict_in_generate=True,
299
+ )
300
+
301
+ hidden_states = outputs.hidden_states
302
+ preds = outputs.sequences[0]
303
+ output_image = None
304
+ seg_mask_2d = None
305
+ seg_mask_3d = None
306
+ if sum(preds == model.seg_token_idx_2d):
307
+ seg_mask = task_seg_2d(model, preds, hidden_states, image)
308
+ output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
309
+ if sum(preds == model.seg_token_idx_3d):
310
+ seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
311
+ output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
312
+ if sum(preds == model.det_token_idx):
313
+ det_box = task_det_2d(model, preds, hidden_states)
314
+ output_image = det_2d_process(image_path, det_box)
315
+
316
+ if preds[0] == 0: # Remove unknown token <unk> at the beginning
317
+ preds = preds[1:]
318
+ if preds[0] == 1: # Remove start token <s> at the beginning
319
+ preds = preds[1:]
320
+
321
+ output_text = model.llama_tokenizer.decode(preds, add_special_tokens=False)
322
+ output_text = output_text.split('###')[0].split('Assistant:')[-1].strip()
323
+
324
+ if 'mel' in output_text and modal == 'derm':
325
+ output_text = 'The main diagnosis is melanoma.'
326
+ return output_image, seg_mask_2d, seg_mask_3d, output_text
327
+
328
+ def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
329
+ num_imgs = len(images)
330
+ modal = modality.lower()
331
+ image_tensors = [read_image(img) for img in images]
332
+ if modality == 'ct':
333
+ time.sleep(2)
334
+ else:
335
+ time.sleep(1)
336
+ image_tensor = torch.cat(image_tensors)
337
+
338
+ with torch.autocast("cuda"):
339
+ with torch.no_grad():
340
+ generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
341
+
342
+ return generated_image, seg_mask_2d, seg_mask_3d, output_text
343
+
344
+ my_dict = {}
345
+ def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
346
+ global global_images
347
+ if not images:
348
+ image = np.zeros((224, 224, 3), dtype=np.uint8)
349
+ blank_image = Image.fromarray(image)
350
+ snapshot = (blank_image, [])
351
+ global_images = 'none'
352
+ return [(prompt, "At least one image is required to proceed.")], snapshot, gr.update(maximum=0)
353
+ if not prompt or not modality:
354
+ image = np.zeros((224, 224, 3), dtype=np.uint8)
355
+ blank_image = Image.fromarray(image)
356
+ snapshot = (blank_image, [])
357
+ global_images = 'none'
358
+ return [(prompt, "Please provide prompt and modality to proceed.")], snapshot, gr.update(maximum=0)
359
+
360
+ generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
361
+ output_images = []
362
+ input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
363
+ if generated_images is not None:
364
+ for generated_image in generated_images:
365
+ output_images.append(np.asarray(generated_image).astype(np.uint8))
366
+ snapshot = (output_images[0], [])
367
+ if seg_mask_2d is not None:
368
+ snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
369
+ if seg_mask_3d is not None:
370
+ snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
371
+ else:
372
+ output_images = input_images.copy()
373
+ snapshot = (output_images[0], [])
374
+
375
+ my_dict['image'] = output_images
376
+ my_dict['mask'] = None
377
+ if seg_mask_2d is not None:
378
+ my_dict['mask'] = seg_mask_2d
379
+ if seg_mask_3d is not None:
380
+ my_dict['mask'] = seg_mask_3d
381
+
382
+ if global_images != images and (global_images is not None):
383
+ chatbot = []
384
+ chatbot.append((prompt, output_text))
385
+ else:
386
+ chatbot.append((prompt, output_text))
387
+ global_images = images
388
+
389
+ return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
390
+
391
+ # my_dict = {}
392
+ # def gradio_interface(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
393
+ # if not images:
394
+ # return None, "Error: At least one image is required to proceed."
395
+ # if not prompt or not task or not modality:
396
+ # return None, "Error: Please provide prompt, select task and modality to proceed."
397
+
398
+ # generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
399
+ # output_images = []
400
+
401
+ # input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
402
+ # if generated_images is not None:
403
+ # for generated_image in generated_images:
404
+ # output_images.append(np.asarray(generated_image).astype(np.uint8))
405
+ # snapshot = (output_images[0], [])
406
+ # if seg_mask_2d is not None:
407
+ # snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
408
+ # if seg_mask_3d is not None:
409
+ # snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
410
+ # else:
411
+ # output_images = input_images.copy()
412
+ # snapshot = (output_images[0], [])
413
+
414
+ # my_dict['image'] = output_images
415
+ # my_dict['mask'] = None
416
+ # if seg_mask_2d is not None:
417
+ # my_dict['mask'] = seg_mask_2d
418
+ # if seg_mask_3d is not None:
419
+ # my_dict['mask'] = seg_mask_3d
420
+
421
+ # return output_text, snapshot, gr.update(maximum=len(output_images)-1)
422
+
423
+ def render(x):
424
+ if x > len(my_dict['image'])-1:
425
+ x = len(my_dict['image'])-1
426
+ if x < 0:
427
+ x = 0
428
+ image = my_dict['image'][x]
429
+ if my_dict['mask'] is None:
430
+ return (image,[])
431
+ else:
432
+ mask = my_dict['mask'][x]
433
+ value = (image,[(mask, "Mask")])
434
+ return value
435
+
436
+ def update_context_visibility(task):
437
+ if task == "report generation" or task == 'classification':
438
+ return gr.update(visible=True)
439
+ else:
440
+ return gr.update(visible=False)
441
+
442
+ def reset_chatbot():
443
+ return []
444
+
445
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
446
+ # with gr.Row():
447
+ # gr.Markdown("<link href='https://fonts.googleapis.com/css2?family=Libre+Franklin:wght@400;700&display=swap' rel='stylesheet'>")
448
+ gr.Markdown("# MedVersa")
449
+ with gr.Row():
450
+ with gr.Column():
451
+ image_input = gr.File(label="Upload Images", file_count="multiple", file_types=["image", "numpy"])
452
+ # task_input = gr.Dropdown(choices=["report generation", "vqa", "localization", "classification"], label="Task")
453
+ context_input = gr.Textbox(label="Context", placeholder="Enter context here...", lines=3, visible=True)
454
+ modality_input = gr.Dropdown(choices=["cxr", "derm", "ct"], label="Modality")
455
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter prompt here... (images should be referred as <img0>, <img1>, ...)", lines=3)
456
+ submit_button = gr.Button("Generate Predictions")
457
+ with gr.Accordion("Advanced Settings", open=False):
458
+ num_beams = gr.Slider(label="Number of Beams", minimum=1, maximum=10, step=1, value=1)
459
+ do_sample = gr.Checkbox(label="Do Sample", value=True)
460
+ min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
461
+ top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.9)
462
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
463
+ length_penalty = gr.Slider(label="Length Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
464
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.1)
465
+
466
+ with gr.Column():
467
+ # output_text = gr.Textbox(label="Generated Text", lines=10, elem_classes="output-textbox")
468
+ chatbot = gr.Chatbot(label="Chatbox")
469
+ slider = gr.Slider(minimum=0, maximum=64, value=1, step=1)
470
+ output_image = gr.AnnotatedImage(height=448, label="Images")
471
+
472
+ # task_input.change(
473
+ # fn=update_context_visibility,
474
+ # inputs=task_input,
475
+ # outputs=context_input
476
+ # )
477
+
478
+ submit_button.click(
479
+ fn=gradio_interface,
480
+ inputs=[chatbot, image_input, context_input, prompt_input, modality_input, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature],
481
+ outputs=[chatbot, output_image, slider]
482
+ )
483
+
484
+ slider.change(
485
+ render,
486
+ inputs=[slider],
487
+ outputs=[output_image],
488
+ )
489
+
490
+ examples = [
491
+ [
492
+ ["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
493
+ "Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
494
+ "How would you characterize the findings from <img0>?",
495
+ "cxr",
496
+ ],
497
+ [
498
+ ["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"],
499
+ "Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.",
500
+ "How would you characterize the findings from <img0>?",
501
+ "cxr",
502
+ ],
503
+ [
504
+ ["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"],
505
+ "Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.",
506
+ "How would you characterize the findings from <img0><img1>?",
507
+ "cxr",
508
+ ],
509
+ [
510
+ ["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"],
511
+ "Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.",
512
+ "How would you characterize the findings from <img0>?",
513
+ "cxr",
514
+ ],
515
+ [
516
+ ["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"],
517
+ "Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.",
518
+ "How would you characterize the findings from <img0>?",
519
+ "cxr",
520
+ ],
521
+ [
522
+ ["./demo_ex/ISIC_0032258.jpg"],
523
+ "Age:70.\nGender:female.\nLocation:back.",
524
+ "What is primary diagnosis?",
525
+ "derm",
526
+ ],
527
+ [
528
+ ["./demo_ex/Case_01013_0000.nii.gz"],
529
+ "",
530
+ "Segment the liver.",
531
+ "ct",
532
+ ],
533
+ [
534
+ ["./demo_ex/Case_00840_0000.nii.gz"],
535
+ "",
536
+ "Segment the liver.",
537
+ "ct",
538
+ ],
539
+ ]
540
+
541
+ gr.Examples(examples, inputs=[image_input, context_input, prompt_input, modality_input])
542
+
543
+ # Run Gradio app
544
+ demo.launch(share=True)
demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png ADDED
demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png ADDED
demo_ex/Case_00840_0000.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27d91a51f4f792740aab30da1416e2a200f637a53e9aa842cf47f2dd96519216
3
+ size 30618190
demo_ex/Case_01013_0000.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63f597a81e594aa0b5d5b67551658f4a8be831ac6189f2f3f644b0a1098fbb09
3
+ size 30845920
demo_ex/ISIC_0032258.jpg ADDED
demo_ex/ISIC_0033730.jpg ADDED
demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png ADDED
demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png ADDED
demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png ADDED
demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png ADDED
environment.yml ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: medversa
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - anaconda
7
+ - defaults
8
+ dependencies:
9
+ - _libgcc_mutex=0.1=main
10
+ - _openmp_mutex=5.1=1_gnu
11
+ - abseil-cpp=20211102.0=h27087fc_1
12
+ - aiosignal=1.3.1=pyhd8ed1ab_0
13
+ - arrow-cpp=11.0.0=py39h613000e_0
14
+ - asttokens=2.2.1=pyhd8ed1ab_0
15
+ - async-timeout=4.0.2=pyhd8ed1ab_0
16
+ - atk-1.0=2.36.0=ha1a6a79_0
17
+ - aws-c-common=0.4.57=he6710b0_1
18
+ - aws-c-event-stream=0.1.6=h2531618_5
19
+ - aws-checksums=0.1.9=he6710b0_0
20
+ - aws-sdk-cpp=1.8.185=hce553d0_0
21
+ - blas=1.0=mkl
22
+ - boost-cpp=1.70.0=ha2d47e9_1
23
+ - bottleneck=1.3.5=py39h7deecbd_0
24
+ - brotlipy=0.7.0=py39h27cfd23_1003
25
+ - bzip2=1.0.8=h7b6447c_0
26
+ - c-ares=1.18.1=h7f8727e_0
27
+ - ca-certificates=2023.12.12=h06a4308_0
28
+ - cairo=1.16.0=hb05425b_5
29
+ - certifi=2023.11.17=py39h06a4308_0
30
+ - cffi=1.15.1=py39h5eee18b_3
31
+ - colorama=0.4.6=pyhd8ed1ab_0
32
+ - cryptography=41.0.2=py39h774aba0_0
33
+ - cuda-cudart=11.7.99=0
34
+ - cuda-cupti=11.7.101=0
35
+ - cuda-libraries=11.7.1=0
36
+ - cuda-nvrtc=11.7.99=0
37
+ - cuda-nvtx=11.7.91=0
38
+ - cuda-runtime=11.7.1=0
39
+ - cudatoolkit=11.7.0=hd8887f6_10
40
+ - curl=7.87.0=h5eee18b_0
41
+ - dataclasses=0.8=pyhc8e2a94_3
42
+ - datasets=2.14.3=pyhd8ed1ab_0
43
+ - dill=0.3.7=pyhd8ed1ab_0
44
+ - executing=1.2.0=pyhd8ed1ab_0
45
+ - expat=2.4.9=h6a678d5_0
46
+ - ffmpeg=4.3=hf484d3e_0
47
+ - filelock=3.9.0=py39h06a4308_0
48
+ - flit-core=3.6.0=pyhd3eb1b0_0
49
+ - font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
50
+ - font-ttf-inconsolata=2.001=hcb22688_0
51
+ - font-ttf-source-code-pro=2.030=hd3eb1b0_0
52
+ - font-ttf-ubuntu=0.83=h8b1ccd4_0
53
+ - fontconfig=2.14.1=h52c9d5c_1
54
+ - fonts-anaconda=1=h8fa9717_0
55
+ - fonts-conda-ecosystem=1=hd3eb1b0_0
56
+ - freetype=2.12.1=h4a9f257_0
57
+ - fribidi=1.0.10=h7b6447c_0
58
+ - frozenlist=1.3.3=py39h5eee18b_0
59
+ - gdbm=1.18=hd4cb3f1_4
60
+ - gdk-pixbuf=2.42.10=h5eee18b_0
61
+ - gettext=0.21.0=hf68c758_0
62
+ - gflags=2.2.2=he1b5a44_1004
63
+ - giflib=5.2.1=h5eee18b_3
64
+ - git=2.34.1=pl5262hc120c5b_0
65
+ - glib=2.69.1=he621ea3_2
66
+ - glog=0.5.0=h48cff8f_0
67
+ - gmp=6.2.1=h295c915_3
68
+ - gmpy2=2.1.2=py39heeb90bb_0
69
+ - gnutls=3.6.15=he1e5248_0
70
+ - gobject-introspection=1.72.0=py39hbb6d50b_2
71
+ - graphite2=1.3.14=h295c915_1
72
+ - graphviz=2.50.0=h1b29801_1
73
+ - grpc-cpp=1.46.1=h33aed49_1
74
+ - gtk2=2.24.33=h73c1081_2
75
+ - gts=0.7.6=hb67d8dd_3
76
+ - harfbuzz=4.3.0=hf52aaf7_1
77
+ - huggingface_hub=0.16.4=pyhd8ed1ab_0
78
+ - icu=58.2=he6710b0_3
79
+ - idna=3.4=py39h06a4308_0
80
+ - importlib_metadata=6.8.0=hd8ed1ab_0
81
+ - intel-openmp=2023.1.0=hdb19cb5_46305
82
+ - jinja2=3.1.2=py39h06a4308_0
83
+ - jpeg=9e=h5eee18b_1
84
+ - krb5=1.19.4=h568e23c_0
85
+ - lame=3.100=h7b6447c_0
86
+ - lcms2=2.12=h3be6417_0
87
+ - ld_impl_linux-64=2.38=h1181459_1
88
+ - lerc=3.0=h295c915_0
89
+ - libbrotlicommon=1.0.9=h166bdaf_7
90
+ - libbrotlidec=1.0.9=h166bdaf_7
91
+ - libbrotlienc=1.0.9=h166bdaf_7
92
+ - libcublas=11.10.3.66=0
93
+ - libcufft=10.7.2.124=h4fbf590_0
94
+ - libcufile=1.7.1.12=0
95
+ - libcurand=10.3.3.129=0
96
+ - libcurl=7.87.0=h91b91d3_0
97
+ - libcusolver=11.4.0.1=0
98
+ - libcusparse=11.7.4.91=0
99
+ - libdeflate=1.17=h5eee18b_0
100
+ - libedit=3.1.20221030=h5eee18b_0
101
+ - libev=4.33=h7f8727e_1
102
+ - libevent=2.1.10=h9b69904_4
103
+ - libffi=3.4.2=h6a678d5_6
104
+ - libgcc=7.2.0=h69d50b8_2
105
+ - libgcc-ng=11.2.0=h1234567_1
106
+ - libgd=2.3.3=h6a678d5_3
107
+ - libgomp=11.2.0=h1234567_1
108
+ - libiconv=1.16=h7f8727e_2
109
+ - libidn2=2.3.4=h5eee18b_0
110
+ - libnghttp2=1.46.0=hce63b2e_0
111
+ - libnpp=11.7.4.75=0
112
+ - libnvjpeg=11.8.0.2=0
113
+ - libpng=1.6.39=h5eee18b_0
114
+ - libprotobuf=3.20.3=he621ea3_0
115
+ - librsvg=2.54.4=h36cc946_2
116
+ - libssh2=1.10.0=h8f2d780_0
117
+ - libstdcxx-ng=11.2.0=h1234567_1
118
+ - libtasn1=4.19.0=h5eee18b_0
119
+ - libthrift=0.15.0=he6d91bd_0
120
+ - libtiff=4.5.0=h6a678d5_2
121
+ - libtool=2.4.6=h6a678d5_1009
122
+ - libunistring=0.9.10=h27cfd23_0
123
+ - libuuid=1.41.5=h5eee18b_0
124
+ - libwebp=1.2.4=h11a3e52_1
125
+ - libwebp-base=1.2.4=h5eee18b_1
126
+ - libxcb=1.15=h7f8727e_0
127
+ - libxml2=2.9.14=h74e7548_0
128
+ - lz4-c=1.9.4=h6a678d5_0
129
+ - mkl=2023.1.0=h6d00ec8_46342
130
+ - mkl-service=2.4.0=py39h5eee18b_1
131
+ - mkl_fft=1.3.6=py39h417a72b_1
132
+ - mkl_random=1.2.2=py39h417a72b_1
133
+ - mpc=1.1.0=h10f8cd9_1
134
+ - mpfr=4.0.2=hb69a4c5_1
135
+ - mpmath=1.3.0=py39h06a4308_0
136
+ - ncurses=6.4=h6a678d5_0
137
+ - nettle=3.7.3=hbbd107a_1
138
+ - networkx=3.1=py39h06a4308_0
139
+ - ninja-base=1.10.2=hd09550d_5
140
+ - numexpr=2.8.4=py39hc78ab66_1
141
+ - numpy-base=1.25.0=py39hb5e798b_0
142
+ - openh264=2.1.1=h4ff587b_0
143
+ - openjpeg=2.4.0=h3ad879b_0
144
+ - openssl=1.1.1w=h7f8727e_0
145
+ - orc=1.7.4=hb3bc3d3_1
146
+ - pango=1.50.7=h05da053_0
147
+ - pcre=8.45=h295c915_0
148
+ - pcre2=10.37=he7ceb23_1
149
+ - perl=5.34.0=h5eee18b_2
150
+ - pip=23.0.1=py39h06a4308_0
151
+ - pixman=0.40.0=h7f8727e_1
152
+ - poppler=0.81.0=h01f5e8b_2
153
+ - poppler-data=0.4.11=h06a4308_1
154
+ - pycparser=2.21=pyhd3eb1b0_0
155
+ - pyopenssl=23.2.0=py39h06a4308_0
156
+ - pysocks=1.7.1=py39h06a4308_0
157
+ - python=3.9.16=h7a1cb2a_2
158
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
159
+ - python-devtools=0.11.0=pyhd8ed1ab_0
160
+ - python-graphviz=0.20.1=py39h06a4308_0
161
+ - python-xxhash=2.0.2=py39h5eee18b_1
162
+ - python_abi=3.9=2_cp39
163
+ - pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
164
+ - pytorch-cuda=11.7=h778d358_5
165
+ - pytorch-mutex=1.0=cuda
166
+ - pytz=2023.3=pyhd8ed1ab_0
167
+ - pyyaml=6.0=py39hb9d737c_4
168
+ - re2=2022.04.01=h27087fc_0
169
+ - readline=8.2=h5eee18b_0
170
+ - sacremoses=0.0.53=pyhd8ed1ab_0
171
+ - setuptools=66.0.0=py39h06a4308_0
172
+ - six=1.16.0=pyh6c4a22f_0
173
+ - snappy=1.1.9=h295c915_0
174
+ - sqlite=3.41.2=h5eee18b_0
175
+ - sympy=1.11.1=py39h06a4308_0
176
+ - tbb=2021.8.0=hdb19cb5_0
177
+ - tk=8.6.12=h1ccaba5_0
178
+ - tmux=3.2a=h385fc29_0
179
+ - tokenizers=0.13.2=py39he7d60b5_1
180
+ - torchtriton=2.0.0=py39
181
+ - transformers=4.28.1=pyhd8ed1ab_0
182
+ - typing_extensions=4.4.0=py39h06a4308_0
183
+ - utf8proc=2.6.1=h27cfd23_0
184
+ - wheel=0.38.4=py39h06a4308_0
185
+ - xz=5.2.10=h5eee18b_1
186
+ - yaml=0.2.5=h7f98852_2
187
+ - zlib=1.2.13=h5eee18b_0
188
+ - zstd=1.5.5=hc292b87_0
189
+ - pip:
190
+ - absl-py==2.0.0
191
+ - accelerate==0.16.0
192
+ - aiofiles==23.1.0
193
+ - aiohttp==3.8.4
194
+ - albumentations==1.3.1
195
+ - altair==4.2.2
196
+ - antlr4-python3-runtime==4.9.3
197
+ - anyio==3.6.2
198
+ - appdirs==1.4.4
199
+ - apptools==5.2.1
200
+ - argon2-cffi==21.3.0
201
+ - argon2-cffi-bindings==21.2.0
202
+ - argparse==1.4.0
203
+ - arrow==1.2.3
204
+ - attrs==22.2.0
205
+ - backcall==0.2.0
206
+ - batchgenerators==0.25
207
+ - beautifulsoup4==4.12.2
208
+ - bitsandbytes==0.37.0
209
+ - bitsandbytes-cuda117==0.26.0.post2
210
+ - bleach==6.0.0
211
+ - blis==0.7.9
212
+ - braceexpand==0.1.7
213
+ - brotli==1.1.0
214
+ - cachetools==5.3.1
215
+ - catalogue==2.0.8
216
+ - cchardet==2.1.7
217
+ - chardet==5.1.0
218
+ - charset-normalizer==3.1.0
219
+ - click==8.1.3
220
+ - cmake==3.26.3
221
+ - comm==0.1.3
222
+ - commonmark==0.9.1
223
+ - conda-pack==0.6.0
224
+ - confection==0.0.4
225
+ - configobj==5.0.8
226
+ - conllu==4.5.3
227
+ - contourpy==1.0.7
228
+ - cpufeature==0.2.1
229
+ - cycler==0.11.0
230
+ - cymem==2.0.7
231
+ - debugpy==1.6.7
232
+ - decorator==5.1.1
233
+ - decord==0.6.0
234
+ - defusedxml==0.7.1
235
+ - deprecated==1.2.14
236
+ - docker-pycreds==0.4.0
237
+ - efficientnet-pytorch==0.7.1
238
+ - einops==0.6.1
239
+ - einops-exts==0.0.4
240
+ - entrypoints==0.4
241
+ - envisage==7.0.3
242
+ - et-xmlfile==1.1.0
243
+ - exceptiongroup==1.2.0
244
+ - fairscale==0.4.13
245
+ - fastapi==0.95.1
246
+ - fastjsonschema==2.16.3
247
+ - ffmpy==0.3.0
248
+ - fonttools==4.38.0
249
+ - fqdn==1.5.1
250
+ - fschat==0.1.10
251
+ - fsspec==2023.4.0
252
+ - ftfy==6.1.1
253
+ - future==0.18.3
254
+ - gitdb==4.0.10
255
+ - gitpython==3.1.31
256
+ - google-auth==2.23.3
257
+ - google-auth-oauthlib==1.0.0
258
+ - gradio==3.23.0
259
+ - gradio-client==0.0.8
260
+ - grpcio==1.59.0
261
+ - h11==0.14.0
262
+ - h5py==3.9.0
263
+ - hjson==3.1.0
264
+ - httpcore==0.17.0
265
+ - httpx==0.24.0
266
+ - humanize==4.8.0
267
+ - hyperlink==21.0.0
268
+ - imageio==2.33.0
269
+ - importlib-metadata==6.6.0
270
+ - importlib-resources==5.12.0
271
+ - inflate64==1.0.0
272
+ - iniconfig==2.0.0
273
+ - iopath==0.1.10
274
+ - ipdb==0.13.13
275
+ - ipykernel==6.22.0
276
+ - ipython==8.12.0
277
+ - ipython-genutils==0.2.0
278
+ - isoduration==20.11.0
279
+ - jedi==0.18.2
280
+ - joblib==1.2.0
281
+ - jsonpointer==2.3
282
+ - jsonschema==4.17.3
283
+ - jupyter-client==8.2.0
284
+ - jupyter-core==5.3.0
285
+ - jupyter-events==0.6.3
286
+ - jupyter-server==2.5.0
287
+ - jupyter-server-terminals==0.4.4
288
+ - jupyterlab-pygments==0.2.2
289
+ - kiwisolver==1.4.4
290
+ - langcodes==3.3.0
291
+ - lazy-loader==0.3
292
+ - linecache2==1.0.0
293
+ - linkify-it-py==2.0.0
294
+ - lit==16.0.2
295
+ - llvmlite==0.39.1
296
+ - markdown==3.5
297
+ - markdown-it-py==2.2.0
298
+ - markdown2==2.4.8
299
+ - markupsafe==2.1.2
300
+ - matplotlib==3.7.0
301
+ - matplotlib-inline==0.1.6
302
+ - mdit-py-plugins==0.3.3
303
+ - mdurl==0.1.2
304
+ - mistune==2.0.5
305
+ - multidict==6.0.4
306
+ - multiprocess==0.70.15
307
+ - multivolumefile==0.2.3
308
+ - munch==4.0.0
309
+ - murmurhash==1.0.9
310
+ - mypy-extensions==1.0.0
311
+ - nbclassic==0.5.6
312
+ - nbclient==0.7.4
313
+ - nbconvert==7.3.1
314
+ - nbformat==5.8.0
315
+ - nest-asyncio==1.5.6
316
+ - nh3==0.2.11
317
+ - nibabel==5.1.0
318
+ - ninja==1.11.1
319
+ - nltk==3.8.1
320
+ - nmslib==2.1.1
321
+ - notebook==6.5.4
322
+ - notebook-shim==0.2.3
323
+ - numba==0.56.4
324
+ - numpy==1.23.5
325
+ - nvidia-cublas-cu11==11.10.3.66
326
+ - nvidia-cuda-cupti-cu11==11.7.101
327
+ - nvidia-cuda-nvrtc-cu11==11.7.99
328
+ - nvidia-cuda-runtime-cu11==11.7.99
329
+ - nvidia-cudnn-cu11==8.5.0.96
330
+ - nvidia-cufft-cu11==10.9.0.58
331
+ - nvidia-curand-cu11==10.2.10.91
332
+ - nvidia-cusolver-cu11==11.4.0.1
333
+ - nvidia-cusparse-cu11==11.7.4.91
334
+ - nvidia-nccl-cu11==2.14.3
335
+ - nvidia-nvtx-cu11==11.7.91
336
+ - oauthlib==3.2.2
337
+ - omegaconf==2.3.0
338
+ - open-clip-torch==2.20.0
339
+ - openai==0.27.0
340
+ - opencv-python==4.7.0.72
341
+ - opencv-python-headless==4.8.0.74
342
+ - openpyxl==3.1.2
343
+ - orjson==3.8.11
344
+ - packaging==23.0
345
+ - pandas==2.0.1
346
+ - pandocfilters==1.5.0
347
+ - parso==0.8.3
348
+ - pathtools==0.1.2
349
+ - pathy==0.10.1
350
+ - peft==0.2.0
351
+ - pexpect==4.8.0
352
+ - pickleshare==0.7.5
353
+ - pillow==9.5.0
354
+ - platformdirs==3.5.0
355
+ - pluggy==1.3.0
356
+ - portalocker==2.7.0
357
+ - preshed==3.0.8
358
+ - pretrainedmodels==0.7.4
359
+ - prometheus-client==0.16.0
360
+ - prompt-toolkit==3.0.38
361
+ - protobuf==3.20.3
362
+ - psutil==5.9.4
363
+ - ptyprocess==0.7.0
364
+ - pure-eval==0.2.2
365
+ - py-cpuinfo==9.0.0
366
+ - py-rsync==0.0.1a0.dev0
367
+ - py7zr==0.20.8
368
+ - pyarrow==11.0.0
369
+ - pyasn1==0.5.0
370
+ - pyasn1-modules==0.3.0
371
+ - pybcj==1.0.2
372
+ - pybind11==2.6.1
373
+ - pycocoevalcap==1.2
374
+ - pycocotools==2.0.6
375
+ - pycryptodomex==3.19.1
376
+ - pydantic==1.10.7
377
+ - pydub==0.25.1
378
+ - pyface==8.0.0
379
+ - pygments==2.15.1
380
+ - pynndescent==0.5.10
381
+ - pyparsing==3.0.9
382
+ - pyppmd==1.1.0
383
+ - pyqt5==5.15.10
384
+ - pyqt5-qt5==5.15.2
385
+ - pyqt5-sip==12.13.0
386
+ - pyrsistent==0.19.3
387
+ - pysbd==0.3.4
388
+ - pytest==7.4.3
389
+ - python-json-logger==2.0.7
390
+ - python-multipart==0.0.6
391
+ - python-polylabel==0.6
392
+ - python-rsync==0.1.0
393
+ - pyzmq==25.0.2
394
+ - pyzstd==0.15.9
395
+ - qudida==0.0.4
396
+ - regex==2022.10.31
397
+ - requests==2.29.0
398
+ - requests-oauthlib==1.3.1
399
+ - rfc3339-validator==0.1.4
400
+ - rfc3986-validator==0.1.1
401
+ - rich==12.6.0
402
+ - rsa==4.9
403
+ - safetensors==0.3.1
404
+ - scikit-image==0.22.0
405
+ - scipy==1.10.1
406
+ - scispacy==0.5.2
407
+ - segmentation-models-pytorch==0.3.3
408
+ - semantic-version==2.10.0
409
+ - send2trash==1.8.2
410
+ - sentence-transformers==2.2.2
411
+ - sentencepiece==0.1.98
412
+ - sentry-sdk==1.21.0
413
+ - setproctitle==1.3.2
414
+ - shapely==2.0.2
415
+ - shellingham==1.5.4
416
+ - shortuuid==1.0.11
417
+ - simpleitk==2.2.1
418
+ - smart-open==6.3.0
419
+ - smmap==5.0.0
420
+ - sniffio==1.3.0
421
+ - soupsieve==2.4.1
422
+ - spacy==3.4.4
423
+ - spacy-legacy==3.0.12
424
+ - spacy-loggers==1.0.4
425
+ - srsly==2.4.6
426
+ - stack-data==0.6.2
427
+ - starlette==0.26.1
428
+ - surface-distance-based-measures==0.1
429
+ - svgwrite==1.4.3
430
+ - swig==4.1.1
431
+ - tenacity==8.2.2
432
+ - tensorboard==2.14.1
433
+ - tensorboard-data-server==0.7.1
434
+ - terminado==0.17.1
435
+ - texttable==1.7.0
436
+ - thinc==8.1.9
437
+ - threadpoolctl==3.1.0
438
+ - tifffile==2023.9.26
439
+ - timm==0.9.2
440
+ - tinycss2==1.2.1
441
+ - tomli==2.0.1
442
+ - toolz==0.12.0
443
+ - torchio==0.19.2
444
+ - torchvision==0.15.2
445
+ - tornado==6.3.1
446
+ - tqdm==4.64.1
447
+ - traceback2==1.4.0
448
+ - traitlets==5.9.0
449
+ - traits==6.4.3
450
+ - traitsui==8.0.0
451
+ - triton==2.0.0
452
+ - typer==0.7.0
453
+ - typing-extensions==4.5.0
454
+ - typing-inspect==0.8.0
455
+ - tzdata==2023.3
456
+ - uc-micro-py==1.0.1
457
+ - umap-learn==0.5.3
458
+ - unittest2==1.1.0
459
+ - unzip==1.0.0
460
+ - uri-template==1.2.0
461
+ - urllib3==1.26.15
462
+ - uvicorn==0.22.0
463
+ - vtk==9.3.0
464
+ - wandb==0.15.0
465
+ - wasabi==0.10.1
466
+ - wavedrom==2.0.3.post3
467
+ - wcwidth==0.2.6
468
+ - webcolors==1.13
469
+ - webdataset==0.2.48
470
+ - webencodings==0.5.1
471
+ - websocket-client==1.5.1
472
+ - websockets==11.0.2
473
+ - werkzeug==3.0.0
474
+ - wrapt==1.16.0
475
+ - xxhash==3.3.0
476
+ - yarl==1.8.2
477
+ - zipp==3.14.0
478
+ prefix: /home/zhouhy/anaconda3/envs/medversa
479
+
inference.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+
3
+ # --- Launch Model ---
4
+ device = 'cuda:0'
5
+ model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
6
+ model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
7
+ model.eval()
8
+
9
+ # --- Define examples ---
10
+ examples = [
11
+ [
12
+ ["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
13
+ "Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
14
+ "How would you characterize the findings from <img0>?",
15
+ "cxr",
16
+ "report generation",
17
+ ],
18
+ [
19
+ ["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"],
20
+ "Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.",
21
+ "How would you characterize the findings from <img0>?",
22
+ "cxr",
23
+ "report generation",
24
+ ],
25
+ [
26
+ ["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"],
27
+ "Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.",
28
+ "How would you characterize the findings from <img0><img1>?",
29
+ "cxr",
30
+ "report generation",
31
+ ],
32
+ [
33
+ ["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"],
34
+ "Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.",
35
+ "How would you characterize the findings from <img0>?",
36
+ "cxr",
37
+ "report generation",
38
+ ],
39
+ [
40
+ ["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"],
41
+ "Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.",
42
+ "How would you characterize the findings from <img0>?",
43
+ "cxr",
44
+ "report generation",
45
+ ],
46
+ [
47
+ ["./demo_ex/ISIC_0032258.jpg"],
48
+ "Age:70.\nGender:female.\nLocation:back.",
49
+ "What is primary diagnosis?",
50
+ "derm",
51
+ "classification",
52
+ ],
53
+ [
54
+ ["./demo_ex/ISIC_0032258.jpg"],
55
+ "Age:70.\nGender:female.\nLocation:back.",
56
+ "Segment the lesion.",
57
+ "derm",
58
+ "segmentation",
59
+ ],
60
+ [
61
+ ["./demo_ex/Case_01013_0000.nii.gz"],
62
+ "",
63
+ "Segment the liver.",
64
+ "ct",
65
+ "segmentation",
66
+ ],
67
+ [
68
+ ["./demo_ex/Case_00840_0000.nii.gz"],
69
+ "",
70
+ "Segment the liver.",
71
+ "ct",
72
+ "segmentation",
73
+ ],
74
+ ]
75
+ # --- Define hyperparams ---
76
+ num_beams = 1
77
+ do_sample = True
78
+ min_length = 1
79
+ top_p = 0.9
80
+ repetition_penalty = 1
81
+ length_penalty = 1
82
+ temperature = 0.1
83
+
84
+ # --- Generate a report for an chest X-ray image ---
85
+ index = 0
86
+ demo_ex = examples[index]
87
+ images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
88
+ seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
89
+ print(output_text)
90
+
91
+ # --- Segment the lesion in the dermatology image ---
92
+ index = 6
93
+ demo_ex = examples[index]
94
+ images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
95
+ seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
96
+ print(output_text)
97
+ print(seg_mask_2d[0].shape) # H, W
98
+
99
+ # --- Segment the liver in the abdomen CT scan ---
100
+ index = -2
101
+ demo_ex = examples[index]
102
+ images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
103
+ seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
104
+ print(output_text)
105
+ print(len(seg_mask_3d)) # Number of slices
106
+ print(seg_mask_3d[0].shape) # H, W
107
+
medomni/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from medomni.common.registry import registry
14
+
15
+ from medomni.datasets.builders import *
16
+ from medomni.models import *
17
+ from medomni.processors import *
18
+ from medomni.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
medomni/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.69 kB). View file
 
medomni/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.01 kB). View file
 
medomni/common/__init__.py ADDED
File without changes
medomni/common/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (147 Bytes). View file
 
medomni/common/__pycache__/config.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
medomni/common/__pycache__/dist_utils.cpython-39.pyc ADDED
Binary file (3.77 kB). View file
 
medomni/common/__pycache__/logger.cpython-39.pyc ADDED
Binary file (6.46 kB). View file
 
medomni/common/__pycache__/optims.cpython-39.pyc ADDED
Binary file (2.99 kB). View file
 
medomni/common/__pycache__/registry.cpython-39.pyc ADDED
Binary file (8.99 kB). View file
 
medomni/common/__pycache__/utils.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
medomni/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from medomni.common.registry import registry
14
+ import ipdb
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hierarchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hierarchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
medomni/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
medomni/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
medomni/common/logger.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from medomni.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avgdata_time:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ # assert isinstance(v, (float, int))
92
+ if isinstance(v, (float, int)):
93
+ self.meters[k].update(v)
94
+ else:
95
+ self.meters[k] = v
96
+
97
+ def __getattr__(self, attr):
98
+ if attr in self.meters:
99
+ return self.meters[attr]
100
+ if attr in self.__dict__:
101
+ return self.__dict__[attr]
102
+ raise AttributeError(
103
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
104
+ )
105
+
106
+ def __str__(self):
107
+ loss_str = []
108
+ for name, meter in self.meters.items():
109
+ loss_str.append("{}: {}".format(name, str(meter)))
110
+ return self.delimiter.join(loss_str)
111
+
112
+ def global_avg(self):
113
+ loss_str = []
114
+ for name, meter in self.meters.items():
115
+ if not isinstance(meter, str):
116
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
117
+ return self.delimiter.join(loss_str)
118
+
119
+ def synchronize_between_processes(self):
120
+ for meter in self.meters.values():
121
+ if not isinstance(meter, str):
122
+ meter.synchronize_between_processes()
123
+
124
+ def add_meter(self, name, meter):
125
+ self.meters[name] = meter
126
+
127
+ def log_every(self, iterable, print_freq, header=None):
128
+ i = 0
129
+ if not header:
130
+ header = ""
131
+ start_time = time.time()
132
+ end = time.time()
133
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
134
+ data_time = SmoothedValue(fmt="{avg:.4f}")
135
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
136
+ log_msg = [
137
+ header,
138
+ "[{0" + space_fmt + "}/{1}]",
139
+ "eta: {eta}",
140
+ "{meters}",
141
+ "time: {time}",
142
+ "data: {data}",
143
+ ]
144
+ if torch.cuda.is_available():
145
+ log_msg.append("max mem: {memory:.0f}")
146
+ log_msg = self.delimiter.join(log_msg)
147
+ MB = 1024.0 * 1024.0
148
+ for obj in iterable:
149
+ data_time.update(time.time() - end)
150
+ yield obj
151
+ iter_time.update(time.time() - end)
152
+ if i % print_freq == 0 or i == len(iterable) - 1:
153
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
154
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
155
+ if torch.cuda.is_available():
156
+ print(
157
+ log_msg.format(
158
+ i,
159
+ len(iterable),
160
+ eta=eta_string,
161
+ meters=str(self),
162
+ time=str(iter_time),
163
+ data=str(data_time),
164
+ memory=torch.cuda.max_memory_allocated() / MB,
165
+ )
166
+ )
167
+ else:
168
+ print(
169
+ log_msg.format(
170
+ i,
171
+ len(iterable),
172
+ eta=eta_string,
173
+ meters=str(self),
174
+ time=str(iter_time),
175
+ data=str(data_time),
176
+ )
177
+ )
178
+ i += 1
179
+ end = time.time()
180
+ total_time = time.time() - start_time
181
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
182
+ print(
183
+ "{} Total time: {} ({:.4f} s / it)".format(
184
+ header, total_time_str, total_time / len(iterable)
185
+ )
186
+ )
187
+
188
+
189
+ class AttrDict(dict):
190
+ def __init__(self, *args, **kwargs):
191
+ super(AttrDict, self).__init__(*args, **kwargs)
192
+ self.__dict__ = self
193
+
194
+
195
+ def setup_logger():
196
+ logging.basicConfig(
197
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
198
+ format="%(asctime)s [%(levelname)s] %(message)s",
199
+ handlers=[logging.StreamHandler()],
200
+ )
medomni/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from medomni.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
medomni/common/registry.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ class Registry:
9
+ mapping = {
10
+ "builder_name_mapping": {},
11
+ "task_name_mapping": {},
12
+ "processor_name_mapping": {},
13
+ "model_name_mapping": {},
14
+ "lr_scheduler_name_mapping": {},
15
+ "runner_name_mapping": {},
16
+ "state": {},
17
+ "paths": {},
18
+ }
19
+
20
+ @classmethod
21
+ def register_builder(cls, name):
22
+ r"""Register a dataset builder to registry with key 'name'
23
+
24
+ Args:
25
+ name: Key with which the builder will be registered.
26
+
27
+ Usage:
28
+
29
+ from medomni.common.registry import registry
30
+ from medomni.datasets.base_dataset_builder import BaseDatasetBuilder
31
+ """
32
+
33
+ def wrap(builder_cls):
34
+ from medomni.datasets.builders.base_dataset_builder import BaseDatasetBuilder
35
+
36
+ assert issubclass(
37
+ builder_cls, BaseDatasetBuilder
38
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
39
+ builder_cls
40
+ )
41
+ if name in cls.mapping["builder_name_mapping"]:
42
+ raise KeyError(
43
+ "Name '{}' already registered for {}.".format(
44
+ name, cls.mapping["builder_name_mapping"][name]
45
+ )
46
+ )
47
+ cls.mapping["builder_name_mapping"][name] = builder_cls
48
+ return builder_cls
49
+
50
+ return wrap
51
+
52
+ @classmethod
53
+ def register_task(cls, name):
54
+ r"""Register a task to registry with key 'name'
55
+
56
+ Args:
57
+ name: Key with which the task will be registered.
58
+
59
+ Usage:
60
+
61
+ from medomni.common.registry import registry
62
+ """
63
+
64
+ def wrap(task_cls):
65
+ from medomni.tasks.base_task import BaseTask
66
+
67
+ assert issubclass(
68
+ task_cls, BaseTask
69
+ ), "All tasks must inherit BaseTask class"
70
+ if name in cls.mapping["task_name_mapping"]:
71
+ raise KeyError(
72
+ "Name '{}' already registered for {}.".format(
73
+ name, cls.mapping["task_name_mapping"][name]
74
+ )
75
+ )
76
+ cls.mapping["task_name_mapping"][name] = task_cls
77
+ return task_cls
78
+
79
+ return wrap
80
+
81
+ @classmethod
82
+ def register_model(cls, name):
83
+ r"""Register a task to registry with key 'name'
84
+
85
+ Args:
86
+ name: Key with which the task will be registered.
87
+
88
+ Usage:
89
+
90
+ from medomni.common.registry import registry
91
+ """
92
+
93
+ def wrap(model_cls):
94
+ from medomni.models import BaseModel
95
+
96
+ assert issubclass(
97
+ model_cls, BaseModel
98
+ ), "All models must inherit BaseModel class"
99
+ if name in cls.mapping["model_name_mapping"]:
100
+ raise KeyError(
101
+ "Name '{}' already registered for {}.".format(
102
+ name, cls.mapping["model_name_mapping"][name]
103
+ )
104
+ )
105
+ cls.mapping["model_name_mapping"][name] = model_cls
106
+ return model_cls
107
+
108
+ return wrap
109
+
110
+ @classmethod
111
+ def register_processor(cls, name):
112
+ r"""Register a processor to registry with key 'name'
113
+
114
+ Args:
115
+ name: Key with which the task will be registered.
116
+
117
+ Usage:
118
+
119
+ from medomni.common.registry import registry
120
+ """
121
+
122
+ def wrap(processor_cls):
123
+ from medomni.processors import BaseProcessor
124
+
125
+ assert issubclass(
126
+ processor_cls, BaseProcessor
127
+ ), "All processors must inherit BaseProcessor class"
128
+ if name in cls.mapping["processor_name_mapping"]:
129
+ raise KeyError(
130
+ "Name '{}' already registered for {}.".format(
131
+ name, cls.mapping["processor_name_mapping"][name]
132
+ )
133
+ )
134
+ cls.mapping["processor_name_mapping"][name] = processor_cls
135
+ return processor_cls
136
+
137
+ return wrap
138
+
139
+ @classmethod
140
+ def register_lr_scheduler(cls, name):
141
+ r"""Register a model to registry with key 'name'
142
+
143
+ Args:
144
+ name: Key with which the task will be registered.
145
+
146
+ Usage:
147
+
148
+ from medomni.common.registry import registry
149
+ """
150
+
151
+ def wrap(lr_sched_cls):
152
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
153
+ raise KeyError(
154
+ "Name '{}' already registered for {}.".format(
155
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
156
+ )
157
+ )
158
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
159
+ return lr_sched_cls
160
+
161
+ return wrap
162
+
163
+ @classmethod
164
+ def register_runner(cls, name):
165
+ r"""Register a model to registry with key 'name'
166
+
167
+ Args:
168
+ name: Key with which the task will be registered.
169
+
170
+ Usage:
171
+
172
+ from medomni.common.registry import registry
173
+ """
174
+
175
+ def wrap(runner_cls):
176
+ if name in cls.mapping["runner_name_mapping"]:
177
+ raise KeyError(
178
+ "Name '{}' already registered for {}.".format(
179
+ name, cls.mapping["runner_name_mapping"][name]
180
+ )
181
+ )
182
+ cls.mapping["runner_name_mapping"][name] = runner_cls
183
+ return runner_cls
184
+
185
+ return wrap
186
+
187
+ @classmethod
188
+ def register_path(cls, name, path):
189
+ r"""Register a path to registry with key 'name'
190
+
191
+ Args:
192
+ name: Key with which the path will be registered.
193
+
194
+ Usage:
195
+
196
+ from medomni.common.registry import registry
197
+ """
198
+ assert isinstance(path, str), "All path must be str."
199
+ if name in cls.mapping["paths"]:
200
+ raise KeyError("Name '{}' already registered.".format(name))
201
+ cls.mapping["paths"][name] = path
202
+
203
+ @classmethod
204
+ def register(cls, name, obj):
205
+ r"""Register an item to registry with key 'name'
206
+
207
+ Args:
208
+ name: Key with which the item will be registered.
209
+
210
+ Usage::
211
+
212
+ from medomni.common.registry import registry
213
+
214
+ registry.register("config", {})
215
+ """
216
+ path = name.split(".")
217
+ current = cls.mapping["state"]
218
+
219
+ for part in path[:-1]:
220
+ if part not in current:
221
+ current[part] = {}
222
+ current = current[part]
223
+
224
+ current[path[-1]] = obj
225
+
226
+ # @classmethod
227
+ # def get_trainer_class(cls, name):
228
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
229
+
230
+ @classmethod
231
+ def get_builder_class(cls, name):
232
+ return cls.mapping["builder_name_mapping"].get(name, None)
233
+
234
+ @classmethod
235
+ def get_model_class(cls, name):
236
+ return cls.mapping["model_name_mapping"].get(name, None)
237
+
238
+ @classmethod
239
+ def get_task_class(cls, name):
240
+ return cls.mapping["task_name_mapping"].get(name, None)
241
+
242
+ @classmethod
243
+ def get_processor_class(cls, name):
244
+ return cls.mapping["processor_name_mapping"].get(name, None)
245
+
246
+ @classmethod
247
+ def get_lr_scheduler_class(cls, name):
248
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
249
+
250
+ @classmethod
251
+ def get_runner_class(cls, name):
252
+ return cls.mapping["runner_name_mapping"].get(name, None)
253
+
254
+ @classmethod
255
+ def list_runners(cls):
256
+ return sorted(cls.mapping["runner_name_mapping"].keys())
257
+
258
+ @classmethod
259
+ def list_models(cls):
260
+ return sorted(cls.mapping["model_name_mapping"].keys())
261
+
262
+ @classmethod
263
+ def list_tasks(cls):
264
+ return sorted(cls.mapping["task_name_mapping"].keys())
265
+
266
+ @classmethod
267
+ def list_processors(cls):
268
+ return sorted(cls.mapping["processor_name_mapping"].keys())
269
+
270
+ @classmethod
271
+ def list_lr_schedulers(cls):
272
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
273
+
274
+ @classmethod
275
+ def list_datasets(cls):
276
+ return sorted(cls.mapping["builder_name_mapping"].keys())
277
+
278
+ @classmethod
279
+ def get_path(cls, name):
280
+ return cls.mapping["paths"].get(name, None)
281
+
282
+ @classmethod
283
+ def get(cls, name, default=None, no_warning=False):
284
+ r"""Get an item from registry with key 'name'
285
+
286
+ Args:
287
+ name (string): Key whose value needs to be retrieved.
288
+ default: If passed and key is not in registry, default value will
289
+ be returned with a warning. Default: None
290
+ no_warning (bool): If passed as True, warning when key doesn't exist
291
+ will not be generated. Useful for MMF's
292
+ internal operations. Default: False
293
+ """
294
+ original_name = name
295
+ name = name.split(".")
296
+ value = cls.mapping["state"]
297
+ for subname in name:
298
+ value = value.get(subname, default)
299
+ if value is default:
300
+ break
301
+
302
+ if (
303
+ "writer" in cls.mapping["state"]
304
+ and value == default
305
+ and no_warning is False
306
+ ):
307
+ cls.mapping["state"]["writer"].warning(
308
+ "Key {} is not present in registry, returning default value "
309
+ "of {}".format(original_name, default)
310
+ )
311
+ return value
312
+
313
+ @classmethod
314
+ def unregister(cls, name):
315
+ r"""Remove an item from registry with key 'name'
316
+
317
+ Args:
318
+ name: Key which needs to be removed.
319
+ Usage::
320
+
321
+ from mmf.common.registry import registry
322
+
323
+ config = registry.unregister("config")
324
+ """
325
+ return cls.mapping["state"].pop(name, None)
326
+
327
+ registry = Registry()
medomni/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from medomni.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
medomni/configs/datasets/medinterp/align.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ med:
3
+ data_type: images
4
+ build_info:
5
+ storage: json_files/medinterp
medomni/configs/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ env:
2
+ # For default users
3
+ # cache_root: "cache"
4
+ # For internal use with persistent storage
5
+ cache_root: "/export/home/.cache/medomni"
medomni/configs/models/medomni.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: medomni
3
+
4
+ # vision encoder
5
+ precision: "fp16"
6
+ freeze_vit: True
7
+
8
+ # Llama
9
+ llama_model: "meta-llama/Llama-2-7b-chat-hf"
10
+
11
+ # generation configs
12
+ prompt: ""
medomni/conversation/__init__.py ADDED
File without changes
medomni/conversation/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
medomni/conversation/__pycache__/conversation.cpython-39.pyc ADDED
Binary file (7.3 kB). View file
 
medomni/conversation/conversation.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+ from medomni.common.registry import registry
14
+ import ipdb
15
+
16
+
17
+ class SeparatorStyle(Enum):
18
+ """Different separator style."""
19
+ SINGLE = auto()
20
+ TWO = auto()
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Conversation:
25
+ """A class that keeps all conversation history."""
26
+ system: str
27
+ roles: List[str]
28
+ messages: List[List[str]]
29
+ offset: int
30
+ # system_img: List[Image.Image] = []
31
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
32
+ sep: str = "###"
33
+ sep2: str = None
34
+
35
+ skip_next: bool = False
36
+ conv_id: Any = None
37
+
38
+ def get_prompt(self):
39
+ if self.sep_style == SeparatorStyle.SINGLE:
40
+ ret = self.system + self.sep
41
+ for role, message in self.messages:
42
+ if message:
43
+ ret += role + ": " + message + self.sep
44
+ else:
45
+ ret += role + ":"
46
+ return ret
47
+ elif self.sep_style == SeparatorStyle.TWO:
48
+ seps = [self.sep, self.sep2]
49
+ ret = self.system + seps[0]
50
+ for i, (role, message) in enumerate(self.messages):
51
+ if message:
52
+ ret += role + ": " + message + seps[i % 2]
53
+ else:
54
+ ret += role + ":"
55
+ return ret
56
+ else:
57
+ raise ValueError(f"Invalid style: {self.sep_style}")
58
+
59
+ def append_message(self, role, message):
60
+ self.messages.append([role, message])
61
+
62
+ def to_gradio_chatbot(self):
63
+ ret = []
64
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
65
+ if i % 2 == 0:
66
+ ret.append([msg, None])
67
+ else:
68
+ ret[-1][-1] = msg
69
+ return ret
70
+
71
+ def copy(self):
72
+ return Conversation(
73
+ system=self.system,
74
+ roles=self.roles,
75
+ messages=[[x, y] for x, y in self.messages],
76
+ offset=self.offset,
77
+ sep_style=self.sep_style,
78
+ sep=self.sep,
79
+ sep2=self.sep2,
80
+ conv_id=self.conv_id)
81
+
82
+ def dict(self):
83
+ return {
84
+ "system": self.system,
85
+ # "system_img": self.system_img,
86
+ "roles": self.roles,
87
+ "messages": self.messages,
88
+ "offset": self.offset,
89
+ "sep": self.sep,
90
+ "sep2": self.sep2,
91
+ "conv_id": self.conv_id,
92
+ }
93
+
94
+
95
+ class StoppingCriteriaSub(StoppingCriteria):
96
+
97
+ def __init__(self, stops=[], encounters=1):
98
+ super().__init__()
99
+ self.stops = stops
100
+
101
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
102
+ for stop in self.stops:
103
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
104
+ return True
105
+
106
+ return False
107
+
108
+
109
+ CONV_VISION = Conversation(
110
+ system="Give the following image: <Img>ImageContent</Img>. "
111
+ "You will be able to see the image once I provide it to you. Act as a clinician and answer my questions.",
112
+ # "You will be able to see the image once I provide it to you. Please answer my questions.",
113
+ #system="",
114
+ roles=("Human", "Assistant"),
115
+ messages=[],
116
+ offset=2,
117
+ sep_style=SeparatorStyle.SINGLE,
118
+ sep="###",
119
+ )
120
+
121
+ class Chat:
122
+ def __init__(self, model, vis_processor, device='cuda:0'):
123
+ self.device = device
124
+ self.model = model
125
+ self.vis_processor = vis_processor
126
+ stop_words_ids = [torch.tensor([835]).to(self.device),
127
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
128
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129
+
130
+ def ask(self, text, conv):
131
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
132
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
133
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
134
+ else:
135
+ conv.append_message(conv.roles[0], text) # commented by hy on 5.9
136
+
137
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
138
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
139
+ conv.append_message(conv.roles[1], None)
140
+ embs = self.get_context_emb(conv, img_list)
141
+
142
+ current_max_len = embs.shape[1] + max_new_tokens
143
+ if current_max_len - max_length > 0:
144
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
145
+ 'The model will not see the contexts outside the range.')
146
+ begin_idx = max(0, current_max_len - max_length)
147
+
148
+ embs = embs[:, begin_idx:]
149
+
150
+ with torch.autocast("cuda"):
151
+ outputs = self.model.llama_model.generate(
152
+ inputs_embeds=embs,
153
+ max_new_tokens=max_new_tokens,
154
+ stopping_criteria=self.stopping_criteria,
155
+ num_beams=num_beams,
156
+ do_sample=True,
157
+ min_length=min_length,
158
+ top_p=top_p,
159
+ repetition_penalty=repetition_penalty,
160
+ length_penalty=length_penalty,
161
+ temperature=temperature,
162
+ )
163
+ output_token = outputs[0]
164
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
165
+ output_token = output_token[1:]
166
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
167
+ output_token = output_token[1:]
168
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
169
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
170
+ output_text = output_text.split('Assistant:')[-1].strip()
171
+ conv.messages[-1][1] = output_text # commented by hy on 5.9
172
+ #---5.9.2023---
173
+ conv.messages = []
174
+ conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
175
+ return output_text, output_token.cpu().numpy()
176
+
177
+ def upload_img(self, image, conv, img_list):
178
+ if isinstance(image, str): # is a image path
179
+ raw_image = Image.open(image).convert('RGB')
180
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
181
+ elif isinstance(image, Image.Image):
182
+ raw_image = image
183
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
184
+ elif isinstance(image, torch.Tensor):
185
+ if len(image.shape) == 3:
186
+ image = image.unsqueeze(0)
187
+ image = image.to(self.device)
188
+
189
+ image_emb, _ = self.model.encode_img(image)
190
+ img_list.append(image_emb)
191
+ conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
192
+ msg = "Received."
193
+ return msg
194
+
195
+ def get_context_emb(self, conv, img_list):
196
+ prompt = conv.get_prompt()
197
+ prompt_segs = prompt.split('<ImageHere>')
198
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
199
+ #seg_tokens = []
200
+ #for i, seg in enumerate(prompt_segs):
201
+ # if i == 1:
202
+ # prompt_ids = self.model.llama_tokenizer(
203
+ # seg,
204
+ # return_tensors="pt",
205
+ # add_special_tokens=i == 0
206
+ # ).to(self.device).input_ids
207
+ # seg_tokens.append(prompt_ids)
208
+ # else:
209
+ # prompt_ids = self.model.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
210
+ # seg_tokens.append(prompt_ids)
211
+ seg_tokens = [
212
+ self.model.llama_tokenizer(
213
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
214
+ # only add bos to the first seg
215
+ for i, seg in enumerate(prompt_segs)
216
+ ]
217
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
218
+ # seg_embs = [self.model.llama_model.model.base_model.embed_tokens(seg_t) for seg_t in seg_tokens] # LoRA
219
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
220
+ mixed_embs = torch.cat(mixed_embs, dim=1)
221
+ return mixed_embs
222
+
medomni/datasets/__init__.py ADDED
File without changes
medomni/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (149 Bytes). View file
 
medomni/datasets/__pycache__/data_utils.cpython-39.pyc ADDED
Binary file (5.95 kB). View file
 
medomni/datasets/builders/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from medomni.datasets.builders.base_dataset_builder import load_dataset_config
9
+ from medomni.datasets.builders.image_text_pair_builder import (
10
+ CCSBUBuilder,
11
+ LaionBuilder,
12
+ CCSBUAlignBuilder
13
+ )
14
+ from medomni.common.registry import registry
15
+
16
+ __all__ = [
17
+ "CCSBUBuilder",
18
+ "LaionBuilder",
19
+ "CCSBUAlignBuilder"
20
+ ]
21
+
22
+ def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
23
+ """
24
+ Example
25
+
26
+ >>> dataset = load_dataset("coco_caption", cfg=None)
27
+ >>> splits = dataset.keys()
28
+ >>> print([len(dataset[split]) for split in splits])
29
+
30
+ """
31
+ if cfg_path is None:
32
+ cfg = None
33
+ else:
34
+ cfg = load_dataset_config(cfg_path)
35
+
36
+ try:
37
+ builder = registry.get_builder_class(name)(cfg)
38
+ except TypeError:
39
+ print(
40
+ f"Dataset {name} not found. Available datasets:\n"
41
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
42
+ )
43
+ exit(1)
44
+
45
+ if vis_path is not None:
46
+ if data_type is None:
47
+ # use default data type in the config
48
+ data_type = builder.config.data_type
49
+
50
+ assert (
51
+ data_type in builder.config.build_info
52
+ ), f"Invalid data_type {data_type} for {name}."
53
+
54
+ builder.config.build_info.get(data_type).storage = vis_path
55
+
56
+ dataset = builder.build_datasets()
57
+ return dataset
58
+
59
+
60
+ class DatasetZoo:
61
+ def __init__(self) -> None:
62
+ self.dataset_zoo = {
63
+ k: list(v.DATASET_CONFIG_DICT.keys())
64
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
65
+ }
66
+
67
+ def get_names(self):
68
+ return list(self.dataset_zoo.keys())
69
+
70
+
71
+ dataset_zoo = DatasetZoo()
medomni/datasets/builders/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.35 kB). View file
 
medomni/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc ADDED
Binary file (6.06 kB). View file
 
medomni/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc ADDED
Binary file (3.82 kB). View file
 
medomni/datasets/builders/base_dataset_builder.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is from
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import warnings
13
+
14
+ from omegaconf import OmegaConf
15
+ import torch.distributed as dist
16
+ from torchvision.datasets.utils import download_url
17
+
18
+ import medomni.common.utils as utils
19
+ from medomni.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20
+ from medomni.common.registry import registry
21
+ from medomni.processors.base_processor import BaseProcessor
22
+
23
+ class BaseDatasetBuilder:
24
+ train_dataset_cls, eval_dataset_cls = None, None
25
+
26
+ def __init__(self, cfg=None):
27
+ super().__init__()
28
+
29
+ if cfg is None:
30
+ # help to create datasets from default config.
31
+ self.config = load_dataset_config(self.default_config_path())
32
+ elif isinstance(cfg, str):
33
+ self.config = load_dataset_config(cfg)
34
+ else:
35
+ # when called from task.build_dataset()
36
+ self.config = cfg
37
+
38
+ self.data_type = self.config.data_type
39
+
40
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
41
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
42
+
43
+ def build_datasets(self):
44
+ # download, split, etc...
45
+ # only called on 1 GPU/TPU in distributed
46
+
47
+ if is_main_process():
48
+ self._download_data()
49
+
50
+ if is_dist_avail_and_initialized():
51
+ dist.barrier()
52
+
53
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
54
+ logging.info("Building datasets...")
55
+ datasets = self.build() # dataset['train'/'val'/'test']
56
+
57
+ return datasets
58
+
59
+ def build_processors(self):
60
+ vis_proc_cfg = self.config.get("vis_processor")
61
+ txt_proc_cfg = self.config.get("text_processor")
62
+
63
+ if vis_proc_cfg is not None:
64
+ vis_train_cfg = vis_proc_cfg.get("train")
65
+ vis_eval_cfg = vis_proc_cfg.get("eval")
66
+
67
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
68
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
69
+
70
+ if txt_proc_cfg is not None:
71
+ txt_train_cfg = txt_proc_cfg.get("train")
72
+ txt_eval_cfg = txt_proc_cfg.get("eval")
73
+
74
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
75
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
76
+
77
+ @staticmethod
78
+ def _build_proc_from_cfg(cfg):
79
+ return (
80
+ registry.get_processor_class(cfg.name).from_config(cfg)
81
+ if cfg is not None
82
+ else None
83
+ )
84
+
85
+ @classmethod
86
+ def default_config_path(cls, type="default"):
87
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
88
+
89
+ def _download_data(self):
90
+ self._download_ann()
91
+ self._download_vis()
92
+
93
+ def _download_ann(self):
94
+ """
95
+ Download annotation files if necessary.
96
+ All the vision-language datasets should have annotations of unified format.
97
+
98
+ storage_path can be:
99
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
100
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
101
+
102
+ Local annotation paths should be relative.
103
+ """
104
+ anns = self.config.build_info.annotations
105
+
106
+ splits = anns.keys()
107
+
108
+ cache_root = registry.get_path("cache_root")
109
+
110
+ for split in splits:
111
+ info = anns[split]
112
+
113
+ urls, storage_paths = info.get("url", None), info.storage
114
+
115
+ if isinstance(urls, str):
116
+ urls = [urls]
117
+ if isinstance(storage_paths, str):
118
+ storage_paths = [storage_paths]
119
+
120
+ assert len(urls) == len(storage_paths)
121
+
122
+ for url_or_filename, storage_path in zip(urls, storage_paths):
123
+ # if storage_path is relative, make it full by prefixing with cache_root.
124
+ if not os.path.isabs(storage_path):
125
+ storage_path = os.path.join(cache_root, storage_path)
126
+
127
+ dirname = os.path.dirname(storage_path)
128
+ if not os.path.exists(dirname):
129
+ os.makedirs(dirname)
130
+
131
+ if os.path.isfile(url_or_filename):
132
+ src, dst = url_or_filename, storage_path
133
+ if not os.path.exists(dst):
134
+ shutil.copyfile(src=src, dst=dst)
135
+ else:
136
+ logging.info("Using existing file {}.".format(dst))
137
+ else:
138
+ if os.path.isdir(storage_path):
139
+ # if only dirname is provided, suffix with basename of URL.
140
+ raise ValueError(
141
+ "Expecting storage_path to be a file path, got directory {}".format(
142
+ storage_path
143
+ )
144
+ )
145
+ else:
146
+ filename = os.path.basename(storage_path)
147
+
148
+ download_url(url=url_or_filename, root=dirname, filename=filename)
149
+
150
+ def _download_vis(self):
151
+
152
+ storage_path = self.config.build_info.get(self.data_type).storage
153
+ storage_path = utils.get_cache_path(storage_path)
154
+
155
+ if not os.path.exists(storage_path):
156
+ warnings.warn(
157
+ f"""
158
+ The specified path {storage_path} for visual inputs does not exist.
159
+ Please provide a correct path to the visual inputs or
160
+ refer to datasets/download_scripts/README.md for downloading instructions.
161
+ """
162
+ )
163
+
164
+ def build(self):
165
+ """
166
+ Create by split datasets inheriting torch.utils.data.Datasets.
167
+
168
+ # build() can be dataset-specific. Overwrite to customize.
169
+ """
170
+ self.build_processors()
171
+
172
+ build_info = self.config.build_info
173
+
174
+ ann_info = build_info.annotations
175
+ vis_info = build_info.get(self.data_type)
176
+
177
+ datasets = dict()
178
+ for split in ann_info.keys():
179
+ if split not in ["train", "val", "test"]:
180
+ continue
181
+
182
+ is_train = split == "train"
183
+
184
+ # processors
185
+ vis_processor = (
186
+ self.vis_processors["train"]
187
+ if is_train
188
+ else self.vis_processors["eval"]
189
+ )
190
+ text_processor = (
191
+ self.text_processors["train"]
192
+ if is_train
193
+ else self.text_processors["eval"]
194
+ )
195
+
196
+ # annotation path
197
+ ann_paths = ann_info.get(split).storage
198
+ if isinstance(ann_paths, str):
199
+ ann_paths = [ann_paths]
200
+
201
+ abs_ann_paths = []
202
+ for ann_path in ann_paths:
203
+ if not os.path.isabs(ann_path):
204
+ ann_path = utils.get_cache_path(ann_path)
205
+ abs_ann_paths.append(ann_path)
206
+ ann_paths = abs_ann_paths
207
+
208
+ # visual data storage path
209
+ vis_path = os.path.join(vis_info.storage, split)
210
+
211
+ if not os.path.isabs(vis_path):
212
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
213
+ vis_path = utils.get_cache_path(vis_path)
214
+
215
+ if not os.path.exists(vis_path):
216
+ warnings.warn("storage path {} does not exist.".format(vis_path))
217
+
218
+ # create datasets
219
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
220
+ datasets[split] = dataset_cls(
221
+ vis_processor=vis_processor,
222
+ text_processor=text_processor,
223
+ ann_paths=ann_paths,
224
+ vis_root=vis_path,
225
+ )
226
+
227
+ return datasets
228
+
229
+
230
+ def load_dataset_config(cfg_path):
231
+ cfg = OmegaConf.load(cfg_path).datasets
232
+ cfg = cfg[list(cfg.keys())[0]]
233
+
234
+ return cfg