muzairkhattak commited on
Commit
1eb3061
1 Parent(s): a0660ee

new interface

Browse files
Files changed (1) hide show
  1. app.py +118 -53
app.py CHANGED
@@ -1,19 +1,18 @@
1
-
2
  import gradio as gr
3
- # Switch path to root of project
4
  import os
5
  import sys
6
- # Get the current working directory
7
  current_dir = os.getcwd()
8
  src_path = os.path.join(current_dir, 'src')
9
  os.chdir(src_path)
10
- # Add src directory to sys.path
11
  sys.path.append(src_path)
12
  from open_clip import create_model_and_transforms
13
  from huggingface_hub import hf_hub_download
14
  from open_clip import HFTokenizer
15
  import torch
16
 
 
 
17
  class create_unimed_clip_model:
18
  def __init__(self, model_name):
19
  # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -75,62 +74,128 @@ class create_unimed_clip_model:
75
  text_features = text_features / text_features.norm(dim=-1, keepdim=True)
76
  image_features = self.model.encode_image(input_image)
77
  logits = (image_features @ text_features.t()).softmax(dim=-1).cpu().numpy()
78
- return {cls_text: float(score) for cls_text, score in zip(candidate_labels, logits[0])}
 
 
79
 
80
  pipes = {
81
  "ViT/B-16": create_unimed_clip_model(model_name="ViT/B-16"),
82
  "ViT/L-14@336px-base-text": create_unimed_clip_model(model_name='ViT/L-14@336px-base-text'),
83
  }
84
- # Define Gradio inputs and outputs
85
- inputs = [
86
- gr.Image(type="pil", label="Image", width=300, height=300),
87
- gr.Textbox(label="Candidate Labels (comma-separated)"),
88
- gr.Radio(
89
- choices=["ViT/B-16", "ViT/L-14@336px-base-text"],
90
- label="Model",
91
- value="ViT/B-16",
92
- ),
93
- gr.Textbox(label="Prompt Template", placeholder="Optional prompt template as prefix",
94
- value=""),
95
- ]
96
- outputs = gr.Label(label="Predicted Scores")
 
 
97
 
98
  def shot(image, labels_text, model_name, hypothesis_template):
99
- labels = [label.strip(" ") for label in labels_text.strip(" ").split(",")]
100
- res = pipes[model_name](input_image=image,
101
- candidate_labels=labels,
102
- hypothesis_template=hypothesis_template)
 
 
 
 
103
  return {single_key: res[single_key] for single_key in res.keys()}
104
- # Define examples
105
-
106
- examples = [
107
- ["../docs/sample_images/brain_MRI.jpg", "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", "ViT/B-16", ""],
108
- ["../docs/sample_images/ct_scan_right_kidney.jpg",
109
- "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
110
- "ViT/B-16", ""],
111
- ["../docs/sample_images/retina_glaucoma.jpg",
112
- "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
113
- "ViT/B-16", ""],
114
- ["../docs/sample_images/tumor_histo_pathology.jpg",
115
- "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
116
- "ViT/B-16", ""],
117
- ["../docs/sample_images/xray_cardiomegaly.jpg",
118
- "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
119
- "ViT/B-16", ""],
120
- ["../docs/sample_images//xray_pneumonia.png",
121
- "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
122
- "ViT/B-16", ""],
123
- ]
124
-
125
- iface = gr.Interface(shot,
126
- inputs,
127
- outputs,
128
- examples=examples,
129
- description="""<p>Demo for UniMed CLIP, a family of strong Medical Contrastive VLMs trained on UniMed-dataset. For more information about our project, refer to our paper and github repository. <br>
130
- Paper: <a href='https://arxiv.org/abs/2412.10372'>https://arxiv.org/abs/2412.10372</a> <br>
131
- Github: <a href='https://github.com/mbzuai-oryx/UniMed-CLIP'>https://github.com/mbzuai-oryx/UniMed-CLIP</a> <br><br>
132
- <b>[DEMO USAGE]</b> To begin with the demo, provide a picture (either upload manually, or select from the given examples) and class labels. Optionally you can also add template as an prefix to the class labels. <br> <b>[NOTE]</b> This demo is running on CPU and thus the response time might be a bit slower. Running it on a machine with a GPU will result in much faster predictions. </p>""",
133
-
134
- title="Zero-shot Medical Image Classification with UniMed-CLIP")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  iface.launch(allowed_paths=["/home/user/app/docs/sample_images"])
 
 
1
  import gradio as gr
 
2
  import os
3
  import sys
4
+
5
  current_dir = os.getcwd()
6
  src_path = os.path.join(current_dir, 'src')
7
  os.chdir(src_path)
 
8
  sys.path.append(src_path)
9
  from open_clip import create_model_and_transforms
10
  from huggingface_hub import hf_hub_download
11
  from open_clip import HFTokenizer
12
  import torch
13
 
14
+
15
+ # Your existing create_unimed_clip_model class remains the same
16
  class create_unimed_clip_model:
17
  def __init__(self, model_name):
18
  # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
74
  text_features = text_features / text_features.norm(dim=-1, keepdim=True)
75
  image_features = self.model.encode_image(input_image)
76
  logits = (image_features @ text_features.t()).softmax(dim=-1).cpu().numpy()
77
+ return {hypothesis_template + " " + cls_text: float(score) for cls_text, score in zip(candidate_labels, logits[0])}
78
+
79
+
80
 
81
  pipes = {
82
  "ViT/B-16": create_unimed_clip_model(model_name="ViT/B-16"),
83
  "ViT/L-14@336px-base-text": create_unimed_clip_model(model_name='ViT/L-14@336px-base-text'),
84
  }
85
+
86
+
87
+
88
+ def reset_all():
89
+ return None, "", "ViT/B-16", "", "", {}
90
+
91
+
92
+ def add_label(label, current_labels):
93
+ if not label.strip():
94
+ return current_labels, label
95
+ labels_list = current_labels.split(",") if current_labels else []
96
+ if label not in labels_list:
97
+ labels_list.append(label.strip())
98
+ return ", ".join(labels_list), "" # Return updated labels and empty string for input
99
+
100
 
101
  def shot(image, labels_text, model_name, hypothesis_template):
102
+ if not labels_text.strip() or not image:
103
+ return {}
104
+ labels = [label.strip() for label in labels_text.strip().split(",")]
105
+ res = pipes[model_name](
106
+ input_image=image,
107
+ candidate_labels=labels,
108
+ hypothesis_template=hypothesis_template
109
+ )
110
  return {single_key: res[single_key] for single_key in res.keys()}
111
+
112
+
113
+ with gr.Blocks() as iface:
114
+ gr.Markdown("""
115
+ # Zero-shot Medical Image Classification with UniMed-CLIP
116
+
117
+ Demo for UniMed CLIP, a family of strong Medical Contrastive VLMs trained on UniMed-dataset. For more information about our project, refer to our paper and github repository.
118
+
119
+ Paper: [https://arxiv.org/abs/2412.10372](https://arxiv.org/abs/2412.10372)
120
+ Github: [https://github.com/mbzuai-oryx/UniMed-CLIP](https://github.com/mbzuai-oryx/UniMed-CLIP)
121
+
122
+ **[DEMO USAGE]** To begin with the demo, provide a picture (either upload manually, or select from the given examples) and add class labels one by one. Optionally you can also add template as a prefix to the class labels.
123
+ **[NOTE]** This demo is running on CPU and thus the response time might be a bit slower. Running it on a machine with a GPU will result in much faster predictions.
124
+ """)
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=1):
128
+ image_input = gr.Image(type="pil", label="Image", width=300, height=300)
129
+ model_choice = gr.Radio(
130
+ choices=["ViT/B-16", "ViT/L-14@336px-base-text"],
131
+ label="Model",
132
+ value="ViT/B-16",
133
+ )
134
+ hypothesis_template = gr.Textbox(
135
+ label="Prompt Template",
136
+ placeholder="Optional prompt template as prefix",
137
+ value=""
138
+ )
139
+ # Label management section
140
+ label_input = gr.Textbox(label="Candidate Label", placeholder="Add a class label, one by one",)
141
+ add_btn = gr.Button("Add new Candidate Label")
142
+
143
+ with gr.Column(scale=1):
144
+ # Hidden textbox to store all labels
145
+ all_labels = gr.Textbox(label="Current Candidate Labels", interactive=False)
146
+
147
+ # Submit and Reset buttons side by side
148
+ with gr.Row():
149
+ reset_btn = gr.Button("Reset All", variant="secondary")
150
+ submit_btn = gr.Button("Submit", variant="primary")
151
+ # Output section
152
+ output = gr.Label(label="Predicted Scores")
153
+
154
+ # Event handlers
155
+ add_btn.click(
156
+ fn=add_label,
157
+ inputs=[label_input, all_labels],
158
+ outputs=[all_labels, label_input] # Now also clearing the input
159
+ )
160
+
161
+ # Reset all inputs
162
+ reset_btn.click(
163
+ fn=reset_all,
164
+ inputs=[],
165
+ outputs=[image_input, label_input, model_choice, hypothesis_template, all_labels, output]
166
+ )
167
+ # Only trigger classification on submit
168
+ submit_btn.click(
169
+ fn=shot,
170
+ inputs=[image_input, all_labels, model_choice, hypothesis_template],
171
+ outputs=[output]
172
+ )
173
+
174
+ # Add the examples
175
+ examples = [
176
+ ["../docs/sample_images/brain_MRI.jpg",
177
+ "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
178
+ "ViT/B-16", ""],
179
+ ["../docs/sample_images/ct_scan_right_kidney.jpg",
180
+ "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
181
+ "ViT/B-16", ""],
182
+ ["../docs/sample_images/tumor_histo_pathology.jpg",
183
+ "benign tissue., malignant tumor., normal cells., inflammatory tissue.",
184
+ "ViT/B-16",
185
+ "The histopathology slide indicates"],
186
+ ["../docs/sample_images/retina_glaucoma.jpg",
187
+ "CT scan of the right kidney., pneumonia disease in this chest X-ray image., a brain MRI., glaucoma in fundus image., a histopathology slide showing Tumor, Cardiomegaly disease in X-ray image of the chest.",
188
+ "ViT/B-16", "A photo of a"],
189
+ ["../docs/sample_images/tumor_histo_pathology.jpg",
190
+ "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
191
+ "ViT/B-16", ""],
192
+ ["../docs/sample_images/xray_cardiomegaly.jpg",
193
+ "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
194
+ "ViT/B-16", ""],
195
+ ["../docs/sample_images//xray_pneumonia.png",
196
+ "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
197
+ "ViT/B-16", ""],
198
+ ]
199
+ gr.Examples(examples=examples, inputs=[image_input, all_labels, model_choice, hypothesis_template])
200
 
201
  iface.launch(allowed_paths=["/home/user/app/docs/sample_images"])