ChayanM commited on
Commit
8173f8d
·
verified ·
1 Parent(s): d728dac

Update Chest_Xray_Report_Generator-V2.py

Browse files
Files changed (1) hide show
  1. Chest_Xray_Report_Generator-V2.py +306 -306
Chest_Xray_Report_Generator-V2.py CHANGED
@@ -1,307 +1,307 @@
1
- import os
2
- import transformers
3
- from transformers import pipeline
4
- import gradio as gr
5
- import cv2
6
- import numpy as np
7
- import pydicom
8
-
9
- ##### Libraries For Grad-Cam-View
10
- import os
11
- import cv2
12
- import numpy as np
13
- import torch
14
- from functools import partial
15
- from torchvision import transforms
16
- from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad
17
- from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
18
- from pytorch_grad_cam.ablation_layer import AblationLayerVit
19
- from transformers import VisionEncoderDecoderModel
20
-
21
- def generate_gradcam(image_path, model_path, output_path, method='gradcam', use_cuda=True, aug_smooth=False, eigen_smooth=False):
22
- methods = {
23
- "gradcam": GradCAM,
24
- "scorecam": ScoreCAM,
25
- "gradcam++": GradCAMPlusPlus,
26
- "ablationcam": AblationCAM,
27
- "xgradcam": XGradCAM,
28
- "eigencam": EigenCAM,
29
- "eigengradcam": EigenGradCAM,
30
- "layercam": LayerCAM,
31
- "fullgrad": FullGrad
32
- }
33
-
34
- if method not in methods:
35
- raise ValueError(f"Method should be one of {list(methods.keys())}")
36
-
37
- model = VisionEncoderDecoderModel.from_pretrained(model_path)
38
- model.encoder.eval()
39
-
40
- if use_cuda and torch.cuda.is_available():
41
- model.encoder = model.encoder.cuda()
42
- else:
43
- use_cuda = False
44
-
45
- #target_layers = [model.blocks[-1].norm1] ## For ViT model
46
- #target_layers = model.blocks[-1].norm1 ## For EfficientNet-B7 model
47
- target_layers = [model.encoder.encoder.layer[-1].layernorm_before] ## For ViT-based VisionEncoderDecoder model
48
- #target_layers = [model.encoder.encoder.layers[-1].blocks[-1].layernorm_before, model.encoder.encoder.layers[-1].blocks[0].layernorm_before] ## For Swin-based VisionEncoderDecoder mode
49
-
50
-
51
- if method == "ablationcam":
52
- cam = methods[method](model=model.encoder,
53
- target_layers=target_layers,
54
- use_cuda=use_cuda,
55
- reshape_transform=reshape_transform,
56
- ablation_layer=AblationLayerVit())
57
- else:
58
- cam = methods[method](model=model.encoder,
59
- target_layers=target_layers,
60
- use_cuda=use_cuda,
61
- reshape_transform=reshape_transform)
62
-
63
- rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
64
- rgb_img = cv2.resize(rgb_img, (224, 224)) ## (224, 224)
65
- rgb_img = np.float32(rgb_img) / 255
66
- input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
67
-
68
- targets = None
69
- cam.batch_size = 16
70
-
71
- grayscale_cam = cam(input_tensor=input_tensor, targets=targets, eigen_smooth=eigen_smooth, aug_smooth=aug_smooth)
72
- grayscale_cam = grayscale_cam[0, :]
73
-
74
- cam_image = show_cam_on_image(rgb_img, grayscale_cam)
75
- output_file = os.path.join(output_path, 'gradcam_result.png')
76
- cv2.imwrite(output_file, cam_image)
77
-
78
-
79
- def reshape_transform(tensor, height=14, width=14): ### height=14, width=14 for ViT-based Model
80
- batch_size, token_number, embed_dim = tensor.size()
81
- if token_number < height * width:
82
- pad = torch.zeros(batch_size, height * width - token_number, embed_dim, device=tensor.device)
83
- tensor = torch.cat([tensor, pad], dim=1)
84
- elif token_number > height * width:
85
- tensor = tensor[:, :height * width, :]
86
-
87
- result = tensor.reshape(batch_size, height, width, embed_dim)
88
- result = result.transpose(2, 3).transpose(1, 2)
89
- return result
90
-
91
-
92
-
93
-
94
- # Example usage:
95
- #image_path = "/home/chayan/CGI_Net/images/images/CXR1353_IM-0230-1001.png"
96
- model_path = "/home/chayan/ViT-GPT2/Mimic_test/"
97
- output_path = "/home/chayan/ViT-GPT2/CAM-Result/"
98
-
99
-
100
-
101
- def sentence_case(paragraph):
102
- sentences = paragraph.split('. ')
103
- formatted_sentences = [sentence.capitalize() for sentence in sentences if sentence]
104
- formatted_paragraph = '. '.join(formatted_sentences)
105
- return formatted_paragraph
106
-
107
- def dicom_to_png(dicom_file, png_file):
108
- # Load DICOM file
109
- dicom_data = pydicom.dcmread(dicom_file)
110
- dicom_data.PhotometricInterpretation = 'MONOCHROME1'
111
-
112
- # Normalize pixel values to 0-255
113
- img = dicom_data.pixel_array
114
- img = img.astype(np.float32)
115
-
116
- img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
117
- img = img.astype(np.uint8)
118
-
119
- # Save as PNG
120
- cv2.imwrite(png_file, img)
121
- return img
122
-
123
-
124
- Image_Captioner = pipeline("image-to-text", model = "/home/chayan/ViT-GPT2/Mimic_test/")
125
-
126
- data_dir = '/home/chayan/ViT-GPT2/'
127
-
128
- def xray_report_generator(Image_file):
129
- if Image_file[-4:] =='.dcm':
130
- png_file = 'DCM2PNG.png'
131
- dicom_to_png(Image_file, png_file)
132
- Image_file = os.path.join(data_dir, png_file)
133
- output = Image_Captioner(Image_file, max_new_tokens=512)
134
-
135
- else:
136
- output = Image_Captioner(Image_file, max_new_tokens=512)
137
-
138
- result = output[0]['generated_text']
139
- output_paragraph = sentence_case(result)
140
-
141
- generate_gradcam(Image_file, model_path, output_path, method='gradcam', use_cuda=True)
142
-
143
- grad_cam_image = output_path + 'gradcam_result.png'
144
-
145
- return Image_file,grad_cam_image, output_paragraph
146
-
147
-
148
-
149
- def save_feedback(feedback):
150
- feedback_dir = "/home/chayan/ViT-GPT2/Feedback/" # Update this to your desired directory
151
- if not os.path.exists(feedback_dir):
152
- os.makedirs(feedback_dir)
153
- feedback_file = os.path.join(feedback_dir, "feedback.txt")
154
- with open(feedback_file, "a") as f:
155
- f.write(feedback + "\n")
156
- return "Feedback submitted successfully!"
157
-
158
-
159
-
160
-
161
- # Custom CSS styles
162
- custom_css = """
163
- <style>
164
-
165
- #title {
166
- color: green;
167
- font-size: 36px;
168
- font-weight: bold;
169
- }
170
- #description {
171
- color: green;
172
- font-size: 22px;
173
- }
174
-
175
-
176
- #submit-btn {
177
- background-color: #1E90FF; /* DodgerBlue */
178
- color: green;
179
- padding: 15px 32px;
180
- text-align: center;
181
- text-decoration: none;
182
- display: inline-block;
183
- font-size: 20px;
184
- margin: 4px 2px;
185
- cursor: pointer;
186
- }
187
- #submit-btn:hover {
188
- background-color: #00FFFF;
189
- }
190
-
191
- .intext textarea {
192
- color: green;
193
- font-size: 20px;
194
- font-weight: bold;
195
- }
196
-
197
-
198
- .small-button {
199
- color: green;
200
- padding: 5px 10px;
201
- font-size: 20px;
202
- }
203
-
204
- </style>
205
- """
206
-
207
- # Sample image paths
208
- sample_images = [
209
- "/mnt/data/chayan/MIMIC-CXR-JPG/2.0.0/files/p19565388/s54621108/a9510716-02da91b0-61532c26-a65b2efc-c9dfa6f1.jpg",
210
- "/mnt/data/chayan/MIMIC-CXR-JPG/2.0.0/files/p19454978/s52312858/93681764-ec39480e-0518b12c-199850c2-f15118ab.jpg",
211
- "/mnt/data/chayan/MIMIC-CXR-JPG/2.0.0/files/p17340686/s55469953/6ff741e9-6ea01eef-1bf10153-d1b6beba-590b6620.jpg"
212
- #"sample4.png",
213
- #"sample5.png"
214
- ]
215
-
216
- def set_input_image(image_path):
217
- return gr.update(value=image_path)
218
-
219
-
220
- with gr.Blocks(css = custom_css) as demo:
221
-
222
- #gr.HTML(custom_css) # Inject custom CSS
223
-
224
- gr.Markdown(
225
- """
226
- <h1 style="color:blue; font-size: 36px; font-weight: bold">Chest X-ray Report Generator</h1>
227
- <p id="description">Upload an X-ray image and get its report with heat-map visualization.</p>
228
- """
229
- )
230
-
231
- with gr.Row():
232
- inputs = gr.File(label="Upload Chest X-ray Image File", type="filepath")
233
-
234
- with gr.Row():
235
- with gr.Column(scale=1, min_width=300):
236
- outputs1 = gr.Image(label="Image Viewer")
237
- with gr.Column(scale=1, min_width=300):
238
- outputs2 = gr.Image(label="Grad_CAM-Visualization")
239
- with gr.Column(scale=1, min_width=300):
240
- outputs3 = gr.Textbox(label="Generated Report", elem_classes = "intext")
241
-
242
-
243
- submit_btn = gr.Button("Generate Report", elem_id="submit-btn")
244
- submit_btn.click(
245
- fn=xray_report_generator,
246
- inputs=inputs,
247
- outputs=[outputs1, outputs2, outputs3])
248
-
249
-
250
- gr.Markdown(
251
- """
252
- <h2 style="color:green; font-size: 24px;">Or choose a sample image:</h2>
253
- """
254
- )
255
-
256
- with gr.Row():
257
- for idx, sample_image in enumerate(sample_images):
258
- with gr.Column(scale=1):
259
- #sample_image_component = gr.Image(value=sample_image, interactive=False)
260
- select_button = gr.Button(f"Select Sample Image {idx+1}")
261
- select_button.click(
262
- fn=set_input_image,
263
- inputs=gr.State(value=sample_image),
264
- outputs=inputs
265
- )
266
-
267
-
268
-
269
- # Feedback section
270
- gr.Markdown(
271
- """
272
- <h2 style="color:green; font-size: 24px;">Provide Your Valuable Feedback:</h2>
273
- """
274
- )
275
-
276
- with gr.Row():
277
- feedback_input = gr.Textbox(label="Your Feedback", lines=4, placeholder="Enter your feedback here...")
278
- feedback_submit_btn = gr.Button("Submit Feedback", elem_classes="small-button")
279
- feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
280
-
281
- feedback_submit_btn.click(
282
- fn=save_feedback,
283
- inputs=feedback_input,
284
- outputs=feedback_output
285
- )
286
-
287
-
288
-
289
- demo.launch(share=True)
290
-
291
-
292
- # inputs = gr.File(label="Upload Chest X-ray Image File", type="filepath")
293
- # outputs1 =gr.Image(label="Image Viewer")
294
- # outputs2 =gr.Image(label="Grad_CAM-Visualization")
295
- # outputs3 = gr.Textbox(label="Generated Report")
296
-
297
-
298
- # interface = gr.Interface(
299
- # fn=xray_report_generator,
300
- # inputs=inputs,
301
- # outputs=[outputs1, outputs2, outputs3],
302
- # title="Chest X-ray Report Generator",
303
- # description="Upload an X-ray image and get its report.",
304
- # )
305
-
306
-
307
  # interface.launch(share=True)
 
1
+ import os
2
+ import transformers
3
+ from transformers import pipeline
4
+ import gradio as gr
5
+ import cv2
6
+ import numpy as np
7
+ import pydicom
8
+
9
+ ##### Libraries For Grad-Cam-View
10
+ import os
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ from functools import partial
15
+ from torchvision import transforms
16
+ from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
18
+ from pytorch_grad_cam.ablation_layer import AblationLayerVit
19
+ from transformers import VisionEncoderDecoderModel
20
+
21
+ def generate_gradcam(image_path, model_path, output_path, method='gradcam', use_cuda=True, aug_smooth=False, eigen_smooth=False):
22
+ methods = {
23
+ "gradcam": GradCAM,
24
+ "scorecam": ScoreCAM,
25
+ "gradcam++": GradCAMPlusPlus,
26
+ "ablationcam": AblationCAM,
27
+ "xgradcam": XGradCAM,
28
+ "eigencam": EigenCAM,
29
+ "eigengradcam": EigenGradCAM,
30
+ "layercam": LayerCAM,
31
+ "fullgrad": FullGrad
32
+ }
33
+
34
+ if method not in methods:
35
+ raise ValueError(f"Method should be one of {list(methods.keys())}")
36
+
37
+ model = VisionEncoderDecoderModel.from_pretrained(model_path)
38
+ model.encoder.eval()
39
+
40
+ if use_cuda and torch.cuda.is_available():
41
+ model.encoder = model.encoder.cuda()
42
+ else:
43
+ use_cuda = False
44
+
45
+ #target_layers = [model.blocks[-1].norm1] ## For ViT model
46
+ #target_layers = model.blocks[-1].norm1 ## For EfficientNet-B7 model
47
+ target_layers = [model.encoder.encoder.layer[-1].layernorm_before] ## For ViT-based VisionEncoderDecoder model
48
+ #target_layers = [model.encoder.encoder.layers[-1].blocks[-1].layernorm_before, model.encoder.encoder.layers[-1].blocks[0].layernorm_before] ## For Swin-based VisionEncoderDecoder mode
49
+
50
+
51
+ if method == "ablationcam":
52
+ cam = methods[method](model=model.encoder,
53
+ target_layers=target_layers,
54
+ use_cuda=use_cuda,
55
+ reshape_transform=reshape_transform,
56
+ ablation_layer=AblationLayerVit())
57
+ else:
58
+ cam = methods[method](model=model.encoder,
59
+ target_layers=target_layers,
60
+ use_cuda=use_cuda,
61
+ reshape_transform=reshape_transform)
62
+
63
+ rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
64
+ rgb_img = cv2.resize(rgb_img, (224, 224)) ## (224, 224)
65
+ rgb_img = np.float32(rgb_img) / 255
66
+ input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
67
+
68
+ targets = None
69
+ cam.batch_size = 16
70
+
71
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets, eigen_smooth=eigen_smooth, aug_smooth=aug_smooth)
72
+ grayscale_cam = grayscale_cam[0, :]
73
+
74
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam)
75
+ output_file = os.path.join(output_path, 'gradcam_result.png')
76
+ cv2.imwrite(output_file, cam_image)
77
+
78
+
79
+ def reshape_transform(tensor, height=14, width=14): ### height=14, width=14 for ViT-based Model
80
+ batch_size, token_number, embed_dim = tensor.size()
81
+ if token_number < height * width:
82
+ pad = torch.zeros(batch_size, height * width - token_number, embed_dim, device=tensor.device)
83
+ tensor = torch.cat([tensor, pad], dim=1)
84
+ elif token_number > height * width:
85
+ tensor = tensor[:, :height * width, :]
86
+
87
+ result = tensor.reshape(batch_size, height, width, embed_dim)
88
+ result = result.transpose(2, 3).transpose(1, 2)
89
+ return result
90
+
91
+
92
+
93
+
94
+ # Example usage:
95
+ #image_path = "/home/chayan/CGI_Net/images/images/CXR1353_IM-0230-1001.png"
96
+ model_path = "./Mimic_test/"
97
+ output_path = "./CAM-Result/"
98
+
99
+
100
+
101
+ def sentence_case(paragraph):
102
+ sentences = paragraph.split('. ')
103
+ formatted_sentences = [sentence.capitalize() for sentence in sentences if sentence]
104
+ formatted_paragraph = '. '.join(formatted_sentences)
105
+ return formatted_paragraph
106
+
107
+ def dicom_to_png(dicom_file, png_file):
108
+ # Load DICOM file
109
+ dicom_data = pydicom.dcmread(dicom_file)
110
+ dicom_data.PhotometricInterpretation = 'MONOCHROME1'
111
+
112
+ # Normalize pixel values to 0-255
113
+ img = dicom_data.pixel_array
114
+ img = img.astype(np.float32)
115
+
116
+ img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
117
+ img = img.astype(np.uint8)
118
+
119
+ # Save as PNG
120
+ cv2.imwrite(png_file, img)
121
+ return img
122
+
123
+
124
+ Image_Captioner = pipeline("image-to-text", model = "./Mimic_test/")
125
+
126
+ data_dir = output_path
127
+
128
+ def xray_report_generator(Image_file):
129
+ if Image_file[-4:] =='.dcm':
130
+ png_file = 'DCM2PNG.png'
131
+ dicom_to_png(Image_file, png_file)
132
+ Image_file = os.path.join(data_dir, png_file)
133
+ output = Image_Captioner(Image_file, max_new_tokens=512)
134
+
135
+ else:
136
+ output = Image_Captioner(Image_file, max_new_tokens=512)
137
+
138
+ result = output[0]['generated_text']
139
+ output_paragraph = sentence_case(result)
140
+
141
+ generate_gradcam(Image_file, model_path, output_path, method='gradcam', use_cuda=True)
142
+
143
+ grad_cam_image = output_path + 'gradcam_result.png'
144
+
145
+ return Image_file,grad_cam_image, output_paragraph
146
+
147
+
148
+
149
+ def save_feedback(feedback):
150
+ feedback_dir = "./Feedback/" # Update this to your desired directory
151
+ if not os.path.exists(feedback_dir):
152
+ os.makedirs(feedback_dir)
153
+ feedback_file = os.path.join(feedback_dir, "feedback.txt")
154
+ with open(feedback_file, "a") as f:
155
+ f.write(feedback + "\n")
156
+ return "Feedback submitted successfully!"
157
+
158
+
159
+
160
+
161
+ # Custom CSS styles
162
+ custom_css = """
163
+ <style>
164
+
165
+ #title {
166
+ color: green;
167
+ font-size: 36px;
168
+ font-weight: bold;
169
+ }
170
+ #description {
171
+ color: green;
172
+ font-size: 22px;
173
+ }
174
+
175
+
176
+ #submit-btn {
177
+ background-color: #1E90FF; /* DodgerBlue */
178
+ color: green;
179
+ padding: 15px 32px;
180
+ text-align: center;
181
+ text-decoration: none;
182
+ display: inline-block;
183
+ font-size: 20px;
184
+ margin: 4px 2px;
185
+ cursor: pointer;
186
+ }
187
+ #submit-btn:hover {
188
+ background-color: #00FFFF;
189
+ }
190
+
191
+ .intext textarea {
192
+ color: green;
193
+ font-size: 20px;
194
+ font-weight: bold;
195
+ }
196
+
197
+
198
+ .small-button {
199
+ color: green;
200
+ padding: 5px 10px;
201
+ font-size: 20px;
202
+ }
203
+
204
+ </style>
205
+ """
206
+
207
+ # Sample image paths
208
+ sample_images = [
209
+ "./Test-Images/p19565388/s54621108/a9510716-02da91b0-61532c26-a65b2efc-c9dfa6f1.jpg",
210
+ "./Test-Images/93681764-ec39480e-0518b12c-199850c2-f15118ab.jpg",
211
+ "./Test-Images/6ff741e9-6ea01eef-1bf10153-d1b6beba-590b6620.jpg"
212
+ #"sample4.png",
213
+ #"sample5.png"
214
+ ]
215
+
216
+ def set_input_image(image_path):
217
+ return gr.update(value=image_path)
218
+
219
+
220
+ with gr.Blocks(css = custom_css) as demo:
221
+
222
+ #gr.HTML(custom_css) # Inject custom CSS
223
+
224
+ gr.Markdown(
225
+ """
226
+ <h1 style="color:blue; font-size: 36px; font-weight: bold">Chest X-ray Report Generator</h1>
227
+ <p id="description">Upload an X-ray image and get its report with heat-map visualization.</p>
228
+ """
229
+ )
230
+
231
+ with gr.Row():
232
+ inputs = gr.File(label="Upload Chest X-ray Image File", type="filepath")
233
+
234
+ with gr.Row():
235
+ with gr.Column(scale=1, min_width=300):
236
+ outputs1 = gr.Image(label="Image Viewer")
237
+ with gr.Column(scale=1, min_width=300):
238
+ outputs2 = gr.Image(label="Grad_CAM-Visualization")
239
+ with gr.Column(scale=1, min_width=300):
240
+ outputs3 = gr.Textbox(label="Generated Report", elem_classes = "intext")
241
+
242
+
243
+ submit_btn = gr.Button("Generate Report", elem_id="submit-btn")
244
+ submit_btn.click(
245
+ fn=xray_report_generator,
246
+ inputs=inputs,
247
+ outputs=[outputs1, outputs2, outputs3])
248
+
249
+
250
+ gr.Markdown(
251
+ """
252
+ <h2 style="color:green; font-size: 24px;">Or choose a sample image:</h2>
253
+ """
254
+ )
255
+
256
+ with gr.Row():
257
+ for idx, sample_image in enumerate(sample_images):
258
+ with gr.Column(scale=1):
259
+ #sample_image_component = gr.Image(value=sample_image, interactive=False)
260
+ select_button = gr.Button(f"Select Sample Image {idx+1}")
261
+ select_button.click(
262
+ fn=set_input_image,
263
+ inputs=gr.State(value=sample_image),
264
+ outputs=inputs
265
+ )
266
+
267
+
268
+
269
+ # Feedback section
270
+ gr.Markdown(
271
+ """
272
+ <h2 style="color:green; font-size: 24px;">Provide Your Valuable Feedback:</h2>
273
+ """
274
+ )
275
+
276
+ with gr.Row():
277
+ feedback_input = gr.Textbox(label="Your Feedback", lines=4, placeholder="Enter your feedback here...")
278
+ feedback_submit_btn = gr.Button("Submit Feedback", elem_classes="small-button")
279
+ feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
280
+
281
+ feedback_submit_btn.click(
282
+ fn=save_feedback,
283
+ inputs=feedback_input,
284
+ outputs=feedback_output
285
+ )
286
+
287
+
288
+
289
+ demo.launch(share=True)
290
+
291
+
292
+ # inputs = gr.File(label="Upload Chest X-ray Image File", type="filepath")
293
+ # outputs1 =gr.Image(label="Image Viewer")
294
+ # outputs2 =gr.Image(label="Grad_CAM-Visualization")
295
+ # outputs3 = gr.Textbox(label="Generated Report")
296
+
297
+
298
+ # interface = gr.Interface(
299
+ # fn=xray_report_generator,
300
+ # inputs=inputs,
301
+ # outputs=[outputs1, outputs2, outputs3],
302
+ # title="Chest X-ray Report Generator",
303
+ # description="Upload an X-ray image and get its report.",
304
+ # )
305
+
306
+
307
  # interface.launch(share=True)