In [None]:
# Switch path to root of project
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
# Get the current working directory
current_dir = os.getcwd()
src_path = os.path.join(current_dir, 'src')
os.chdir(src_path)

In [None]:
from open_clip import create_model_and_transforms, get_mean_std
from open_clip import HFTokenizer
from PIL import Image
import torch
from urllib.request import urlopen

In [None]:
# Define main parameters
model = 'ViT-L-14-336-quickgelu' # available pretrained weights ['ViT-L-14-336-quickgelu', 'ViT-B-16-quickgelu']
pretrained = "./unimed_clip_vit_l14_base_text_encoder.pt" # Path to pretrained weights
text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" # available pretrained weights ["microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract", "microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract"]
mean, std = get_mean_std()
device='cuda'

In [None]:
model, _, preprocess = create_model_and_transforms(
 model,
 pretrained,
 precision='amp',
 device=device,
 force_quick_gelu=True,
 pretrained_image=False,
 mean=mean, std=std,
 inmem=True,
 text_encoder_name=text_encoder_name,
)

In [None]:
tokenizer = HFTokenizer(
 text_encoder_name,
 context_length=256,
 **{},
)

In [None]:
# Zeroshot Inference

In [None]:
# Zero-shot image classification
template = 'this is a photo of '
labels = [
 'adenocarcinoma histopathology',
 'brain MRI',
 'covid line chart',
 'squamous cell carcinoma histopathology',
 'immunohistochemistry histopathology',
 'bone X-ray',
 'chest X-ray',
 'pie chart',
 'hematoxylin and eosin histopathology'
]

In [None]:
dataset_url = 'https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/'
test_imgs = [
 'squamous_cell_carcinoma_histopathology.jpeg',
 'H_and_E_histopathology.jpg',
 'bone_X-ray.jpg',
 'adenocarcinoma_histopathology.jpg',
 'covid_line_chart.png',
 'IHC_histopathology.jpg',
 'chest_X-ray.jpg',
 'brain_MRI.jpg',
 'pie_chart.png'
]

In [None]:
images = torch.stack([preprocess(Image.open(urlopen(dataset_url + img))) for img in test_imgs]).to(device)
texts = [tokenizer(template + cls_text).to(next(model.parameters()).device, non_blocking=True) for cls_text in labels]
texts = torch.cat(texts, dim=0)
with torch.no_grad():
 text_features = model.encode_text(texts)
 text_features = text_features / text_features.norm(dim=-1, keepdim=True)
 image_features = model.encode_image(images)
 image_features = image_features / image_features.norm(dim=-1, keepdim=True)
 logits = (image_features @ text_features.t()).detach().softmax(dim=-1)
 sorted_indices = torch.argsort(logits, dim=-1, descending=True)

 logits = logits.cpu().numpy()
 sorted_indices = sorted_indices.cpu().numpy()

top_k = -1

for i, img in enumerate(test_imgs):
 pred = labels[sorted_indices[i][0]]

 top_k = len(labels) if top_k == -1 else top_k
 print(img.split('/')[-1] + ':')
 for j in range(top_k):
 jth_index = sorted_indices[i][j]
 print(f'{labels[jth_index]}: {logits[i][jth_index]}')
 print('\n')