|
import gradio as gr |
|
import os |
|
import sys |
|
|
|
current_dir = os.getcwd() |
|
src_path = os.path.join(current_dir, 'src') |
|
os.chdir(src_path) |
|
sys.path.append(src_path) |
|
from open_clip import create_model_and_transforms |
|
from huggingface_hub import hf_hub_download |
|
from open_clip import HFTokenizer |
|
import torch |
|
|
|
|
|
|
|
class create_unimed_clip_model: |
|
def __init__(self, model_name): |
|
|
|
self.device = 'cpu' |
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
if model_name == "ViT/B-16": |
|
|
|
weights_path = hf_hub_download( |
|
repo_id="UzairK/unimed-clip-vit-b16", |
|
filename="unimed-clip-vit-b16.pt" |
|
) |
|
self.pretrained = weights_path |
|
self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" |
|
self.model_name = "ViT-B-16-quickgelu" |
|
elif model_name == 'ViT/L-14@336px-base-text': |
|
|
|
self.model_name = "ViT-L-14-336-quickgelu" |
|
weights_path = hf_hub_download( |
|
repo_id="UzairK/unimed_clip_vit_l14_base_text_encoder", |
|
filename="unimed_clip_vit_l14_base_text_encoder.pt" |
|
) |
|
self.pretrained = weights_path |
|
self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" |
|
self.tokenizer = HFTokenizer( |
|
self.text_encoder_name, |
|
context_length=256, |
|
**{}, |
|
) |
|
self.model, _, self.processor = create_model_and_transforms( |
|
self.model_name, |
|
self.pretrained, |
|
precision='amp', |
|
device=self.device, |
|
force_quick_gelu=True, |
|
pretrained_image=False, |
|
mean=mean, std=std, |
|
inmem=True, |
|
text_encoder_name=self.text_encoder_name, |
|
) |
|
|
|
def __call__(self, input_image, candidate_labels, hypothesis_template): |
|
|
|
input_image = self.processor(input_image).unsqueeze(0).to(self.device) |
|
if hypothesis_template == "": |
|
texts = [ |
|
self.tokenizer(cls_text).to(self.device) |
|
for cls_text in candidate_labels |
|
] |
|
else: |
|
texts = [ |
|
self.tokenizer(hypothesis_template + " " + cls_text).to(self.device) |
|
for cls_text in candidate_labels |
|
] |
|
texts = torch.cat(texts, dim=0) |
|
|
|
with torch.no_grad(): |
|
text_features = self.model.encode_text(texts) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
image_features = self.model.encode_image(input_image) |
|
logits = (image_features @ text_features.t()).softmax(dim=-1).cpu().numpy() |
|
return {hypothesis_template + " " + cls_text: float(score) for cls_text, score in zip(candidate_labels, logits[0])} |
|
|
|
|
|
|
|
pipes = { |
|
"ViT/B-16": create_unimed_clip_model(model_name="ViT/B-16"), |
|
"ViT/L-14@336px-base-text": create_unimed_clip_model(model_name='ViT/L-14@336px-base-text'), |
|
} |
|
|
|
|
|
|
|
def reset_all(): |
|
return None, "", "ViT/B-16", "", "", {} |
|
|
|
|
|
def add_label(label, current_labels): |
|
if not label.strip(): |
|
return current_labels, label |
|
labels_list = current_labels.split(",") if current_labels else [] |
|
if label not in labels_list: |
|
labels_list.append(label.strip()) |
|
return ", ".join(labels_list), "" |
|
|
|
|
|
def shot(image, labels_text, model_name, hypothesis_template): |
|
if not labels_text.strip() or not image: |
|
return {} |
|
labels = [label.strip() for label in labels_text.strip().split(",")] |
|
res = pipes[model_name]( |
|
input_image=image, |
|
candidate_labels=labels, |
|
hypothesis_template=hypothesis_template |
|
) |
|
return {single_key: res[single_key] for single_key in res.keys()} |
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown(""" |
|
# Zero-shot Medical Image Classification with UniMed-CLIP |
|
|
|
Demo for UniMed CLIP, a family of strong Medical Contrastive VLMs trained on UniMed-dataset. For more information about our project, refer to our paper and github repository. |
|
|
|
Paper: [https://arxiv.org/abs/2412.10372](https://arxiv.org/abs/2412.10372) |
|
Github: [https://github.com/mbzuai-oryx/UniMed-CLIP](https://github.com/mbzuai-oryx/UniMed-CLIP) |
|
|
|
**[DEMO USAGE]** To begin with the demo, provide a picture (either upload manually, or select from the given examples) and add class labels one by one. Optionally you can also add template as a prefix to the class labels. |
|
**[NOTE]** This demo is running on CPU and thus the response time might be a bit slower. Running it on a machine with a GPU will result in much faster predictions. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="pil", label="Image", width=300, height=300) |
|
model_choice = gr.Radio( |
|
choices=["ViT/B-16", "ViT/L-14@336px-base-text"], |
|
label="Model", |
|
value="ViT/B-16", |
|
) |
|
hypothesis_template = gr.Textbox( |
|
label="Prompt Template", |
|
placeholder="Optional prompt template as prefix", |
|
value="" |
|
) |
|
|
|
label_input = gr.Textbox(label="Candidate Label", placeholder="Add a class label, one by one",) |
|
add_btn = gr.Button("Add new Candidate Label") |
|
|
|
with gr.Column(scale=1): |
|
|
|
all_labels = gr.Textbox(label="Current Candidate Labels", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
reset_btn = gr.Button("Reset All", variant="secondary") |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
|
output = gr.Label(label="Predicted Scores") |
|
|
|
|
|
add_btn.click( |
|
fn=add_label, |
|
inputs=[label_input, all_labels], |
|
outputs=[all_labels, label_input] |
|
) |
|
|
|
|
|
reset_btn.click( |
|
fn=reset_all, |
|
inputs=[], |
|
outputs=[image_input, label_input, model_choice, hypothesis_template, all_labels, output] |
|
) |
|
|
|
submit_btn.click( |
|
fn=shot, |
|
inputs=[image_input, all_labels, model_choice, hypothesis_template], |
|
outputs=[output] |
|
) |
|
|
|
|
|
examples = [ |
|
["../docs/sample_images/brain_MRI.jpg", |
|
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", |
|
"ViT/B-16", ""], |
|
["../docs/sample_images/ct_scan_right_kidney.jpg", |
|
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", |
|
"ViT/B-16", ""], |
|
["../docs/sample_images/tumor_histo_pathology.jpg", |
|
"benign tissue., malignant tumor., normal cells., inflammatory tissue.", |
|
"ViT/B-16", |
|
"The histopathology slide indicates"], |
|
["../docs/sample_images/retina_glaucoma.jpg", |
|
"CT scan of the right kidney., pneumonia disease in this chest X-ray image., a brain MRI., glaucoma in fundus image., a histopathology slide showing Tumor, Cardiomegaly disease in X-ray image of the chest.", |
|
"ViT/B-16", "A photo of a"], |
|
["../docs/sample_images/tumor_histo_pathology.jpg", |
|
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", |
|
"ViT/B-16", ""], |
|
["../docs/sample_images/xray_cardiomegaly.jpg", |
|
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", |
|
"ViT/B-16", ""], |
|
["../docs/sample_images//xray_pneumonia.png", |
|
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", |
|
"ViT/B-16", ""], |
|
] |
|
gr.Examples(examples=examples, inputs=[image_input, all_labels, model_choice, hypothesis_template]) |
|
|
|
iface.launch(allowed_paths=["/home/user/app/docs/sample_images"]) |
|
|