muzairkhattak
new interface
1eb3061
import gradio as gr
import os
import sys
current_dir = os.getcwd()
src_path = os.path.join(current_dir, 'src')
os.chdir(src_path)
sys.path.append(src_path)
from open_clip import create_model_and_transforms
from huggingface_hub import hf_hub_download
from open_clip import HFTokenizer
import torch
# Your existing create_unimed_clip_model class remains the same
class create_unimed_clip_model:
def __init__(self, model_name):
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = 'cpu'
mean = (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean
std = (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std
if model_name == "ViT/B-16":
# Download the weights
weights_path = hf_hub_download(
repo_id="UzairK/unimed-clip-vit-b16",
filename="unimed-clip-vit-b16.pt"
)
self.pretrained = weights_path # Path to pretrained weights
self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
self.model_name = "ViT-B-16-quickgelu"
elif model_name == 'ViT/L-14@336px-base-text':
# Download the weights
self.model_name = "ViT-L-14-336-quickgelu"
weights_path = hf_hub_download(
repo_id="UzairK/unimed_clip_vit_l14_base_text_encoder",
filename="unimed_clip_vit_l14_base_text_encoder.pt"
)
self.pretrained = weights_path # Path to pretrained weights
self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
self.tokenizer = HFTokenizer(
self.text_encoder_name,
context_length=256,
**{},
)
self.model, _, self.processor = create_model_and_transforms(
self.model_name,
self.pretrained,
precision='amp',
device=self.device,
force_quick_gelu=True,
pretrained_image=False,
mean=mean, std=std,
inmem=True,
text_encoder_name=self.text_encoder_name,
)
def __call__(self, input_image, candidate_labels, hypothesis_template):
# Preprocess input
input_image = self.processor(input_image).unsqueeze(0).to(self.device)
if hypothesis_template == "":
texts = [
self.tokenizer(cls_text).to(self.device)
for cls_text in candidate_labels
]
else:
texts = [
self.tokenizer(hypothesis_template + " " + cls_text).to(self.device)
for cls_text in candidate_labels
]
texts = torch.cat(texts, dim=0)
# Perform inference
with torch.no_grad():
text_features = self.model.encode_text(texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features = self.model.encode_image(input_image)
logits = (image_features @ text_features.t()).softmax(dim=-1).cpu().numpy()
return {hypothesis_template + " " + cls_text: float(score) for cls_text, score in zip(candidate_labels, logits[0])}
pipes = {
"ViT/B-16": create_unimed_clip_model(model_name="ViT/B-16"),
"ViT/L-14@336px-base-text": create_unimed_clip_model(model_name='ViT/L-14@336px-base-text'),
}
def reset_all():
return None, "", "ViT/B-16", "", "", {}
def add_label(label, current_labels):
if not label.strip():
return current_labels, label
labels_list = current_labels.split(",") if current_labels else []
if label not in labels_list:
labels_list.append(label.strip())
return ", ".join(labels_list), "" # Return updated labels and empty string for input
def shot(image, labels_text, model_name, hypothesis_template):
if not labels_text.strip() or not image:
return {}
labels = [label.strip() for label in labels_text.strip().split(",")]
res = pipes[model_name](
input_image=image,
candidate_labels=labels,
hypothesis_template=hypothesis_template
)
return {single_key: res[single_key] for single_key in res.keys()}
with gr.Blocks() as iface:
gr.Markdown("""
# Zero-shot Medical Image Classification with UniMed-CLIP
Demo for UniMed CLIP, a family of strong Medical Contrastive VLMs trained on UniMed-dataset. For more information about our project, refer to our paper and github repository.
Paper: [https://arxiv.org/abs/2412.10372](https://arxiv.org/abs/2412.10372)
Github: [https://github.com/mbzuai-oryx/UniMed-CLIP](https://github.com/mbzuai-oryx/UniMed-CLIP)
**[DEMO USAGE]** To begin with the demo, provide a picture (either upload manually, or select from the given examples) and add class labels one by one. Optionally you can also add template as a prefix to the class labels.
**[NOTE]** This demo is running on CPU and thus the response time might be a bit slower. Running it on a machine with a GPU will result in much faster predictions.
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Image", width=300, height=300)
model_choice = gr.Radio(
choices=["ViT/B-16", "ViT/L-14@336px-base-text"],
label="Model",
value="ViT/B-16",
)
hypothesis_template = gr.Textbox(
label="Prompt Template",
placeholder="Optional prompt template as prefix",
value=""
)
# Label management section
label_input = gr.Textbox(label="Candidate Label", placeholder="Add a class label, one by one",)
add_btn = gr.Button("Add new Candidate Label")
with gr.Column(scale=1):
# Hidden textbox to store all labels
all_labels = gr.Textbox(label="Current Candidate Labels", interactive=False)
# Submit and Reset buttons side by side
with gr.Row():
reset_btn = gr.Button("Reset All", variant="secondary")
submit_btn = gr.Button("Submit", variant="primary")
# Output section
output = gr.Label(label="Predicted Scores")
# Event handlers
add_btn.click(
fn=add_label,
inputs=[label_input, all_labels],
outputs=[all_labels, label_input] # Now also clearing the input
)
# Reset all inputs
reset_btn.click(
fn=reset_all,
inputs=[],
outputs=[image_input, label_input, model_choice, hypothesis_template, all_labels, output]
)
# Only trigger classification on submit
submit_btn.click(
fn=shot,
inputs=[image_input, all_labels, model_choice, hypothesis_template],
outputs=[output]
)
# Add the examples
examples = [
["../docs/sample_images/brain_MRI.jpg",
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
"ViT/B-16", ""],
["../docs/sample_images/ct_scan_right_kidney.jpg",
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
"ViT/B-16", ""],
["../docs/sample_images/tumor_histo_pathology.jpg",
"benign tissue., malignant tumor., normal cells., inflammatory tissue.",
"ViT/B-16",
"The histopathology slide indicates"],
["../docs/sample_images/retina_glaucoma.jpg",
"CT scan of the right kidney., pneumonia disease in this chest X-ray image., a brain MRI., glaucoma in fundus image., a histopathology slide showing Tumor, Cardiomegaly disease in X-ray image of the chest.",
"ViT/B-16", "A photo of a"],
["../docs/sample_images/tumor_histo_pathology.jpg",
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
"ViT/B-16", ""],
["../docs/sample_images/xray_cardiomegaly.jpg",
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
"ViT/B-16", ""],
["../docs/sample_images//xray_pneumonia.png",
"CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
"ViT/B-16", ""],
]
gr.Examples(examples=examples, inputs=[image_input, all_labels, model_choice, hypothesis_template])
iface.launch(allowed_paths=["/home/user/app/docs/sample_images"])