{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data set already exists in the local drive. Loading it.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import pickle\n",
    "from datasets import load_dataset\n",
    "\n",
    "curr_dir = Path(os.getcwd())\n",
    "data_dir = curr_dir / 'data'\n",
    "if not os.path.exists(data_dir):\n",
    "    os.mkdir(data_dir)\n",
    "data_pickle_path = data_dir / 'data_set.pkl'\n",
    "\n",
    "if not os.path.exists(data_pickle_path):\n",
    "    print(f\"Data set hasn't been loaded. Loading from the datasets library and save it as a pickle.\")\n",
    "    data_set = load_dataset(\"vipulmaheshwari/GTA-Image-Captioning-Dataset\")\n",
    "    with open(data_pickle_path, 'wb') as outfile:\n",
    "        pickle.dump(data_set, outfile)\n",
    "else:\n",
    "    print(f\"Data set already exists in the local drive. Loading it.\")\n",
    "    with open(data_pickle_path, 'rb') as infile:\n",
    "        data_set = pickle.load(infile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(data_set)\n",
    "# len(data_set['train']['image']), len(data_set['train']['text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Source: https://huggingface.co/sentence-transformers/clip-ViT-L-14\n",
    "\n",
    "from sentence_transformers import SentenceTransformer, util\n",
    "# from PIL import Image\n",
    "\n",
    "#Load CLIP model\n",
    "model = SentenceTransformer(\"sentence-transformers/clip-ViT-L-14\") # SentenceTransformer('clip-ViT-L-14')\n",
    "\n",
    "#Encode an image:\n",
    "# img_emb = model.encode(image) # Image.open('two_dogs_in_snow.jpg')\n",
    "\n",
    "# #Encode text descriptions\n",
    "# text_emb = model.encode(text) # ['Two dogs in the snow', 'A cat on a table', 'A picture of London at night']\n",
    "\n",
    "# #Compute cosine similarities \n",
    "# cos_scores = util.cos_sim(img_emb, text_emb)\n",
    "# print(cos_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_embeddings = []\n",
    "for image in tqdm(data_set['train']['image'][:2]):\n",
    "    img_embedding = model.encode(image)\n",
    "    img_embeddings.append(img_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# try FAISS. Chroma, Pinecone (check the GAFS project)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyarrow as pa\n",
    "import lancedb\n",
    "\n",
    "db = lancedb.connect('./data/tables')\n",
    "schema = pa.schema(\n",
    "  [\n",
    "      pa.field(\"vector\", pa.list_(pa.float32())),\n",
    "      # pa.field(\"text\", pa.string()),\n",
    "      # pa.field(\"id\", pa.int32())\n",
    "  ])\n",
    "# tbl = db.create_table(\"gta_data\", schema=schema, mode=\"overwrite\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:15<00:00,  7.65s/it]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "img_embeddings = []\n",
    "for image in tqdm(data_set['train']['image'][:2]):\n",
    "    img_embedding = model.encode(image)\n",
    "    img_embeddings.append(img_embedding)\n",
    "\n",
    "tbl_data = pa.Table.from_arrays([pa.array(img_embeddings)], [\"vector\"])\n",
    "tbl = db.create_table(\"gta_data\", tbl_data, schema=schema, mode=\"overwrite\")\n",
    "\n",
    "# tbl.add(img_embeddings)\n",
    "# tbl.to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Query column vector must be a vector. Got list<item: float>.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[63], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mtbl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msearch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43ma road with a stop\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector_column_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvector\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlimit\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_pandas\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      2\u001b[0m res\n",
      "File \u001b[1;32mc:\\Users\\Admin\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\grandtheftauto-multimodal-rag-application-ufxwo2j--py3.11\\Lib\\site-packages\\lancedb\\query.py:262\u001b[0m, in \u001b[0;36mLanceQueryBuilder.to_pandas\u001b[1;34m(self, flatten)\u001b[0m\n\u001b[0;32m    247\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mto_pandas\u001b[39m(\u001b[38;5;28mself\u001b[39m, flatten: Optional[Union[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;28mbool\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpd.DataFrame\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m    248\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    249\u001b[0m \u001b[38;5;124;03m    Execute the query and return the results as a pandas DataFrame.\u001b[39;00m\n\u001b[0;32m    250\u001b[0m \u001b[38;5;124;03m    In addition to the selected columns, LanceDB also returns a vector\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    260\u001b[0m \u001b[38;5;124;03m        If unspecified, do not flatten the nested columns.\u001b[39;00m\n\u001b[0;32m    261\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m--> 262\u001b[0m     tbl \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_arrow\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    263\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m flatten \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m    264\u001b[0m         \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n",
      "File \u001b[1;32mc:\\Users\\Admin\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\grandtheftauto-multimodal-rag-application-ufxwo2j--py3.11\\Lib\\site-packages\\lancedb\\query.py:527\u001b[0m, in \u001b[0;36mLanceVectorQueryBuilder.to_arrow\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    518\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mto_arrow\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m pa\u001b[38;5;241m.\u001b[39mTable:\n\u001b[0;32m    519\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    520\u001b[0m \u001b[38;5;124;03m    Execute the query and return the results as an\u001b[39;00m\n\u001b[0;32m    521\u001b[0m \u001b[38;5;124;03m    [Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    525\u001b[0m \u001b[38;5;124;03m    vector and the returned vectors.\u001b[39;00m\n\u001b[0;32m    526\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m--> 527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_batches\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mread_all()\n",
      "File \u001b[1;32mc:\\Users\\Admin\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\grandtheftauto-multimodal-rag-application-ufxwo2j--py3.11\\Lib\\site-packages\\lancedb\\query.py:557\u001b[0m, in \u001b[0;36mLanceVectorQueryBuilder.to_batches\u001b[1;34m(self, batch_size)\u001b[0m\n\u001b[0;32m    544\u001b[0m     vector \u001b[38;5;241m=\u001b[39m [v\u001b[38;5;241m.\u001b[39mtolist() \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m vector]\n\u001b[0;32m    545\u001b[0m query \u001b[38;5;241m=\u001b[39m Query(\n\u001b[0;32m    546\u001b[0m     vector\u001b[38;5;241m=\u001b[39mvector,\n\u001b[0;32m    547\u001b[0m     \u001b[38;5;28mfilter\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_where,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    555\u001b[0m     with_row_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_with_row_id,\n\u001b[0;32m    556\u001b[0m )\n\u001b[1;32m--> 557\u001b[0m result_set \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_table\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execute_query\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    558\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reranker \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    559\u001b[0m     rs_table \u001b[38;5;241m=\u001b[39m result_set\u001b[38;5;241m.\u001b[39mread_all()\n",
      "File \u001b[1;32mc:\\Users\\Admin\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\grandtheftauto-multimodal-rag-application-ufxwo2j--py3.11\\Lib\\site-packages\\lancedb\\table.py:1616\u001b[0m, in \u001b[0;36mLanceTable._execute_query\u001b[1;34m(self, query, batch_size)\u001b[0m\n\u001b[0;32m   1612\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_execute_query\u001b[39m(\n\u001b[0;32m   1613\u001b[0m     \u001b[38;5;28mself\u001b[39m, query: Query, batch_size: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1614\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m pa\u001b[38;5;241m.\u001b[39mRecordBatchReader:\n\u001b[0;32m   1615\u001b[0m     ds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mto_lance()\n\u001b[1;32m-> 1616\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscanner\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1617\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1618\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mfilter\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1619\u001b[0m \u001b[43m        \u001b[49m\u001b[43mprefilter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprefilter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1620\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnearest\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\n\u001b[0;32m   1621\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcolumn\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvector_column\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1622\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mq\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1623\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mk\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1624\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmetric\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1625\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnprobes\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnprobes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1626\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrefine_factor\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrefine_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1627\u001b[0m \u001b[43m        \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1628\u001b[0m \u001b[43m        \u001b[49m\u001b[43mwith_row_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwith_row_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1629\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1630\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto_reader()\n",
      "File \u001b[1;32mc:\\Users\\Admin\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\grandtheftauto-multimodal-rag-application-ufxwo2j--py3.11\\Lib\\site-packages\\lance\\dataset.py:321\u001b[0m, in \u001b[0;36mLanceDataset.scanner\u001b[1;34m(self, columns, filter, limit, offset, nearest, batch_size, batch_readahead, fragment_readahead, scan_in_order, fragments, prefilter, with_row_id, use_stats)\u001b[0m\n\u001b[0;32m    305\u001b[0m builder \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m    306\u001b[0m     ScannerBuilder(\u001b[38;5;28mself\u001b[39m)\n\u001b[0;32m    307\u001b[0m     \u001b[38;5;241m.\u001b[39mcolumns(columns)\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    318\u001b[0m     \u001b[38;5;241m.\u001b[39muse_stats(use_stats)\n\u001b[0;32m    319\u001b[0m )\n\u001b[0;32m    320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m nearest \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 321\u001b[0m     builder \u001b[38;5;241m=\u001b[39m \u001b[43mbuilder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnearest\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnearest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder\u001b[38;5;241m.\u001b[39mto_scanner()\n",
      "File \u001b[1;32mc:\\Users\\Admin\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\grandtheftauto-multimodal-rag-application-ufxwo2j--py3.11\\Lib\\site-packages\\lance\\dataset.py:2049\u001b[0m, in \u001b[0;36mScannerBuilder.nearest\u001b[1;34m(self, column, q, k, metric, nprobes, refine_factor, use_index)\u001b[0m\n\u001b[0;32m   2047\u001b[0m     column_type \u001b[38;5;241m=\u001b[39m column_type\u001b[38;5;241m.\u001b[39mstorage_type\n\u001b[0;32m   2048\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m pa\u001b[38;5;241m.\u001b[39mtypes\u001b[38;5;241m.\u001b[39mis_fixed_size_list(column_type):\n\u001b[1;32m-> 2049\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[0;32m   2050\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuery column \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcolumn\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be a vector. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcolumn_field\u001b[38;5;241m.\u001b[39mtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   2051\u001b[0m     )\n\u001b[0;32m   2052\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(q) \u001b[38;5;241m!=\u001b[39m column_type\u001b[38;5;241m.\u001b[39mlist_size:\n\u001b[0;32m   2053\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m   2054\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuery vector size \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(q)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not match index column size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   2055\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcolumn_type\u001b[38;5;241m.\u001b[39mlist_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   2056\u001b[0m     )\n",
      "\u001b[1;31mTypeError\u001b[0m: Query column vector must be a vector. Got list<item: float>."
     ]
    }
   ],
   "source": [
    "res = tbl.search(model.encode(\"a road with a stop\"), vector_column_name=\"vector\").limit(3).to_pandas()\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://huggingface.co/openai/clip-vit-large-patch14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import clip\n",
    "import torch\n",
    "import os\n",
    "from datasets import load_dataset\n",
    "\n",
    "# ds = load_dataset(\"vipulmaheshwari/GTA-Image-Captioning-Dataset\")\n",
    "# device = torch.device(\"mps\")\n",
    "model, preprocess = clip.load(\"ViT-L/14\") # , device=device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "768"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def embed_txt(txt):\n",
    "    tokenized_text = clip.tokenize([txt])\n",
    "    embeddings = model.encode_text(tokenized_text)\n",
    "    \n",
    "    # Detach, move to CPU, convert to numpy array, and extract the first element as a list\n",
    "    result = embeddings.detach().numpy()[0].tolist()\n",
    "    return result\n",
    "\n",
    "len(embed_txt(\"a road with a stop\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1.172108769416809,\n",
       " 0.5741956830024719,\n",
       " -0.11420677602291107,\n",
       " -0.5107784271240234,\n",
       " -0.7742195725440979,\n",
       " 0.7895426750183105,\n",
       " 0.31811264157295227,\n",
       " 0.5389135479927063,\n",
       " 0.17074763774871826,\n",
       " -1.0352754592895508,\n",
       " -0.013449656777083874,\n",
       " -0.5795634388923645,\n",
       " -0.37020763754844666,\n",
       " -0.7534741163253784,\n",
       " 0.6788989901542664,\n",
       " -0.1245330423116684,\n",
       " 1.0375893115997314,\n",
       " -0.08196641504764557,\n",
       " 0.169560506939888,\n",
       " -0.3306411802768707,\n",
       " 0.6850194931030273,\n",
       " -0.4113234281539917,\n",
       " -0.3725243806838989,\n",
       " -0.8902166485786438,\n",
       " -0.2419223040342331,\n",
       " 0.33643779158592224,\n",
       " 0.18724264204502106,\n",
       " 0.6745221018791199,\n",
       " 0.00899740681052208,\n",
       " -0.29769381880760193,\n",
       " 0.6830898523330688,\n",
       " 0.7002785205841064,\n",
       " 0.5598942041397095,\n",
       " -0.27884775400161743,\n",
       " 0.29804039001464844,\n",
       " 0.4663200378417969,\n",
       " -0.40516427159309387,\n",
       " -0.2796509861946106,\n",
       " -0.3568377196788788,\n",
       " 0.7982958555221558,\n",
       " 1.0218019485473633,\n",
       " -0.3191905915737152,\n",
       " -0.8690600395202637,\n",
       " -0.5986450910568237,\n",
       " 0.6520456671714783,\n",
       " 0.8482719659805298,\n",
       " 0.45436325669288635,\n",
       " -0.24868743121623993,\n",
       " -0.22428922355175018,\n",
       " -0.3995105028152466,\n",
       " 0.1387435346841812,\n",
       " 0.030430370941758156,\n",
       " 0.1954972743988037,\n",
       " 0.36345618963241577,\n",
       " 0.23408269882202148,\n",
       " 0.030055442824959755,\n",
       " -0.13948054611682892,\n",
       " -0.6816356778144836,\n",
       " -0.2554306387901306,\n",
       " -0.8186500668525696,\n",
       " 0.0802079439163208,\n",
       " -0.28623825311660767,\n",
       " 0.889072060585022,\n",
       " 0.3205733895301819,\n",
       " 1.4578713178634644,\n",
       " 0.5289382934570312,\n",
       " -0.9107804894447327,\n",
       " -0.1899547427892685,\n",
       " -0.39814451336860657,\n",
       " 0.07741428166627884,\n",
       " 0.00696764700114727,\n",
       " 0.8374080657958984,\n",
       " 0.17547933757305145,\n",
       " -0.6835469007492065,\n",
       " 0.44190704822540283,\n",
       " -0.258558452129364,\n",
       " -0.16306370496749878,\n",
       " 0.17053553462028503,\n",
       " 0.8770076036453247,\n",
       " 0.2896091341972351,\n",
       " -0.2233574390411377,\n",
       " -0.30297425389289856,\n",
       " -0.7410178780555725,\n",
       " 0.010058385320007801,\n",
       " -0.7731197476387024,\n",
       " -0.2569619417190552,\n",
       " 0.05559535324573517,\n",
       " 0.6135262846946716,\n",
       " -0.5267459154129028,\n",
       " -0.14416567981243134,\n",
       " 0.3300650715827942,\n",
       " 0.3322101831436157,\n",
       " 0.260479211807251,\n",
       " -0.6002621054649353,\n",
       " 0.033296529203653336,\n",
       " 0.5030784010887146,\n",
       " -0.5291236042976379,\n",
       " 0.11839054524898529,\n",
       " -0.2279912680387497,\n",
       " -0.24884033203125,\n",
       " -0.27888786792755127,\n",
       " -0.1304142028093338,\n",
       " 0.1286783516407013,\n",
       " 0.15377336740493774,\n",
       " 0.5802848935127258,\n",
       " -0.3416184186935425,\n",
       " -0.41235557198524475,\n",
       " 0.04911366105079651,\n",
       " 0.28588297963142395,\n",
       " 1.097459316253662,\n",
       " 0.8836804628372192,\n",
       " -0.06680312007665634,\n",
       " 0.5119672417640686,\n",
       " 0.1433386206626892,\n",
       " 0.3975537121295929,\n",
       " 0.751021683216095,\n",
       " -0.5127158761024475,\n",
       " -1.0673898458480835,\n",
       " -0.810725212097168,\n",
       " -0.9325631260871887,\n",
       " 0.28165996074676514,\n",
       " -1.1700552701950073,\n",
       " -0.6979520916938782,\n",
       " 0.09645866602659225,\n",
       " -0.15432433784008026,\n",
       " -0.6545705199241638,\n",
       " -0.2297753095626831,\n",
       " 0.9147917628288269,\n",
       " -0.3901214897632599,\n",
       " -0.08340626955032349,\n",
       " -0.0342048779129982,\n",
       " 0.4271363615989685,\n",
       " 0.3410806655883789,\n",
       " -0.14932666718959808,\n",
       " 0.05415431410074234,\n",
       " -0.5995809435844421,\n",
       " -0.33829835057258606,\n",
       " -0.23623280227184296,\n",
       " -0.5740441679954529,\n",
       " 0.3325800895690918,\n",
       " -0.18519632518291473,\n",
       " -0.26904159784317017,\n",
       " 0.03128799423575401,\n",
       " 0.15838740766048431,\n",
       " -0.003409828059375286,\n",
       " -0.2664038836956024,\n",
       " -0.6785658597946167,\n",
       " 0.4431314170360565,\n",
       " -0.38189026713371277,\n",
       " 0.5427551865577698,\n",
       " 0.5074883103370667,\n",
       " -0.186558797955513,\n",
       " 0.08342668414115906,\n",
       " 0.04791847988963127,\n",
       " -0.1341174989938736,\n",
       " 0.8764032125473022,\n",
       " -0.10158982127904892,\n",
       " 0.9622796177864075,\n",
       " -0.058163080364465714,\n",
       " -1.0029855966567993,\n",
       " -0.22422465682029724,\n",
       " 1.2381765842437744,\n",
       " 0.17981192469596863,\n",
       " 0.034056372940540314,\n",
       " -0.2695963978767395,\n",
       " -0.21056877076625824,\n",
       " -0.3712306320667267,\n",
       " 0.17336499691009521,\n",
       " 0.5278773903846741,\n",
       " 0.7908108234405518,\n",
       " -1.034334659576416,\n",
       " -0.5650461912155151,\n",
       " -0.7466263175010681,\n",
       " -0.16805803775787354,\n",
       " 0.39045724272727966,\n",
       " -0.5074604749679565,\n",
       " 0.29658886790275574,\n",
       " -0.1186276450753212,\n",
       " 0.7888982892036438,\n",
       " -0.00017159162962343544,\n",
       " 0.9989897608757019,\n",
       " 0.21528062224388123,\n",
       " 0.3544112741947174,\n",
       " -0.18352235853672028,\n",
       " -0.5933219790458679,\n",
       " -0.4221193492412567,\n",
       " 0.20716431736946106,\n",
       " 0.026883812621235847,\n",
       " 1.2931787967681885,\n",
       " 0.3020362854003906,\n",
       " 0.26052647829055786,\n",
       " 0.056001197546720505,\n",
       " -0.5442985892295837,\n",
       " -0.24692402780056,\n",
       " -0.04342973232269287,\n",
       " 0.32930392026901245,\n",
       " -0.7617244124412537,\n",
       " 0.26960083842277527,\n",
       " 0.29244083166122437,\n",
       " -0.2099844217300415,\n",
       " 0.2785693407058716,\n",
       " 0.07669660449028015,\n",
       " -0.1421067714691162,\n",
       " 0.46162599325180054,\n",
       " 0.3855959475040436,\n",
       " 0.27650055289268494,\n",
       " -0.44994688034057617,\n",
       " -0.28603509068489075,\n",
       " -0.5041812062263489,\n",
       " -0.3805933892726898,\n",
       " 0.5895918011665344,\n",
       " 0.6383715867996216,\n",
       " -0.08397688716650009,\n",
       " 0.22880668938159943,\n",
       " -0.25133225321769714,\n",
       " 0.2853071093559265,\n",
       " -0.0931459441781044,\n",
       " 0.3020959496498108,\n",
       " 0.24055352807044983,\n",
       " 0.18953140079975128,\n",
       " -0.17559008300304413,\n",
       " 0.11638100445270538,\n",
       " 0.5736441612243652,\n",
       " 0.34651291370391846,\n",
       " 0.0011261674808338284,\n",
       " 0.6858928203582764,\n",
       " -0.3585776090621948,\n",
       " 0.21113723516464233,\n",
       " -0.451948344707489,\n",
       " -0.6812528371810913,\n",
       " -0.37171897292137146,\n",
       " -0.11487153172492981,\n",
       " -0.7819438576698303,\n",
       " 0.2523130476474762,\n",
       " -0.006692436058074236,\n",
       " 0.5665392279624939,\n",
       " -0.5619456768035889,\n",
       " 0.06306441873311996,\n",
       " 0.21295419335365295,\n",
       " 0.5865535140037537,\n",
       " 0.27423301339149475,\n",
       " 0.2840102016925812,\n",
       " -0.37136274576187134,\n",
       " 0.016866570338606834,\n",
       " 0.2263607531785965,\n",
       " 0.43608683347702026,\n",
       " -0.4567808508872986,\n",
       " 0.9201197028160095,\n",
       " -0.28868433833122253,\n",
       " 0.2835354208946228,\n",
       " 0.5691022276878357,\n",
       " -0.24377702176570892,\n",
       " 0.5043097138404846,\n",
       " -0.41853949427604675,\n",
       " 0.03636287525296211,\n",
       " -0.07350795716047287,\n",
       " -0.06902104616165161,\n",
       " 0.32698169350624084,\n",
       " -0.24132660031318665,\n",
       " 0.0912783071398735,\n",
       " -1.047544002532959,\n",
       " -0.8717364072799683,\n",
       " -0.8879557847976685,\n",
       " 0.301925927400589,\n",
       " -1.2747677564620972,\n",
       " 0.10643213242292404,\n",
       " 0.050040390342473984,\n",
       " -0.6990651488304138,\n",
       " 0.4598444104194641,\n",
       " -0.2630557417869568,\n",
       " 0.3260715901851654,\n",
       " 0.15428033471107483,\n",
       " 0.10122397541999817,\n",
       " 0.07699556648731232,\n",
       " 0.06605273485183716,\n",
       " -0.2160506695508957,\n",
       " -0.1665394902229309,\n",
       " -0.5145867466926575,\n",
       " -0.8410879373550415,\n",
       " -0.3635564148426056,\n",
       " -0.14213085174560547,\n",
       " -0.3718281686306,\n",
       " -0.2025422751903534,\n",
       " -0.45895904302597046,\n",
       " 0.16690057516098022,\n",
       " -0.29905644059181213,\n",
       " 0.03865504637360573,\n",
       " 0.23067855834960938,\n",
       " 0.23403894901275635,\n",
       " -0.3748420774936676,\n",
       " -0.4377340078353882,\n",
       " -0.6237973570823669,\n",
       " -0.5650405287742615,\n",
       " -0.12215842306613922,\n",
       " -0.23550915718078613,\n",
       " -0.030611969530582428,\n",
       " 0.1457085907459259,\n",
       " 0.39134201407432556,\n",
       " 0.7538257241249084,\n",
       " -0.5013869404792786,\n",
       " -0.22639918327331543,\n",
       " 0.324470579624176,\n",
       " 0.2524488568305969,\n",
       " -0.6817197799682617,\n",
       " -0.1683609038591385,\n",
       " 0.09771472215652466,\n",
       " -0.324865460395813,\n",
       " 0.38337022066116333,\n",
       " -0.148436039686203,\n",
       " 0.7256155610084534,\n",
       " -0.9280087947845459,\n",
       " -0.6846877336502075,\n",
       " -0.37772396206855774,\n",
       " 0.03854738548398018,\n",
       " -0.5223367214202881,\n",
       " 0.04659451171755791,\n",
       " -1.2525877952575684,\n",
       " 0.15308304131031036,\n",
       " -0.2739616334438324,\n",
       " 0.07301849126815796,\n",
       " 0.7795864939689636,\n",
       " -0.2228480577468872,\n",
       " -0.35411256551742554,\n",
       " -0.6261951923370361,\n",
       " 0.20154286921024323,\n",
       " -0.02966398000717163,\n",
       " -0.7075097560882568,\n",
       " -0.45100030303001404,\n",
       " -0.5318045020103455,\n",
       " 0.22182771563529968,\n",
       " 0.08000355958938599,\n",
       " 0.16378679871559143,\n",
       " 0.33453676104545593,\n",
       " -0.20498014986515045,\n",
       " -0.5192173719406128,\n",
       " 0.3957352936267853,\n",
       " -0.21540209650993347,\n",
       " -0.26865679025650024,\n",
       " -0.9579092264175415,\n",
       " 0.29295825958251953,\n",
       " 0.07182762026786804,\n",
       " 0.2812371850013733,\n",
       " 0.5159787535667419,\n",
       " -0.1598782241344452,\n",
       " -0.02911016158759594,\n",
       " 0.10978005081415176,\n",
       " -1.152063012123108,\n",
       " -1.075944423675537,\n",
       " -0.19859834015369415,\n",
       " 0.48424282670021057,\n",
       " -0.3020830452442169,\n",
       " 0.0681198462843895,\n",
       " -0.03712642937898636,\n",
       " -0.26295045018196106,\n",
       " 0.23075002431869507,\n",
       " 0.03392830863595009,\n",
       " 0.5592344999313354,\n",
       " 0.27158620953559875,\n",
       " 0.08701741695404053,\n",
       " -0.2469501793384552,\n",
       " 0.7389507293701172,\n",
       " 0.3184473216533661,\n",
       " -0.5283591151237488,\n",
       " -0.35726648569107056,\n",
       " 0.2647046446800232,\n",
       " 0.06684468686580658,\n",
       " -0.4558630883693695,\n",
       " -0.3814390301704407,\n",
       " 0.6464404463768005,\n",
       " -0.3603093922138214,\n",
       " -0.7406730651855469,\n",
       " -0.06739675253629684,\n",
       " 0.3286390006542206,\n",
       " 0.07030770927667618,\n",
       " 0.20259763300418854,\n",
       " -0.18537510931491852,\n",
       " 0.39111021161079407,\n",
       " -0.1252942532300949,\n",
       " 0.1268956959247589,\n",
       " -0.10496045649051666,\n",
       " 1.1690759658813477,\n",
       " 0.23655962944030762,\n",
       " 0.2556387782096863,\n",
       " -0.30134761333465576,\n",
       " -0.3626421391963959,\n",
       " -0.35505855083465576,\n",
       " -0.22458982467651367,\n",
       " -0.40729954838752747,\n",
       " -0.40974897146224976,\n",
       " 0.028972748667001724,\n",
       " 0.6284871101379395,\n",
       " 0.3097871243953705,\n",
       " -0.1652112752199173,\n",
       " 1.0627437829971313,\n",
       " -0.6887637376785278,\n",
       " -0.031500522047281265,\n",
       " -0.0873744785785675,\n",
       " -0.9616701006889343,\n",
       " 0.3587159216403961,\n",
       " 0.1391131579875946,\n",
       " -0.19815994799137115,\n",
       " 0.7807681560516357,\n",
       " 0.2649019658565521,\n",
       " -0.48934823274612427,\n",
       " -0.7037213444709778,\n",
       " -0.39783185720443726,\n",
       " -0.36193808913230896,\n",
       " -0.6811600923538208,\n",
       " -0.18488575518131256,\n",
       " 0.6047443151473999,\n",
       " -0.17012985050678253,\n",
       " -0.11221067607402802,\n",
       " -0.11349140107631683,\n",
       " -7.79653263092041,\n",
       " -0.03174687176942825,\n",
       " -0.5907049179077148,\n",
       " -0.0845143049955368,\n",
       " 0.6719594597816467,\n",
       " -0.6047013998031616,\n",
       " -0.4621417820453644,\n",
       " 0.4189649224281311,\n",
       " 0.2606521546840668,\n",
       " -0.5251185894012451,\n",
       " 0.656951904296875,\n",
       " -0.14103704690933228,\n",
       " -0.724404513835907,\n",
       " 0.032266344875097275,\n",
       " -0.38332653045654297,\n",
       " 0.2214561551809311,\n",
       " -0.11025898903608322,\n",
       " 0.2219904512166977,\n",
       " -0.16805943846702576,\n",
       " -0.22911910712718964,\n",
       " 0.40065279603004456,\n",
       " 0.8264251947402954,\n",
       " -0.25879043340682983,\n",
       " -0.4252917170524597,\n",
       " -0.1860014647245407,\n",
       " 0.21712413430213928,\n",
       " 0.852258026599884,\n",
       " 1.1114447116851807,\n",
       " 0.03458324819803238,\n",
       " -0.42567503452301025,\n",
       " -0.4035224914550781,\n",
       " 0.5391470789909363,\n",
       " 0.6653061509132385,\n",
       " -0.15112830698490143,\n",
       " 0.20673374831676483,\n",
       " 0.5916152596473694,\n",
       " 0.10783706605434418,\n",
       " 0.06303859502077103,\n",
       " -0.6804474592208862,\n",
       " 0.46267828345298767,\n",
       " -0.8944555521011353,\n",
       " -0.20007365942001343,\n",
       " -0.18524183332920074,\n",
       " -0.25279444456100464,\n",
       " 0.013942774385213852,\n",
       " -0.227418452501297,\n",
       " -0.5019238591194153,\n",
       " -0.259070485830307,\n",
       " -0.4195726811885834,\n",
       " -0.2565968334674835,\n",
       " 0.08592142164707184,\n",
       " -0.4816386103630066,\n",
       " -0.7389425039291382,\n",
       " 0.384757936000824,\n",
       " 1.148498773574829,\n",
       " -0.08795226365327835,\n",
       " -0.7781391143798828,\n",
       " -0.18237966299057007,\n",
       " 0.27100449800491333,\n",
       " 0.7376315593719482,\n",
       " -0.2066810131072998,\n",
       " -0.042161568999290466,\n",
       " 0.14717990159988403,\n",
       " -0.25498059391975403,\n",
       " 0.33164745569229126,\n",
       " -0.3789907693862915,\n",
       " -0.702992856502533,\n",
       " -0.46402469277381897,\n",
       " -0.47181829810142517,\n",
       " -0.530529260635376,\n",
       " 0.08136516064405441,\n",
       " 0.3396340608596802,\n",
       " -0.21239398419857025,\n",
       " 0.38136026263237,\n",
       " -0.9020550847053528,\n",
       " -0.41401106119155884,\n",
       " -0.47626185417175293,\n",
       " -0.34683799743652344,\n",
       " -0.3377147912979126,\n",
       " -0.6628923416137695,\n",
       " 0.2143520712852478,\n",
       " 0.31117284297943115,\n",
       " 0.43092554807662964,\n",
       " 0.12191533297300339,\n",
       " -0.017828848212957382,\n",
       " -0.12583602964878082,\n",
       " 0.33957740664482117,\n",
       " -0.09169825166463852,\n",
       " 0.24532632529735565,\n",
       " 0.5283830165863037,\n",
       " 0.7038718461990356,\n",
       " 0.6268500089645386,\n",
       " 0.00923143420368433,\n",
       " 0.8284425139427185,\n",
       " 0.6025779247283936,\n",
       " 0.5495515465736389,\n",
       " -0.34349843859672546,\n",
       " 0.3288527727127075,\n",
       " 0.1823807954788208,\n",
       " 0.2601393759250641,\n",
       " -0.01894410327076912,\n",
       " 0.535849928855896,\n",
       " -0.07729293406009674,\n",
       " -0.05701117962598801,\n",
       " -0.5398024320602417,\n",
       " -0.2532539665699005,\n",
       " -0.02206384763121605,\n",
       " -0.5667169690132141,\n",
       " -0.1217791885137558,\n",
       " 0.37247171998023987,\n",
       " -0.11095214635133743,\n",
       " -0.615912914276123,\n",
       " 0.32324957847595215,\n",
       " 0.45441827178001404,\n",
       " 0.23056231439113617,\n",
       " -2.3405637741088867,\n",
       " -0.3898467421531677,\n",
       " -0.03767596557736397,\n",
       " -0.17562665045261383,\n",
       " 0.40651726722717285,\n",
       " -0.45753777027130127,\n",
       " 1.0350662469863892,\n",
       " -0.45301544666290283,\n",
       " 0.5571080446243286,\n",
       " -0.7762919068336487,\n",
       " -0.2582171857357025,\n",
       " -0.8123776316642761,\n",
       " 0.027839435264468193,\n",
       " 0.021091900765895844,\n",
       " -0.3034447133541107,\n",
       " 0.34992972016334534,\n",
       " -0.6623353958129883,\n",
       " -0.2909213602542877,\n",
       " -0.18953290581703186,\n",
       " -0.5997650623321533,\n",
       " 0.8640273213386536,\n",
       " -0.24815954267978668,\n",
       " -0.29709047079086304,\n",
       " 0.8860780000686646,\n",
       " 0.04529644176363945,\n",
       " 1.1951236724853516,\n",
       " -1.1161422729492188,\n",
       " -0.04289549961686134,\n",
       " -1.6880977153778076,\n",
       " -0.16583313047885895,\n",
       " -0.4640212059020996,\n",
       " 0.03880169615149498,\n",
       " -0.4149312973022461,\n",
       " 0.5659136772155762,\n",
       " -0.07184366881847382,\n",
       " 0.6438769102096558,\n",
       " -1.1572128534317017,\n",
       " 0.32702523469924927,\n",
       " 0.19401556253433228,\n",
       " -0.36513882875442505,\n",
       " -0.1496993601322174,\n",
       " 0.5544662475585938,\n",
       " -0.10601028800010681,\n",
       " 0.2943094074726105,\n",
       " -0.9837754368782043,\n",
       " -0.14144904911518097,\n",
       " 0.7259737253189087,\n",
       " 0.05785682797431946,\n",
       " 0.8584915995597839,\n",
       " -0.27259302139282227,\n",
       " -0.6073381900787354,\n",
       " -0.22768571972846985,\n",
       " 0.7255773544311523,\n",
       " 0.1539279967546463,\n",
       " -0.6805699467658997,\n",
       " -1.0378549098968506,\n",
       " -0.597703754901886,\n",
       " -0.6462168097496033,\n",
       " 1.1171226501464844,\n",
       " -0.21000456809997559,\n",
       " -0.7443035244941711,\n",
       " -0.16614656150341034,\n",
       " 0.03670107200741768,\n",
       " 0.23261283338069916,\n",
       " -0.5053027272224426,\n",
       " -1.0062577724456787,\n",
       " 0.028607431799173355,\n",
       " 0.6196390986442566,\n",
       " 0.11939772218465805,\n",
       " 0.16041713953018188,\n",
       " 0.012548833154141903,\n",
       " -0.6940840482711792,\n",
       " -1.0390965938568115,\n",
       " 0.3209550082683563,\n",
       " -0.5268062353134155,\n",
       " 0.5799688696861267,\n",
       " -0.3353428542613983,\n",
       " -0.3517853319644928,\n",
       " -0.38189470767974854,\n",
       " 0.23297882080078125,\n",
       " 0.045969072729349136,\n",
       " 0.6408992409706116,\n",
       " -0.23498287796974182,\n",
       " -0.2744370400905609,\n",
       " -0.3386567234992981,\n",
       " 0.16898459196090698,\n",
       " 0.4274075925350189,\n",
       " -0.4734047055244446,\n",
       " -0.02491043135523796,\n",
       " -0.5023868680000305,\n",
       " -0.1599859893321991,\n",
       " -0.28793132305145264,\n",
       " 0.45987895131111145,\n",
       " 0.12111934274435043,\n",
       " 0.695939838886261,\n",
       " 0.18703705072402954,\n",
       " 0.11010603606700897,\n",
       " -0.0493675135076046,\n",
       " 0.2681659758090973,\n",
       " 0.6883248090744019,\n",
       " 0.14249111711978912,\n",
       " -0.3902900516986847,\n",
       " 0.02434423565864563,\n",
       " 0.8115938305854797,\n",
       " 0.31366243958473206,\n",
       " 0.1475793719291687,\n",
       " 0.8607581853866577,\n",
       " 1.106387972831726,\n",
       " -0.12984894216060638,\n",
       " 0.6475292444229126,\n",
       " 0.4389672875404358,\n",
       " -0.14565706253051758,\n",
       " -0.29327720403671265,\n",
       " 0.19903028011322021,\n",
       " 0.44643306732177734,\n",
       " -0.055179595947265625,\n",
       " 8.315621376037598,\n",
       " -0.08598960936069489,\n",
       " 0.7728097438812256,\n",
       " 0.1960563361644745,\n",
       " 0.7582479119300842,\n",
       " -0.6882674098014832,\n",
       " -0.22637659311294556,\n",
       " 0.5025527477264404,\n",
       " -0.07177169620990753,\n",
       " -0.03814778849482536,\n",
       " 1.0206265449523926,\n",
       " -0.4750046730041504,\n",
       " 0.015179314650595188,\n",
       " -0.6247814297676086,\n",
       " 0.4034382998943329,\n",
       " 1.700039029121399,\n",
       " -0.30730658769607544,\n",
       " 0.28762733936309814,\n",
       " 0.63616544008255,\n",
       " -0.23646242916584015,\n",
       " 0.2806755304336548,\n",
       " 0.4410918056964874,\n",
       " 0.14614292979240417,\n",
       " 0.4948270916938782,\n",
       " 0.43732860684394836,\n",
       " 1.0119167566299438,\n",
       " 0.9210423827171326,\n",
       " -0.35212814807891846,\n",
       " 0.32403385639190674,\n",
       " -0.44126105308532715,\n",
       " -0.18103229999542236,\n",
       " -0.31492364406585693,\n",
       " -0.503863513469696,\n",
       " -0.26293063163757324,\n",
       " 0.21797089278697968,\n",
       " -0.9694619178771973,\n",
       " 0.021304313093423843,\n",
       " 0.44222936034202576,\n",
       " -0.36141523718833923,\n",
       " -0.463960736989975,\n",
       " -0.24528658390045166,\n",
       " 0.11174631118774414,\n",
       " 0.09441330283880234,\n",
       " 0.18713852763175964,\n",
       " 0.36507827043533325,\n",
       " 0.7508949041366577,\n",
       " -0.15697608888149261,\n",
       " 0.4001035690307617,\n",
       " 1.323508620262146,\n",
       " -0.20196901261806488,\n",
       " 0.292355477809906,\n",
       " 0.34666717052459717,\n",
       " -0.11999291181564331,\n",
       " -0.6510916352272034,\n",
       " 0.4462094306945801,\n",
       " -0.45647361874580383,\n",
       " -0.14198175072669983,\n",
       " -0.4045391082763672,\n",
       " 0.7035051584243774,\n",
       " 0.3213372826576233,\n",
       " 0.5096818804740906,\n",
       " 0.6800979971885681,\n",
       " -0.008764655329287052,\n",
       " -0.19463925063610077,\n",
       " -0.7179383635520935,\n",
       " 0.2567158043384552,\n",
       " 0.07364790141582489,\n",
       " -0.222466841340065,\n",
       " 0.022669780999422073,\n",
       " 0.8473037481307983,\n",
       " -0.034888043999671936,\n",
       " -0.07169658690690994,\n",
       " -0.05516548082232475,\n",
       " -0.06913617253303528,\n",
       " -0.530577540397644,\n",
       " -0.6640213131904602,\n",
       " -0.34023773670196533,\n",
       " -0.5658687949180603,\n",
       " -0.4476564824581146,\n",
       " -2.571279287338257,\n",
       " -0.12790530920028687,\n",
       " 0.9560791850090027,\n",
       " -0.6428014039993286,\n",
       " -0.4189566671848297,\n",
       " -0.20985344052314758,\n",
       " 0.47335946559906006,\n",
       " -0.11219882220029831,\n",
       " -0.10753587633371353,\n",
       " 0.14247222244739532,\n",
       " 1.059354305267334,\n",
       " 0.3302377462387085,\n",
       " -0.3935352563858032,\n",
       " -0.058758582919836044,\n",
       " 0.648691713809967,\n",
       " 0.30499130487442017,\n",
       " -0.27360308170318604,\n",
       " -0.25764214992523193,\n",
       " 0.015458552166819572,\n",
       " 0.6662879586219788,\n",
       " 0.3119010329246521,\n",
       " -0.15479373931884766,\n",
       " 0.028574924916028976,\n",
       " -0.1503346860408783,\n",
       " 0.06127818673849106,\n",
       " -0.0910576581954956,\n",
       " 0.0481022410094738,\n",
       " 0.9771047234535217,\n",
       " 0.7927762866020203,\n",
       " 0.023048892617225647,\n",
       " 0.30974704027175903,\n",
       " 0.33901262283325195,\n",
       " -0.07123278081417084,\n",
       " 0.34432730078697205,\n",
       " -0.12369780987501144,\n",
       " 0.2354590892791748,\n",
       " 0.38229313492774963,\n",
       " -0.8465576767921448,\n",
       " -0.2445705085992813,\n",
       " -0.16847288608551025,\n",
       " 0.5078030824661255,\n",
       " -0.4897501766681671,\n",
       " 0.07203903794288635,\n",
       " 0.6503809690475464,\n",
       " -0.08006825298070908]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# https://vipul-maheshwari.github.io/2024/03/03/multimodal-rag-application\n",
    "\n",
    "def embed_image(img):\n",
    "    processed_image = preprocess(img)\n",
    "    unsqueezed_image = processed_image.unsqueeze(0)\n",
    "    embeddings = model.encode_image(unsqueezed_image)\n",
    "    \n",
    "    # Detach, move to CPU, convert to numpy array, and extract the first element as a list\n",
    "    result = embeddings.detach().numpy()[0].tolist()\n",
    "    return result\n",
    "\n",
    "len(embed_image(image))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed_txt(txt):\n",
    "    tokenized_text = clip.tokenize([txt]).to(device)\n",
    "    embeddings = model.encode_text(tokenized_text)\n",
    "    \n",
    "    # Detach, move to CPU, convert to numpy array, and extract the first element as a list\n",
    "    result = embeddings.detach().cpu().numpy()[0].tolist()\n",
    "    return result\n",
    "\n",
    "res = tbl.search(embed_txt(\"a road with a stop\")).limit(3).to_pandas()\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "https://blog.lancedb.com/lancedb-polars-2d5eb32a8aa3/\n",
    "\n",
    "https://github.com/lancedb/lancedb"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}