{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QS0v2bceN4Or"
      },
      "source": [
        "Builds a database of vector embeddings from list of abstracts"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l5RwcIG8OAjX"
      },
      "source": [
        "## Some Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sfwT5YW2JCnu"
      },
      "outputs": [],
      "source": [
        "!pip install transformers==4.28.0\n",
        "!pip install -U sentence-transformers\n",
        "!pip install datasets\n",
        "!pip install langchain\n",
        "!pip install torch\n",
        "!pip install faiss-cpu"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "psoTvOp4VkBE"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import shutil\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from tqdm.auto import tqdm\n",
        "import torch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "arZiN8QRHS_a"
      },
      "outputs": [],
      "source": [
        "import locale\n",
        "locale.getpreferredencoding = lambda: \"UTF-8\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JwWs0-Uu6ohg"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer, BertForSequenceClassification\n",
        "\n",
        "m_tokenizer = AutoTokenizer.from_pretrained(\"biodatlab/MIReAD-Neuro-Large\")\n",
        "m_model = BertForSequenceClassification.from_pretrained(\"biodatlab/MIReAD-Neuro-Large\")\n",
        "miread_bundle = (m_tokenizer,m_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BR-adEUUz9su"
      },
      "outputs": [],
      "source": [
        "def create_lbert_embed(sents,bundle):\n",
        "  tokenizer = bundle[0]\n",
        "  model = bundle[1]\n",
        "  model.cuda()\n",
        "  tokens = tokenizer(sents,padding=True,truncation=True,return_tensors='pt')\n",
        "  device = torch.device('cuda')\n",
        "  tokens = tokens.to(device)\n",
        "  with torch.no_grad():\n",
        "    embeds = model(**tokens, output_hidden_states=True,return_dict=True).pooler_output\n",
        "  return embeds.cpu()\n",
        "\n",
        "def create_miread_embed(sents,bundle):\n",
        "  tokenizer = bundle[0]\n",
        "  model = bundle[1]\n",
        "  model.cuda()\n",
        "  tokens = tokenizer(sents,\n",
        "                   max_length=512,\n",
        "                   padding=True,\n",
        "                   truncation=True,\n",
        "                   return_tensors=\"pt\"\n",
        "                  )\n",
        "  device = torch.device('cuda')\n",
        "  tokens = tokens.to(device)\n",
        "  with torch.no_grad():\n",
        "    out = model.bert(**tokens)\n",
        "    feature = out.last_hidden_state[:, 0, :]\n",
        "  return feature.cpu()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-wHpHmD3zNSR"
      },
      "outputs": [],
      "source": [
        "from langchain.vectorstores import FAISS\n",
        "from langchain.embeddings import HuggingFaceEmbeddings\n",
        "\n",
        "model_name = \"biodatlab/MIReAD-Neuro-Large\"\n",
        "model_kwargs = {'device': 'cuda'}\n",
        "encode_kwargs = {'normalize_embeddings': False}\n",
        "faiss_embedder = HuggingFaceEmbeddings(\n",
        "    model_name=model_name,\n",
        "    model_kwargs=model_kwargs,\n",
        "    encode_kwargs=encode_kwargs\n",
        ")\n",
        "\n",
        "def add_to_db(data,create_embed,bundle,name=''):\n",
        "  batch_size = 128\n",
        "  \"\"\"\n",
        "  data : list of rows with an 'abstract' and an 'identifier' field\n",
        "  index : pinecone Index object\n",
        "  create_embed : function that creates the embedding given an abstract\n",
        "  \"\"\"\n",
        "  res = []\n",
        "  vecdb = None\n",
        "  for i in tqdm(range(0, len(data), batch_size)):\n",
        "      # find end of batch\n",
        "      i_end = min(i+batch_size, len(data))\n",
        "      # create IDs batch\n",
        "      ids = [name + '-' + str(x) for x in range(i, i_end)]\n",
        "      # create metadata batch\n",
        "      metadatas = [{\n",
        "                    'journal':row.get('journal','None'),\n",
        "                    'title':row['title'],\n",
        "                    'abstract': row['abstract'],\n",
        "                    'authors':row.get('authors','None'),\n",
        "                    'link':row.get('link','None'),\n",
        "                    'date':row.get('date','None'),\n",
        "                    'submitter':row.get('submitter','None'),\n",
        "                    } for row in data[i:i_end]]\n",
        "      # create embeddings\n",
        "      em = [create_embed(row['abstract'],bundle).tolist()[0] for row in data[i:i_end]]\n",
        "      texts = [row['abstract'] for row in data[i:i_end]]\n",
        "      records = list(zip(texts, em))\n",
        "      if vecdb:\n",
        "        vecdb_batch = FAISS.from_embeddings(records,faiss_embedder,metadatas=metadatas,ids=ids)\n",
        "        vecdb.merge_from(vecdb_batch)\n",
        "      else:\n",
        "        vecdb = FAISS.from_embeddings(records,faiss_embedder,metadatas=metadatas,ids=ids)\n",
        "  return vecdb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PfsK3DE4MMou"
      },
      "outputs": [],
      "source": [
        "nbdt_data = pd.read_json('data_final.json')\n",
        "aliases = pd.read_csv('id_list.csv')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JrGJh5XgNPvU"
      },
      "outputs": [],
      "source": [
        "aliases = aliases.drop_duplicates('Full Name')\n",
        "aliases.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CShYwGwWMZh5"
      },
      "outputs": [],
      "source": [
        "nbdt_data.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SziJtbggMuyn"
      },
      "outputs": [],
      "source": [
        "def load_nbdt(data,aliases):\n",
        "  nbdt_records = []\n",
        "  urls = []\n",
        "  no_abst_count = 0\n",
        "  no_journal_count = 0\n",
        "  for row in aliases.itertuples():\n",
        "    name = row[1]\n",
        "    auth_ids = eval(row[2])\n",
        "    auth_ids = [int(x) for x in auth_ids]\n",
        "    papers = nbdt_data.loc[nbdt_data['authorId'].isin(auth_ids)]['papers']\n",
        "    all_papers = []\n",
        "    for paper_set in papers:\n",
        "      all_papers.extend(paper_set)\n",
        "    for paper in all_papers:\n",
        "      url = paper['url']\n",
        "      title = paper['title']\n",
        "      abst = paper['abstract']\n",
        "      year = paper['year']\n",
        "      journal = paper.get('journal')\n",
        "      if journal:\n",
        "        journal = journal.get('name')\n",
        "      else:\n",
        "        journal = 'None'\n",
        "        no_journal_count += 1\n",
        "      authors = [name]\n",
        "      if not(abst):\n",
        "        abst = ''\n",
        "        no_abst_count += 1\n",
        "      record = {'journal':journal,'title':title,'abstract':abst,'link':url,'date':year,'authors':authors,'submitter':'None'}\n",
        "      if url not in urls:\n",
        "        nbdt_records.append(record)\n",
        "        urls.append(url)\n",
        "  return nbdt_records, (no_abst_count,no_journal_count)\n",
        "nbdt_recs, no_counts = load_nbdt(nbdt_data,aliases)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IovTlDINc2Ds"
      },
      "outputs": [],
      "source": [
        "nbdt_db = add_to_db(nbdt_recs,create_miread_embed,miread_bundle,'nbdt')\n",
        "nbdt_db.save_local(\"nbdt_index\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}