File size: 19,628 Bytes
c5d2283
1
2
{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{},"source":["### Kaggle link: https://www.kaggle.com/code/noobhocai/train-stage-2"]},{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:15:26.396463Z","iopub.status.busy":"2023-06-26T16:15:26.396153Z","iopub.status.idle":"2023-06-26T16:15:44.091436Z","shell.execute_reply":"2023-06-26T16:15:44.090085Z","shell.execute_reply.started":"2023-06-26T16:15:26.396437Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n","\u001b[0m"]}],"source":["!pip install rank_bm25 pandarallel gensim --q"]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:15:46.131348Z","iopub.status.busy":"2023-06-26T16:15:46.130925Z","iopub.status.idle":"2023-06-26T16:15:47.760579Z","shell.execute_reply":"2023-06-26T16:15:47.759408Z","shell.execute_reply.started":"2023-06-26T16:15:46.131315Z"},"trusted":true},"outputs":[],"source":["import os\n","import json\n","import pandas as pd\n","import numpy as np\n","import json, pickle\n","from rank_bm25 import BM25Okapi\n","import argparse\n","import gc\n","from tqdm.auto import tqdm\n","tqdm.pandas()\n","from glob import glob \n","import re \n","from nltk import word_tokenize as lib_tokenizer \n","import string\n","from gensim.corpora import Dictionary\n","from gensim.models import TfidfModel, OkapiBM25Model\n","from gensim.similarities import SparseMatrixSimilarity"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:16:45.111948Z","iopub.status.busy":"2023-06-26T16:16:45.111271Z","iopub.status.idle":"2023-06-26T16:16:45.117881Z","shell.execute_reply":"2023-06-26T16:16:45.116573Z","shell.execute_reply.started":"2023-06-26T16:16:45.111915Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["INFO: Pandarallel will run on 10 workers.\n","INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n"]}],"source":["from pandarallel import pandarallel\n","\n","pandarallel.initialize(progress_bar=True, nb_workers=10)"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:16:47.210933Z","iopub.status.busy":"2023-06-26T16:16:47.210499Z","iopub.status.idle":"2023-06-26T16:17:49.317252Z","shell.execute_reply":"2023-06-26T16:17:49.316132Z","shell.execute_reply.started":"2023-06-26T16:16:47.210900Z"},"trusted":true},"outputs":[],"source":["df_wiki = pd.read_json(\"/kaggle/input/e2eqa-wiki-zalo-ai/wikipedia_20220620_cleaned/wikipedia_20220620_cleaned.jsonl\", lines=True)"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:18:42.926014Z","iopub.status.busy":"2023-06-26T16:18:42.925307Z","iopub.status.idle":"2023-06-26T16:18:42.961174Z","shell.execute_reply":"2023-06-26T16:18:42.959896Z","shell.execute_reply.started":"2023-06-26T16:18:42.925974Z"},"trusted":true},"outputs":[{"data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>id</th>\n","      <th>url</th>\n","      <th>title</th>\n","      <th>text</th>\n","      <th>timestamp</th>\n","      <th>revid</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>2</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=2</td>\n","      <td>Trang Chính</td>\n","      <td>Trang Chính\\n\\n&lt;templatestyles src=\"Wiki2021/s...</td>\n","      <td>2022-05-12 12:46:53+00:00</td>\n","      <td>68591979</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>4</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=4</td>\n","      <td>Internet Society</td>\n","      <td>Internet Society\\n\\nInternet Society hay ISOC ...</td>\n","      <td>2022-01-20 07:59:10+00:00</td>\n","      <td>67988747</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>13</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=13</td>\n","      <td>Tiếng Việt</td>\n","      <td>Tiếng Việt\\n\\nTiếng Việt, cũng gọi là tiếng Vi...</td>\n","      <td>2022-05-29 03:42:42+00:00</td>\n","      <td>68660631</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>24</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=24</td>\n","      <td>Ohio</td>\n","      <td>Ohio\\n\\nOhio (viết tắt là OH, viết tắt cũ là O...</td>\n","      <td>2022-04-17 08:15:22+00:00</td>\n","      <td>68482118</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>26</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=26</td>\n","      <td>California</td>\n","      <td>California\\n\\nCalifornia (phát âm như \"Ca-li-p...</td>\n","      <td>2022-06-16 15:27:07+00:00</td>\n","      <td>68738039</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>"],"text/plain":["   id                                     url             title  \\\n","0   2   https://vi.wikipedia.org/wiki?curid=2       Trang Chính   \n","1   4   https://vi.wikipedia.org/wiki?curid=4  Internet Society   \n","2  13  https://vi.wikipedia.org/wiki?curid=13        Tiếng Việt   \n","3  24  https://vi.wikipedia.org/wiki?curid=24              Ohio   \n","4  26  https://vi.wikipedia.org/wiki?curid=26        California   \n","\n","                                                text  \\\n","0  Trang Chính\\n\\n<templatestyles src=\"Wiki2021/s...   \n","1  Internet Society\\n\\nInternet Society hay ISOC ...   \n","2  Tiếng Việt\\n\\nTiếng Việt, cũng gọi là tiếng Vi...   \n","3  Ohio\\n\\nOhio (viết tắt là OH, viết tắt cũ là O...   \n","4  California\\n\\nCalifornia (phát âm như \"Ca-li-p...   \n","\n","                  timestamp     revid  \n","0 2022-05-12 12:46:53+00:00  68591979  \n","1 2022-01-20 07:59:10+00:00  67988747  \n","2 2022-05-29 03:42:42+00:00  68660631  \n","3 2022-04-17 08:15:22+00:00  68482118  \n","4 2022-06-16 15:27:07+00:00  68738039  "]},"execution_count":6,"metadata":{},"output_type":"execute_result"}],"source":["df_wiki.head()"]},{"cell_type":"code","execution_count":7,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:20:14.921978Z","iopub.status.busy":"2023-06-26T16:20:14.921498Z","iopub.status.idle":"2023-06-26T16:20:14.943942Z","shell.execute_reply":"2023-06-26T16:20:14.942444Z","shell.execute_reply.started":"2023-06-26T16:20:14.921945Z"},"trusted":true},"outputs":[],"source":["def post_process(x):\n","    x = \" \".join(word_tokenize(strip_context(x))).strip()\n","    x = x.replace(\"\\n\",\" \")\n","    x = \"\".join([i for i in x if i not in string.punctuation])\n","    x = \" \".join(x.split()[:128])\n","    return x\n","\n","dict_map = dict({})  \n","def word_tokenize(text): \n","    global dict_map \n","    words = text.split() \n","    words_norm = [] \n","    for w in words: \n","        if dict_map.get(w, None) is None: \n","            dict_map[w] = ' '.join(lib_tokenizer(w)).replace('``', '\"').replace(\"''\", '\"') \n","        words_norm.append(dict_map[w]) \n","    return words_norm \n"," \n","def strip_answer_string(text): \n","    text = text.strip() \n","    while text[-1] in '.,/><;:\\'\"[]{}+=-_)(*&^!~`': \n","        if text[0] != '(' and text[-1] == ')' and '(' in text: \n","            break \n","        if text[-1] == '\"' and text[0] != '\"' and text.count('\"') > 1: \n","            break \n","        text = text[:-1].strip() \n","    while text[0] in '.,/><;:\\'\"[]{}+=-_)(*&^!~`': \n","        if text[0] == '\"' and text[-1] != '\"' and text.count('\"') > 1: \n","            break \n","        text = text[1:].strip() \n","    text = text.strip() \n","    return text \n"," \n","def strip_context(text): \n","    text = text.replace('\\n', ' ') \n","    text = re.sub(r'\\s+', ' ', text) \n","    text = text.strip() \n","    return text\n","\n","def check_(x):\n","    x = str(x).lower()\n","    return (x.isnumeric() or \"ngày\" in x or \"tháng\" in x or \"năm\" in x)\n","\n","def find_candidate_ids(x, raw_answer=None, already_added=[], topk=50):\n","    x = str(x)\n","    query = post_process(x).lower().split()\n","    tfidf_query = tfidf_model[dictionary.doc2bow(query)]\n","    scores = bm25_index[tfidf_query]\n","    top_n = list(np.argsort(scores)[::-1][:topk])\n","    top_n = [i for i in top_n if i not in already_added]\n","    # scores = list(scores[top_n])\n","    if raw_answer is not None:\n","        raw_answer = raw_answer.strip()\n","        if raw_answer in entity_dict:\n","            title = entity_dict[raw_answer].replace(\"wiki/\",\"\").replace(\"_\",\" \")\n","            extra_id = title2idx.get(title, -1)\n","            # print((raw_answer,title,extra_id, extra_id not in top_n))\n","            if extra_id != -1 and extra_id not in top_n:\n","                print(f\"Add extra id {extra_id} for {raw_answer}\")\n","                top_n.append(extra_id)\n","                top_n = list(set(top_n))\n","    scores = scores[top_n]\n","    return list(top_n), np.array(scores)"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:20:18.394704Z","iopub.status.busy":"2023-06-26T16:20:18.394284Z","iopub.status.idle":"2023-06-26T16:30:31.484998Z","shell.execute_reply":"2023-06-26T16:30:31.483810Z","shell.execute_reply.started":"2023-06-26T16:20:18.394671Z"},"trusted":true},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"8cd018dfcf7e4ccc85f93f8bb319f26c","version_major":2,"version_minor":0},"text/plain":["VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=127347), Label(value='0 / 127347')…"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a13f08c5d7974e1087d598ca8b488840","version_major":2,"version_minor":0},"text/plain":["VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=127347), Label(value='0 / 127347')…"]},"metadata":{},"output_type":"display_data"}],"source":["df_wiki['title_lower'] = df_wiki['title'].apply(lambda x: x.lower()).parallel_apply(post_process)\n","df_wiki['text_lower'] = df_wiki['text'].apply(lambda x: x.lower()).parallel_apply(post_process)"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:33:42.344050Z","iopub.status.busy":"2023-06-26T16:33:42.342811Z","iopub.status.idle":"2023-06-26T16:33:42.362074Z","shell.execute_reply":"2023-06-26T16:33:42.360662Z","shell.execute_reply.started":"2023-06-26T16:33:42.344003Z"},"trusted":true},"outputs":[{"data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>id</th>\n","      <th>url</th>\n","      <th>title</th>\n","      <th>text</th>\n","      <th>timestamp</th>\n","      <th>revid</th>\n","      <th>title_lower</th>\n","      <th>text_lower</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>2</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=2</td>\n","      <td>Trang Chính</td>\n","      <td>Trang Chính\\n\\n&lt;templatestyles src=\"Wiki2021/s...</td>\n","      <td>2022-05-12 12:46:53+00:00</td>\n","      <td>68591979</td>\n","      <td>trang chính</td>\n","      <td>trang chính templatestyles src wiki2021stylesc...</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>4</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=4</td>\n","      <td>Internet Society</td>\n","      <td>Internet Society\\n\\nInternet Society hay ISOC ...</td>\n","      <td>2022-01-20 07:59:10+00:00</td>\n","      <td>67988747</td>\n","      <td>internet society</td>\n","      <td>internet society internet society hay isoc là ...</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>13</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=13</td>\n","      <td>Tiếng Việt</td>\n","      <td>Tiếng Việt\\n\\nTiếng Việt, cũng gọi là tiếng Vi...</td>\n","      <td>2022-05-29 03:42:42+00:00</td>\n","      <td>68660631</td>\n","      <td>tiếng việt</td>\n","      <td>tiếng việt tiếng việt cũng gọi là tiếng việt n...</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>24</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=24</td>\n","      <td>Ohio</td>\n","      <td>Ohio\\n\\nOhio (viết tắt là OH, viết tắt cũ là O...</td>\n","      <td>2022-04-17 08:15:22+00:00</td>\n","      <td>68482118</td>\n","      <td>ohio</td>\n","      <td>ohio ohio viết tắt là oh viết tắt cũ là o là m...</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>26</td>\n","      <td>https://vi.wikipedia.org/wiki?curid=26</td>\n","      <td>California</td>\n","      <td>California\\n\\nCalifornia (phát âm như \"Ca-li-p...</td>\n","      <td>2022-06-16 15:27:07+00:00</td>\n","      <td>68738039</td>\n","      <td>california</td>\n","      <td>california california phát âm như caliphótnia ...</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>"],"text/plain":["   id                                     url             title  \\\n","0   2   https://vi.wikipedia.org/wiki?curid=2       Trang Chính   \n","1   4   https://vi.wikipedia.org/wiki?curid=4  Internet Society   \n","2  13  https://vi.wikipedia.org/wiki?curid=13        Tiếng Việt   \n","3  24  https://vi.wikipedia.org/wiki?curid=24              Ohio   \n","4  26  https://vi.wikipedia.org/wiki?curid=26        California   \n","\n","                                                text  \\\n","0  Trang Chính\\n\\n<templatestyles src=\"Wiki2021/s...   \n","1  Internet Society\\n\\nInternet Society hay ISOC ...   \n","2  Tiếng Việt\\n\\nTiếng Việt, cũng gọi là tiếng Vi...   \n","3  Ohio\\n\\nOhio (viết tắt là OH, viết tắt cũ là O...   \n","4  California\\n\\nCalifornia (phát âm như \"Ca-li-p...   \n","\n","                  timestamp     revid       title_lower  \\\n","0 2022-05-12 12:46:53+00:00  68591979       trang chính   \n","1 2022-01-20 07:59:10+00:00  67988747  internet society   \n","2 2022-05-29 03:42:42+00:00  68660631        tiếng việt   \n","3 2022-04-17 08:15:22+00:00  68482118              ohio   \n","4 2022-06-16 15:27:07+00:00  68738039        california   \n","\n","                                          text_lower  \n","0  trang chính templatestyles src wiki2021stylesc...  \n","1  internet society internet society hay isoc là ...  \n","2  tiếng việt tiếng việt cũng gọi là tiếng việt n...  \n","3  ohio ohio viết tắt là oh viết tắt cũ là o là m...  \n","4  california california phát âm như caliphótnia ...  "]},"execution_count":9,"metadata":{},"output_type":"execute_result"}],"source":["df_wiki.head()"]},{"cell_type":"code","execution_count":10,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:46:35.539444Z","iopub.status.busy":"2023-06-26T16:46:35.538142Z","iopub.status.idle":"2023-06-26T16:46:38.074928Z","shell.execute_reply":"2023-06-26T16:46:38.073705Z","shell.execute_reply.started":"2023-06-26T16:46:35.539390Z"},"trusted":true},"outputs":[],"source":["title2idx = dict([(x.strip(),y) for x,y in zip(df_wiki.title, df_wiki.index.values)])\n","train = json.load(open(\"/kaggle/input/e2eqa-wiki-zalo-ai/processed/zac2022_train_merged_final.json\"))\n","entity_dict =  json.load(open(\"/kaggle/input/e2eqa-wiki-zalo-ai/processed/entities.json\"))"]},{"cell_type":"code","execution_count":19,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:55:20.687579Z","iopub.status.busy":"2023-06-26T16:55:20.686344Z","iopub.status.idle":"2023-06-26T16:55:25.068585Z","shell.execute_reply":"2023-06-26T16:55:25.067193Z","shell.execute_reply.started":"2023-06-26T16:55:20.687529Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["mkdir: cannot create directory ‘/kaggle/working/bm25_stage2’: File exists\n","mkdir: cannot create directory ‘/kaggle/working/bm25_stage2/full_text’: File exists\n"]}],"source":["!mkdir /kaggle/working/bm25_stage2\n","!mkdir /kaggle/working/bm25_stage2/full_text\n","!mkdir /kaggle/working/bm25_stage2/title"]},{"cell_type":"code","execution_count":18,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:54:49.241998Z","iopub.status.busy":"2023-06-26T16:54:49.240748Z","iopub.status.idle":"2023-06-26T16:54:52.346235Z","shell.execute_reply":"2023-06-26T16:54:52.345380Z","shell.execute_reply.started":"2023-06-26T16:54:49.241950Z"},"trusted":true},"outputs":[],"source":["corpus = [doc.split() for doc in df_wiki['text_lower']] #simple tokenier\n","dictionary = Dictionary(corpus)\n","bm25_model = OkapiBM25Model(dictionary=dictionary)\n","bm25_corpus = bm25_model[list(map(dictionary.doc2bow, corpus))]\n","bm25_index = SparseMatrixSimilarity(bm25_corpus, num_docs=len(corpus), num_terms=len(dictionary),normalize_queries=False, normalize_documents=False)\n","tfidf_model = TfidfModel(dictionary=dictionary, smartirs='bnn')  # Enforce binary weighting of queries\n","dictionary.save(\"/kaggle/working/bm25_stage2/full_text/dict\")\n","tfidf_model.save(\"/kaggle/working/bm25_stage2/full_text/tfidf\")\n","bm25_index.save(\"/kaggle/working/bm25_stage2/full_text/bm25_index\")"]},{"cell_type":"code","execution_count":22,"metadata":{"execution":{"iopub.execute_input":"2023-06-26T16:58:58.741331Z","iopub.status.busy":"2023-06-26T16:58:58.740930Z","iopub.status.idle":"2023-06-26T16:58:59.918024Z","shell.execute_reply":"2023-06-26T16:58:59.916804Z","shell.execute_reply.started":"2023-06-26T16:58:58.741301Z"},"trusted":true},"outputs":[],"source":["corpus = [doc.split() for doc in df_wiki['title_lower']] #simple tokenier\n","dictionary = Dictionary(corpus)\n","bm25_model = OkapiBM25Model(dictionary=dictionary)\n","bm25_corpus = bm25_model[list(map(dictionary.doc2bow, corpus))]\n","bm25_index = SparseMatrixSimilarity(bm25_corpus, num_docs=len(corpus), num_terms=len(dictionary),normalize_queries=False, normalize_documents=False)\n","tfidf_model = TfidfModel(dictionary=dictionary, smartirs='bnn')  # Enforce binary weighting of queries\n","dictionary.save(\"/kaggle/working/bm25_stage2/title/dict\")\n","tfidf_model.save(\"/kaggle/working/bm25_stage2/title/tfidf\")\n","bm25_index.save(\"/kaggle/working/bm25_stage2/title/bm25_index\")"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.10"}},"nbformat":4,"nbformat_minor":4}