{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import transformers\n", "import torch" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModel, AutoTokenizer, BertModel, DistilBertModel, RobertaModel, GPT2Model" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [], "source": [ "mname = 'bert-base-uncased'\n", "sentence = 'The count went forward with his original plan'\n", "t_class = BertModel\n", "\n", "def test_model(t_class, mname, sentence):\n", " m = t_class.from_pretrained(mname, output_hidden_states=True, output_past=False, output_attentions=True, output_additional_info=True)\n", " t = AutoTokenizer.from_pretrained(mname)\n", " input_ids = t.encode(sentence)\n", " outputs = m(torch.tensor(input_ids).unsqueeze(0))\n", " return outputs\n" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "mname = 'bert-base-uncased'\n", "sentence = 'The count went forward with his original plan'\n", "t_class = BertModel\n", "out = test_model(t_class, mname, sentence)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "mname = 'distilbert-base-uncased'\n", "sentence = 'The count went forward with his original plan'\n", "t_class = DistilBertModel\n", "out = test_model(t_class, mname, sentence)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "mname = 'roberta-base'\n", "sentence = 'The count went forward with his original plan'\n", "t_class = RobertaModel\n", "out = test_model(t_class, mname, sentence)" ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n", "CONTEXTS: torch.Size([1, 16, 12, 64])\n" ] } ], "source": [ "mname = 'gpt2'\n", "sentence = 'The count went forward with his original plan to take over the mighty world of Disney'\n", "t_class = GPT2Model\n", "out = test_model(t_class, mname, sentence)" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4" ] }, "execution_count": 123, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(out)" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 12, 64])" ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out[-1][0].shape" ] }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 768])" ] }, "execution_count": 120, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out[1][0].shape" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12" ] }, "execution_count": 107, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = t_class.from_pretrained(mname)\n", "\n", "m.config.n_head" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GPT2Config {\n", " \"attn_pdrop\": 0.1,\n", " \"bos_token_id\": 0,\n", " \"do_sample\": false,\n", " \"embd_pdrop\": 0.1,\n", " \"eos_token_ids\": 0,\n", " \"finetuning_task\": null,\n", " \"id2label\": {\n", " \"0\": \"LABEL_0\",\n", " \"1\": \"LABEL_1\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"is_decoder\": false,\n", " \"label2id\": {\n", " \"LABEL_0\": 0,\n", " \"LABEL_1\": 1\n", " },\n", " \"layer_norm_epsilon\": 1e-05,\n", " \"length_penalty\": 1.0,\n", " \"max_length\": 20,\n", " \"model_type\": \"gpt2\",\n", " \"n_ctx\": 1024,\n", " \"n_embd\": 768,\n", " \"n_head\": 12,\n", " \"n_layer\": 12,\n", " \"n_positions\": 1024,\n", " \"num_beams\": 1,\n", " \"num_labels\": 2,\n", " \"num_return_sequences\": 1,\n", " \"output_additional_info\": false,\n", " \"output_attentions\": false,\n", " \"output_hidden_states\": false,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"pruned_heads\": {},\n", " \"repetition_penalty\": 1.0,\n", " \"resid_pdrop\": 0.1,\n", " \"summary_activation\": null,\n", " \"summary_first_dropout\": 0.1,\n", " \"summary_proj_to_labels\": true,\n", " \"summary_type\": \"cls_index\",\n", " \"summary_use_proj\": true,\n", " \"temperature\": 1.0,\n", " \"top_k\": 50,\n", " \"top_p\": 1.0,\n", " \"torchscript\": false,\n", " \"use_bfloat16\": false,\n", " \"vocab_size\": 50257\n", "}" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.config" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 12, 8, 8])" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out[-1][1].shape" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 12, 8, 64])" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out[1][0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tokenizing smiles" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['C', 'O', '.', 'C', 'O', 'C', '(', '=', 'O', ')', 'C', '(', 'C', ')', '(', 'C', ')', 'c', '1', 'c', 'c', 'c', '(', 'C', '(', '=', 'O', ')', 'C', 'C', 'C', 'N', '2', 'C', 'C', 'C', '(', 'C', '(', 'O', ')', '(', 'c', '3', 'c', 'c', 'c', 'c', 'c', '3', ')', 'c', '3', 'c', 'c', 'c', 'c', 'c', '3', ')', 'C', 'C', '2', ')', 'c', 'c', '1', '.', 'Cl', '.', 'O', '[Na]', '>>', 'C', 'C', '(', 'C', ')', '(', 'C', '(', '=', 'O', ')', 'O', ')', 'c', '1', 'c', 'c', 'c', '(', 'C', '(', '=', 'O', ')', 'C', 'C', 'C', 'N', '2', 'C', 'C', 'C', '(', 'C', '(', 'O', ')', '(', 'c', '3', 'c', 'c', 'c', 'c', 'c', '3', ')', 'c', '3', 'c', 'c', 'c', 'c', 'c', '3', ')', 'C', 'C', '2', ')', 'c', 'c', '1']\n" ] } ], "source": [ "import re\n", "import regex\n", "\n", "def tokenize_smiles(smiles: str) -> str:\n", " \"\"\"\n", " Tokenize a SMILES molecule or reaction\n", " \"\"\"\n", " pattern = r\"(\\[[^\\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\\(|\\)|\\.|=|#|-|\\+|\\\\|\\/|:|~|@|\\?|>>?|\\*|\\$|\\%[0-9]{2}|\\%\\([0-9]{3}\\)|[0-9])\"\n", " regex = re.compile(pattern)\n", " tokens = [token for token in regex.findall(smiles)]\n", " if smiles != ''.join(tokens):\n", " raise \n", "# return ' '.join(tokens)\n", " return tokens\n", "\n", "\n", "rxn = 'CO.COC(=O)C(C)(C)c1ccc(C(=O)CCCN2CCC(C(O)(c3ccccc3)c3ccccc3)CC2)cc1.Cl.O[Na]>>CC(C)(C(=O)O)c1ccc(C(=O)CCCN2CCC(C(O)(c3ccccc3)c3ccccc3)CC2)cc1'\n", "tokenized_rxn = tokenize_smiles(rxn)\n", "print(tokenized_rxn)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['C',\n", " 'O',\n", " '.',\n", " 'C',\n", " 'O',\n", " 'C',\n", " '(',\n", " '=',\n", " 'O',\n", " ')',\n", " 'C',\n", " '(',\n", " 'C',\n", " ')',\n", " '(',\n", " 'C',\n", " ')',\n", " 'c',\n", " '1',\n", " 'c',\n", " 'c',\n", " 'c',\n", " '(',\n", " 'C',\n", " '(',\n", " '=',\n", " 'O',\n", " ')',\n", " 'C',\n", " 'C',\n", " 'C',\n", " 'N',\n", " '2',\n", " 'C',\n", " 'C',\n", " 'C',\n", " '(',\n", " 'C',\n", " '(',\n", " 'O',\n", " ')',\n", " '(',\n", " 'c',\n", " '3',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " '3',\n", " ')',\n", " 'c',\n", " '3',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " '3',\n", " ')',\n", " 'C',\n", " 'C',\n", " '2',\n", " ')',\n", " 'c',\n", " 'c',\n", " '1',\n", " '.',\n", " 'Cl',\n", " '.',\n", " 'O',\n", " '[Na]',\n", " '>>',\n", " 'C',\n", " 'C',\n", " '(',\n", " 'C',\n", " ')',\n", " '(',\n", " 'C',\n", " '(',\n", " '=',\n", " 'O',\n", " ')',\n", " 'O',\n", " ')',\n", " 'c',\n", " '1',\n", " 'c',\n", " 'c',\n", " 'c',\n", " '(',\n", " 'C',\n", " '(',\n", " '=',\n", " 'O',\n", " ')',\n", " 'C',\n", " 'C',\n", " 'C',\n", " 'N',\n", " '2',\n", " 'C',\n", " 'C',\n", " 'C',\n", " '(',\n", " 'C',\n", " '(',\n", " 'O',\n", " ')',\n", " '(',\n", " 'c',\n", " '3',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " '3',\n", " ')',\n", " 'c',\n", " '3',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " 'c',\n", " '3',\n", " ')',\n", " 'C',\n", " 'C',\n", " '2',\n", " ')',\n", " 'c',\n", " 'c',\n", " '1']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_rxn" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:tformers] *", "language": "python", "name": "conda-env-tformers-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }