{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# General information\n",
"\n",
"Example colab for SigLIP models described in [the SigLIP paper](https://arxiv.org/abs/2303.15343).\n",
"\n",
"**These models are not official Google products and were trained and released for research purposes.**\n",
"\n",
"If you find our model(s) useful for your research, consider citing\n",
"\n",
"```\n",
"@article{zhai2023sigmoid,\n",
" title={Sigmoid loss for language image pre-training},\n",
" author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},\n",
" journal={International Conference on Computer Vision ({ICCV})},\n",
" year={2023}\n",
"}\n",
"```\n",
"\n",
"If you use our released models in your products, we will appreciate any direct feedback. We are reachable by xzhai@google.com, basilm@google.com, akolesnikov@google.com and lbeyer@google.com.\n",
"\n",
"\n",
"Only the models explicitly marked with `i18n` in the name are expected to perform reasonably well on non-english data."
],
"metadata": {
"id": "wR53lePHuiP-"
}
},
{
"cell_type": "code",
"source": [
"#@markdown # Environment setup\n",
"#@markdown **IMPORTANT NOTE**: Modern jax (>0.4) does not support the Colab TPU\n",
"#@markdown anymore, so don't select TPU runtime here. CPU and GPU work and are both fast enough.\n",
"\n",
"# Install the right jax version for TPU/GPU/CPU\n",
"import os\n",
"if 'COLAB_TPU_ADDR' in os.environ:\n",
" raise \"TPU colab not supported.\"\n",
"elif 'NVIDIA_PRODUCT_NAME' in os.environ:\n",
" !nvidia-smi\n",
"import jax\n",
"jax.devices()\n",
"\n",
"\n",
"# Get latest version of big_vision codebase.\n",
"!git clone --quiet --branch=main --depth=1 https://github.com/google-research/big_vision\n",
"!cd big_vision && git pull --rebase --quiet\n",
"!pip -q install -r big_vision/big_vision/requirements.txt\n",
"# Gives us ~2x faster gsutil cp to get the model checkpoints.\n",
"!pip3 -q install --no-cache-dir -U crcmod\n",
"\n",
"%cd big_vision\n",
"\n",
"\n",
"import numpy as np\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import ml_collections\n",
"\n",
"from google.colab.output import _publish as publish"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kXSdSXVg2PAI",
"outputId": "ba908946-0cd3-4468-9034-cd108529986f",
"cellView": "form"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Thu Sep 28 09:08:47 2023 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 75C P8 14W / 70W | 0MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n",
"fatal: destination path 'big_vision' already exists and is not an empty directory.\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"/content/big_vision\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Choose and load model, perform inference"
],
"metadata": {
"id": "byHpmgAO6inM"
}
},
{
"cell_type": "code",
"source": [
"# Pick your hero: (WHEN CHANGING THIS, RERUN IMAGE/TEXT EMBEDDING CELLS)\n",
"# Give this cell 1-3mins.\n",
"\n",
"# VARIANT, RES = 'B/16', 224\n",
"# VARIANT, RES = 'B/16', 256\n",
"# VARIANT, RES = 'B/16', 384\n",
"# VARIANT, RES = 'B/16', 512\n",
"# VARIANT, RES = 'L/16', 256\n",
"VARIANT, RES = 'L/16', 384\n",
"# VARIANT, RES = 'So400m/14', 224\n",
"# VARIANT, RES = 'So400m/14', 384\n",
"# VARIANT, RES = 'B/16-i18n', 256\n",
"\n",
"CKPT, TXTVARIANT, EMBDIM, SEQLEN, VOCAB = {\n",
" ('B/16', 224): ('webli_en_b16_224_63724782.npz', 'B', 768, 64, 32_000),\n",
" ('B/16', 256): ('webli_en_b16_256_60500360.npz', 'B', 768, 64, 32_000),\n",
" ('B/16', 384): ('webli_en_b16_384_68578854.npz', 'B', 768, 64, 32_000),\n",
" ('B/16', 512): ('webli_en_b16_512_68580893.npz', 'B', 768, 64, 32_000),\n",
" ('L/16', 256): ('webli_en_l16_256_60552751.npz', 'L', 1024, 64, 32_000),\n",
" ('L/16', 384): ('webli_en_l16_384_63634585.npz', 'L', 1024, 64, 32_000),\n",
" ('So400m/14', 224): ('webli_en_so400m_224_57633886.npz', 'So400m', 1152, 16, 32_000),\n",
" ('So400m/14', 384): ('webli_en_so400m_384_58765454.npz', 'So400m', 1152, 64, 32_000),\n",
" ('B/16-i18n', 256): ('webli_i18n_b16_256_66117334.npz', 'B', 768, 64, 250_000),\n",
" ('So400m/16', 256): ('webli_i18n_so400m_16_256_78061115.npz', 'So400m', 1152, 64, 250_000),\n",
"}[VARIANT, RES]\n",
"\n",
"# It is significantly faster to first copy the checkpoint (30s vs 8m30 for B and 1m vs ??? for L)\n",
"!test -f /tmp/{CKPT} || gsutil cp gs://big_vision/siglip/{CKPT} /tmp/\n",
"\n",
"if VARIANT.endswith('-i18n'):\n",
" VARIANT = VARIANT[:-len('-i18n')]\n",
"\n",
"import big_vision.models.proj.image_text.two_towers as model_mod\n",
"\n",
"model_cfg = ml_collections.ConfigDict()\n",
"model_cfg.image_model = 'vit' # TODO(lbeyer): remove later, default\n",
"model_cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default\n",
"model_cfg.image = dict(variant=VARIANT, pool_type='map')\n",
"model_cfg.text = dict(variant=TXTVARIANT, vocab_size=VOCAB)\n",
"model_cfg.out_dim = (None, EMBDIM) # (image_out_dim, text_out_dim)\n",
"model_cfg.bias_init = -10.0\n",
"model_cfg.temperature_init = 10.0\n",
"\n",
"model = model_mod.Model(**model_cfg)\n",
"\n",
"# Using `init_params` is slower but will lead to `load` below performing sanity-checks.\n",
"# init_params = jax.jit(model.init, backend=\"cpu\")(jax.random.PRNGKey(42), jnp.zeros([1, RES, RES, 3], jnp.float32), jnp.zeros([1, SEQLEN], jnp.int32))['params']\n",
"init_params = None # Faster but bypasses loading sanity-checks.\n",
"\n",
"params = model_mod.load(init_params, f'/tmp/{CKPT}', model_cfg)"
],
"metadata": {
"id": "0DsOabGD7MRG",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5afc9f52-7eb4-4a0d-b681-3ab5945ce9b4"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Copying gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz...\n",
"- [1 files][ 1.3 GiB/ 1.3 GiB] 45.3 MiB/s \n",
"Operation completed over 1 objects/1.3 GiB. \n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title Load and embed images\n",
"\n",
"import big_vision.pp.builder as pp_builder\n",
"import big_vision.pp.ops_general\n",
"import big_vision.pp.ops_image\n",
"import big_vision.pp.ops_text\n",
"import PIL\n",
"\n",
"!wget -q https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg\n",
"!wget -q https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg\n",
"!wget -q 'https://images.unsplash.com/photo-1566467021888-b03548769dd1?ixlib=rb-4.0.3&q=85&fm=jpg&crop=entropy&cs=srgb&dl=svetlana-gumerova-hQHm2D1fH70-unsplash.jpg&w=640' -O cold_drink.jpg\n",
"!wget -q 'https://images.rawpixel.com/image_1300/czNmcy1wcml2YXRlL3Jhd3BpeGVsX2ltYWdlcy93ZWJzaXRlX2NvbnRlbnQvbHIvdXB3azU4ODU5NzY1LXdpa2ltZWRpYS1pbWFnZS1rb3diMmhkeC5qcGc.jpg' -O hot_drink.jpg\n",
"!wget -q https://storage.googleapis.com/big_vision/siglip/authors.jpg\n",
"!wget -q https://storage.googleapis.com/big_vision/siglip/siglip.jpg\n",
"!wget -q https://storage.googleapis.com/big_vision/siglip/caffeine.jpg\n",
"!wget -q https://storage.googleapis.com/big_vision/siglip/robosign.jpg\n",
"!wget -q https://storage.googleapis.com/big_vision/siglip/fried_fish.jpeg\n",
"!wget -q 'https://pbs.twimg.com/media/FTyEyxyXsAAyKPc?format=jpg&name=small' -O cow_beach.jpg\n",
"!wget -q 'https://storage.googleapis.com/big_vision/siglip/cow_beach2.jpg' -O cow_beach2.jpg\n",
"!wget -q 'https://pbs.twimg.com/media/Frb6NIEXwAA8-fI?format=jpg&name=medium' -O mountain_view.jpg\n",
"\n",
"\n",
"images = [PIL.Image.open(fname) for fname in [\n",
" 'apple-ipod.jpg',\n",
" 'apple-blank.jpg',\n",
" 'cold_drink.jpg',\n",
" 'hot_drink.jpg',\n",
" 'caffeine.jpg',\n",
" 'siglip.jpg',\n",
" 'authors.jpg',\n",
" 'robosign.jpg',\n",
" 'cow_beach.jpg',\n",
" 'cow_beach2.jpg',\n",
" 'mountain_view.jpg',\n",
"]]\n",
"\n",
"pp_img = pp_builder.get_preprocess_fn(f'resize({RES})|value_range(-1, 1)')\n",
"imgs = np.array([pp_img({'image': np.array(image)})['image'] for image in images])\n",
"zimg, _, out = model.apply({'params': params}, imgs, None)\n",
"\n",
"print(imgs.shape, zimg.shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xmuXfCfBjgeF",
"outputId": "3627819b-007e-4107-e1f4-06b7ad3ac03a"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(11, 384, 384, 3) (11, 1024)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title Tokenize and embed texts\n",
"\n",
"texts = [\n",
" 'an apple',\n",
" 'a picture of an apple',\n",
" 'an ipod',\n",
" 'granny smith',\n",
" 'an apple with a note saying \"ipod\"',\n",
" 'a cold drink on a hot day',\n",
" 'a hot drink on a cold day',\n",
" 'a photo of a cold drink on a hot day',\n",
" 'a photo of a hot drink on a cold day',\n",
" #\n",
" 'a photo of two guys in need of caffeine',\n",
" 'a photo of two guys in need of water',\n",
" 'a photo of the SigLIP authors',\n",
" 'a photo of a rock band',\n",
" 'a photo of researchers at Google Brain',\n",
" 'a photo of researchers at OpenAI',\n",
" #\n",
" 'a robot on a sign',\n",
" 'a photo of a robot on a sign',\n",
" 'an empty street',\n",
" 'autumn in Toronto',\n",
" 'a photo of autumn in Toronto',\n",
" 'a photo of Toronto in autumn',\n",
" 'a photo of Toronto in summer',\n",
" 'autumn in Singapore',\n",
" #\n",
" 'cow',\n",
" 'a cow in a tuxedo',\n",
" 'a cow on the beach',\n",
" 'a cow in the prairie',\n",
" #\n",
" 'the real mountain view',\n",
" 'Zürich',\n",
" 'San Francisco',\n",
" 'a picture of a laptop with the lockscreen on, a cup of cappucino, salt and pepper grinders. The view through the window reveals lake Zürich and the Alps in the background of the city.',\n",
"]\n",
"\n",
"TOKENIZERS = {\n",
" 32_000: 'c4_en',\n",
" 250_000: 'mc4',\n",
"}\n",
"pp_txt = pp_builder.get_preprocess_fn(f'tokenize(max_len={SEQLEN}, model=\"{TOKENIZERS[VOCAB]}\", eos=\"sticky\", pad_value=1, inkey=\"text\")')\n",
"txts = np.array([pp_txt({'text': text})['labels'] for text in texts])\n",
"_, ztxt, out = model.apply({'params': params}, None, txts)\n",
"\n",
"print(txts.shape, ztxt.shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KGrpkRTtjU-L",
"outputId": "7c43b56e-cd53-4801-b1e3-66774368a1d2"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(31, 64) (31, 1024)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# This is how to get all probabilities:\n",
"print(f\"Learned temperature {out['t'].item():.1f}, learned bias: {out['b'].item():.1f}\")\n",
"probs = jax.nn.sigmoid(zimg @ ztxt.T * out['t'] + out['b'])\n",
"print(f\"{probs[0][0]:.1%} that image 0 is '{texts[0]}'\")\n",
"print(f\"{probs[0][1]:.1%} that image 0 is '{texts[1]}'\")"
],
"metadata": {
"id": "TIdAVw9VGEAw",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "22fc0d9a-8986-4679-ca89-6e4330a55c6e"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Learned temperature 118.2, learned bias: -12.7\n",
"10.4% that image 0 is 'an apple'\n",
"42.8% that image 0 is 'a picture of an apple'\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# @title Pretty demo (code)\n",
"from IPython.display import Javascript\n",
"\n",
"DEMO_IMG_SIZE = 96\n",
"\n",
"import base64\n",
"import io\n",
"\n",
"def bv2rgb(bv_img):\n",
" return (bv_img * 127.5 + 127.5).astype(np.uint8)\n",
"\n",
"def html_img(*, enc_img=None, pixels=None, id=None, size=100, max_size=None, max_height=None, style=\"\"):\n",
" if enc_img is None and pixels is not None:\n",
" with io.BytesIO() as buf:\n",
" PIL.Image.fromarray(np.asarray(pixels)).save(buf, format=\"JPEG\")\n",
" enc_img = buf.getvalue()\n",
"\n",
" img_data = base64.b64encode(np.ascontiguousarray(enc_img)).decode('ascii')\n",
"\n",
" id_spec = f'id={id}' if id else ''\n",
" if size is not None:\n",
" style_spec = f'style=\"{style}; width: {size}px; height: {size}px\"'\n",
" elif max_size is not None:\n",
" style_spec = f'style=\"{style}; width: auto; height: auto; max-width: {max_size}px; max-height: {max_size}px;\"'\n",
" elif max_height is not None:\n",
" style_spec = f'style=\"{style}; object-fit: cover; width: auto; height: {max_height}px;\"'\n",
" else: style_spec = ''\n",
"\n",
" return f''\n",
"\n",
"\n",
"def make_table(zimg, ztxt, out):\n",
" # The default learnable bias is a little conservative. Play around with it!\n",
" t, b = out['t'].item(), out['b'].item()\n",
" tempered_logits = zimg @ ztxt.T * t\n",
" probs = 1 / (1 + np.exp(-tempered_logits - b))\n",
" publish.javascript(f\"var logits = {tempered_logits.tolist()};\")\n",
"\n",
" def color(p):\n",
" return mpl.colors.rgb2hex(mpl.cm.Greens(p / 2)) if p >= 0.01 else \"transparent\"\n",
"\n",
" publish.javascript(f\"var cmap = {[color(x) for x in np.linspace(0, 1, 50)]};\")\n",
" def cell(x, iimg, itxt):\n",
" return f\"
{x * 100:>4.0f}%\"\n", "\n", " html = f'''\n", "
\n", " \n", " \n", " \n", "
\n", " '''\n", "\n", " html += \"\" + html_img(pixels=bv2rgb(img), size=DEMO_IMG_SIZE) for img in imgs])\n", " html += \" | \"\n", " for itxt, txt in enumerate(texts):\n", " html += f\" | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{txt}\"\n",
"\n",
" publish.css(r\"\"\"\n",
" table {\n",
" border-collapse: collapse;\n",
" }\n",
"\n",
" tr {\n",
" border: 1px transparent;\n",
" }\n",
"\n",
" tr:nth-child(odd) {\n",
" background-color: #F5F5F5;\n",
" }\n",
"\n",
" tr:hover {\n",
" background-color: lightyellow;\n",
" border: 1px solid black;\n",
" }\n",
"\n",
" td.pct {\n",
" text-align: center;\n",
" }\n",
" \"\"\")\n",
" publish.html(html)\n",
"\n",
" # JS code to compute and write all probs from the logits.\n",
" display(Javascript('''\n",
" function update(b) {\n",
" for(var iimg = 0; iimg < logits.length; iimg++) {\n",
" for(var itxt = 0; itxt < logits[iimg].length; itxt++) {\n",
" const el = document.getElementById(`p_${iimg}_${itxt}`);\n",
" const p = Math.round(100 / (1 + Math.exp(-logits[iimg][itxt] - b)));\n",
" const pad = p < 10.0 ? ' ' : p < 100.0 ? ' ' : ''\n",
" el.innerHTML = pad + (p).toFixed(0) + '%';\n",
"\n",
" const td = document.getElementById(`td_${iimg}_${itxt}`);\n",
" const c = cmap[Math.round(p / 100 * (cmap.length - 1))];\n",
" td.style.backgroundColor = c;\n",
" }\n",
" }\n",
" }\n",
" '''))\n",
"\n",
" # JS code to connect the bias value slider\n",
" display(Javascript('''\n",
" const value = document.querySelector(\"#value\");\n",
" const input = document.querySelector(\"#b\");\n",
" value.textContent = input.value;\n",
" input.addEventListener(\"input\", (event) => {\n",
" value.textContent = event.target.value;\n",
" update(event.target.value);\n",
" });\n",
" '''))\n",
"\n",
" # Make the cell output as large as the table to avoid annoying scrollbars.\n",
" display(Javascript(f'update({b})'))\n",
" display(Javascript('google.colab.output.resizeIframeToContent()'))"
],
"metadata": {
"cellView": "form",
"id": "eolOc7vd_ZSj"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"make_table(zimg, ztxt, out)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 767
},
"id": "mt5BIywzzA6c",
"outputId": "3b06cfb9-a3da-42d7-8caf-d5366d058f8b"
},
"execution_count": 14,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
" \n", " \n", " \n", " \n", " \n", "
|