{ "cells": [ { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# Switch path to root of project\n", "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", "# Get the current working directory\n", "current_dir = os.getcwd()\n", "src_path = os.path.join(current_dir, 'src')\n", "os.chdir(src_path)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "from open_clip import create_model_and_transforms, get_mean_std\n", "from open_clip import HFTokenizer\n", "from PIL import Image\n", "import torch\n", "from urllib.request import urlopen" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# Define main parameters\n", "model = 'ViT-L-14-336-quickgelu' # available pretrained weights ['ViT-L-14-336-quickgelu', 'ViT-B-16-quickgelu']\n", "pretrained = \"./unimed_clip_vit_l14_base_text_encoder.pt\" # Path to pretrained weights\n", "text_encoder_name = \"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract\" # available pretrained weights [\"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract\", \"microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract\"]\n", "mean, std = get_mean_std()\n", "device='cuda'" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "model, _, preprocess = create_model_and_transforms(\n", " model,\n", " pretrained,\n", " precision='amp',\n", " device=device,\n", " force_quick_gelu=True,\n", " pretrained_image=False,\n", " mean=mean, std=std,\n", " inmem=True,\n", " text_encoder_name=text_encoder_name,\n", ")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "tokenizer = HFTokenizer(\n", " text_encoder_name,\n", " context_length=256,\n", " **{},\n", ")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# Zeroshot Inference" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# Zero-shot image classification\n", "template = 'this is a photo of '\n", "labels = [\n", " 'adenocarcinoma histopathology',\n", " 'brain MRI',\n", " 'covid line chart',\n", " 'squamous cell carcinoma histopathology',\n", " 'immunohistochemistry histopathology',\n", " 'bone X-ray',\n", " 'chest X-ray',\n", " 'pie chart',\n", " 'hematoxylin and eosin histopathology'\n", "]" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "dataset_url = 'https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/'\n", "test_imgs = [\n", " 'squamous_cell_carcinoma_histopathology.jpeg',\n", " 'H_and_E_histopathology.jpg',\n", " 'bone_X-ray.jpg',\n", " 'adenocarcinoma_histopathology.jpg',\n", " 'covid_line_chart.png',\n", " 'IHC_histopathology.jpg',\n", " 'chest_X-ray.jpg',\n", " 'brain_MRI.jpg',\n", " 'pie_chart.png'\n", "]" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "images = torch.stack([preprocess(Image.open(urlopen(dataset_url + img))) for img in test_imgs]).to(device)\n", "texts = [tokenizer(template + cls_text).to(next(model.parameters()).device, non_blocking=True) for cls_text in labels]\n", "texts = torch.cat(texts, dim=0)\n", "with torch.no_grad():\n", " text_features = model.encode_text(texts)\n", " text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n", " image_features = model.encode_image(images)\n", " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", " logits = (image_features @ text_features.t()).detach().softmax(dim=-1)\n", " sorted_indices = torch.argsort(logits, dim=-1, descending=True)\n", "\n", " logits = logits.cpu().numpy()\n", " sorted_indices = sorted_indices.cpu().numpy()\n", "\n", "top_k = -1\n", "\n", "for i, img in enumerate(test_imgs):\n", " pred = labels[sorted_indices[i][0]]\n", "\n", " top_k = len(labels) if top_k == -1 else top_k\n", " print(img.split('/')[-1] + ':')\n", " for j in range(top_k):\n", " jth_index = sorted_indices[i][j]\n", " print(f'{labels[jth_index]}: {logits[i][jth_index]}')\n", " print('\\n')" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }