import gradio as gr import os import sys current_dir = os.getcwd() src_path = os.path.join(current_dir, 'src') os.chdir(src_path) sys.path.append(src_path) from open_clip import create_model_and_transforms from huggingface_hub import hf_hub_download from open_clip import HFTokenizer import torch # Your existing create_unimed_clip_model class remains the same class create_unimed_clip_model: def __init__(self, model_name): # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = 'cpu' mean = (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean std = (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std if model_name == "ViT/B-16": # Download the weights weights_path = hf_hub_download( repo_id="UzairK/unimed-clip-vit-b16", filename="unimed-clip-vit-b16.pt" ) self.pretrained = weights_path # Path to pretrained weights self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" self.model_name = "ViT-B-16-quickgelu" elif model_name == 'ViT/L-14@336px-base-text': # Download the weights self.model_name = "ViT-L-14-336-quickgelu" weights_path = hf_hub_download( repo_id="UzairK/unimed_clip_vit_l14_base_text_encoder", filename="unimed_clip_vit_l14_base_text_encoder.pt" ) self.pretrained = weights_path # Path to pretrained weights self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" self.tokenizer = HFTokenizer( self.text_encoder_name, context_length=256, **{}, ) self.model, _, self.processor = create_model_and_transforms( self.model_name, self.pretrained, precision='amp', device=self.device, force_quick_gelu=True, pretrained_image=False, mean=mean, std=std, inmem=True, text_encoder_name=self.text_encoder_name, ) def __call__(self, input_image, candidate_labels, hypothesis_template): # Preprocess input input_image = self.processor(input_image).unsqueeze(0).to(self.device) if hypothesis_template == "": texts = [ self.tokenizer(cls_text).to(self.device) for cls_text in candidate_labels ] else: texts = [ self.tokenizer(hypothesis_template + " " + cls_text).to(self.device) for cls_text in candidate_labels ] texts = torch.cat(texts, dim=0) # Perform inference with torch.no_grad(): text_features = self.model.encode_text(texts) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_features = self.model.encode_image(input_image) logits = (image_features @ text_features.t()).softmax(dim=-1).cpu().numpy() return {hypothesis_template + " " + cls_text: float(score) for cls_text, score in zip(candidate_labels, logits[0])} pipes = { "ViT/B-16": create_unimed_clip_model(model_name="ViT/B-16"), "ViT/L-14@336px-base-text": create_unimed_clip_model(model_name='ViT/L-14@336px-base-text'), } def reset_all(): return None, "", "ViT/B-16", "", "", {} def add_label(label, current_labels): if not label.strip(): return current_labels, label labels_list = current_labels.split(",") if current_labels else [] if label not in labels_list: labels_list.append(label.strip()) return ", ".join(labels_list), "" # Return updated labels and empty string for input def shot(image, labels_text, model_name, hypothesis_template): if not labels_text.strip() or not image: return {} labels = [label.strip() for label in labels_text.strip().split(",")] res = pipes[model_name]( input_image=image, candidate_labels=labels, hypothesis_template=hypothesis_template ) return {single_key: res[single_key] for single_key in res.keys()} with gr.Blocks() as iface: gr.Markdown(""" # Zero-shot Medical Image Classification with UniMed-CLIP Demo for UniMed CLIP, a family of strong Medical Contrastive VLMs trained on UniMed-dataset. For more information about our project, refer to our paper and github repository. Paper: [https://arxiv.org/abs/2412.10372](https://arxiv.org/abs/2412.10372) Github: [https://github.com/mbzuai-oryx/UniMed-CLIP](https://github.com/mbzuai-oryx/UniMed-CLIP) **[DEMO USAGE]** To begin with the demo, provide a picture (either upload manually, or select from the given examples) and add class labels one by one. Optionally you can also add template as a prefix to the class labels. **[NOTE]** This demo is running on CPU and thus the response time might be a bit slower. Running it on a machine with a GPU will result in much faster predictions. """) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Image", width=300, height=300) model_choice = gr.Radio( choices=["ViT/B-16", "ViT/L-14@336px-base-text"], label="Model", value="ViT/B-16", ) hypothesis_template = gr.Textbox( label="Prompt Template", placeholder="Optional prompt template as prefix", value="" ) # Label management section label_input = gr.Textbox(label="Candidate Label", placeholder="Add a class label, one by one",) add_btn = gr.Button("Add new Candidate Label") with gr.Column(scale=1): # Hidden textbox to store all labels all_labels = gr.Textbox(label="Current Candidate Labels", interactive=False) # Submit and Reset buttons side by side with gr.Row(): reset_btn = gr.Button("Reset All", variant="secondary") submit_btn = gr.Button("Submit", variant="primary") # Output section output = gr.Label(label="Predicted Scores") # Event handlers add_btn.click( fn=add_label, inputs=[label_input, all_labels], outputs=[all_labels, label_input] # Now also clearing the input ) # Reset all inputs reset_btn.click( fn=reset_all, inputs=[], outputs=[image_input, label_input, model_choice, hypothesis_template, all_labels, output] ) # Only trigger classification on submit submit_btn.click( fn=shot, inputs=[image_input, all_labels, model_choice, hypothesis_template], outputs=[output] ) # Add the examples examples = [ ["../docs/sample_images/brain_MRI.jpg", "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", "ViT/B-16", ""], ["../docs/sample_images/ct_scan_right_kidney.jpg", "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", "ViT/B-16", ""], ["../docs/sample_images/tumor_histo_pathology.jpg", "benign tissue., malignant tumor., normal cells., inflammatory tissue.", "ViT/B-16", "The histopathology slide indicates"], ["../docs/sample_images/retina_glaucoma.jpg", "CT scan of the right kidney., pneumonia disease in this chest X-ray image., a brain MRI., glaucoma in fundus image., a histopathology slide showing Tumor, Cardiomegaly disease in X-ray image of the chest.", "ViT/B-16", "A photo of a"], ["../docs/sample_images/tumor_histo_pathology.jpg", "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", "ViT/B-16", ""], ["../docs/sample_images/xray_cardiomegaly.jpg", "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", "ViT/B-16", ""], ["../docs/sample_images//xray_pneumonia.png", "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", "ViT/B-16", ""], ] gr.Examples(examples=examples, inputs=[image_input, all_labels, model_choice, hypothesis_template]) iface.launch(allowed_paths=["/home/user/app/docs/sample_images"])