File size: 12,334 Bytes
ba14b13 |
1 |
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30919,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"#@ First lets check the number of classes in PubMEdQA\nfrom datasets import load_dataset\n\nds = load_dataset(\"qiaojin/PubMedQA\", \"pqa_artificial\")\n\nprint(ds)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:55:44.411227Z","iopub.execute_input":"2025-03-16T09:55:44.411537Z","iopub.status.idle":"2025-03-16T09:55:52.472201Z","shell.execute_reply.started":"2025-03-16T09:55:44.411511Z","shell.execute_reply":"2025-03-16T09:55:52.471540Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/5.19k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"87a06377d77e4ff2a4892193569e5057"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"train-00000-of-00001.parquet: 0%| | 0.00/233M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c3e6a041041449799c0885dd82afae1b"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Generating train split: 0%| | 0/211269 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"b127ad68ebf24b48868922c098cc3c9e"}},"metadata":{}},{"name":"stdout","text":"DatasetDict({\n train: Dataset({\n features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],\n num_rows: 211269\n })\n})\n","output_type":"stream"}],"execution_count":1},{"cell_type":"code","source":"labels = set(ds['train']['final_decision'])\nlabels","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:55:52.473261Z","iopub.execute_input":"2025-03-16T09:55:52.473769Z","iopub.status.idle":"2025-03-16T09:55:52.655190Z","shell.execute_reply.started":"2025-03-16T09:55:52.473745Z","shell.execute_reply":"2025-03-16T09:55:52.653968Z"}},"outputs":[{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"{'no', 'yes'}"},"metadata":{}}],"execution_count":2},{"cell_type":"markdown","source":"Okay so number of classes = $2$","metadata":{}},{"cell_type":"code","source":"#@ Step 1: Clone the github repo \n!git clone https://github.com/Firojpaudel/GEM.git\n\n#@ Step 2: Install all requirements \n!pip install -r /kaggle/working/GEM/requirements.txt -qq","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:55:52.656885Z","iopub.execute_input":"2025-03-16T09:55:52.657123Z","iopub.status.idle":"2025-03-16T09:55:57.925038Z","shell.execute_reply.started":"2025-03-16T09:55:52.657103Z","shell.execute_reply":"2025-03-16T09:55:57.924103Z"}},"outputs":[{"name":"stdout","text":"Cloning into 'GEM'...\nremote: Enumerating objects: 44, done.\u001b[K\nremote: Counting objects: 100% (44/44), done.\u001b[K\nremote: Compressing objects: 100% (33/33), done.\u001b[K\nremote: Total 44 (delta 22), reused 25 (delta 9), pack-reused 0 (from 0)\u001b[K\nReceiving objects: 100% (44/44), 15.42 KiB | 3.85 MiB/s, done.\nResolving deltas: 100% (22/22), done.\n","output_type":"stream"}],"execution_count":3},{"cell_type":"code","source":"import warnings \nwarnings.filterwarnings('ignore')","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:55:57.926619Z","iopub.execute_input":"2025-03-16T09:55:57.926858Z","iopub.status.idle":"2025-03-16T09:55:57.930417Z","shell.execute_reply.started":"2025-03-16T09:55:57.926838Z","shell.execute_reply":"2025-03-16T09:55:57.929640Z"}},"outputs":[],"execution_count":4},{"cell_type":"code","source":"#@ Step 3: Add repo to path\nimport sys\nsys.path.append('/kaggle/working/GEM')\n\n#@ Step 4: Import and run function\nfrom gem_trainer import run_gem_pipeline\nfrom datasets import load_dataset\n\n#@ Rest of the code as above\ndataset_subset = ds['train'].select(range(20000))\nprint(dataset_subset)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:55:57.931412Z","iopub.execute_input":"2025-03-16T09:55:57.931731Z","iopub.status.idle":"2025-03-16T09:56:08.914987Z","shell.execute_reply.started":"2025-03-16T09:55:57.931703Z","shell.execute_reply":"2025-03-16T09:56:08.914244Z"}},"outputs":[{"name":"stderr","text":"The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"0it [00:00, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"4f86a58a7b0840e98a8721989340ac94"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/48.0 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f7196fc1c8b7496990cc7c762f0f27c8"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"config.json: 0%| | 0.00/570 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"2f333f457a624965968714e380dca829"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"423bddc0c7724ca7a934752ffd18c8da"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json: 0%| | 0.00/466k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"711cdb3dba9b49eebc59d297048b0cc8"}},"metadata":{}},{"name":"stdout","text":"Dataset({\n features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],\n num_rows: 20000\n})\n","output_type":"stream"}],"execution_count":5},{"cell_type":"code","source":"### Train-test splittooooo\ndataset = dataset_subset.train_test_split(test_size=0.2, seed=42) ## ahh splitting nowww","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:56:08.915690Z","iopub.execute_input":"2025-03-16T09:56:08.916092Z","iopub.status.idle":"2025-03-16T09:56:08.931307Z","shell.execute_reply.started":"2025-03-16T09:56:08.916062Z","shell.execute_reply":"2025-03-16T09:56:08.930445Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"##@ Label Map..\nlabel_map = {'yes': 0, 'no': 1}","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:56:08.932136Z","iopub.execute_input":"2025-03-16T09:56:08.932373Z","iopub.status.idle":"2025-03-16T09:56:08.943733Z","shell.execute_reply.started":"2025-03-16T09:56:08.932342Z","shell.execute_reply":"2025-03-16T09:56:08.942853Z"}},"outputs":[],"execution_count":7},{"cell_type":"code","source":"from transformers import AutoTokenizer\ntokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:56:08.945189Z","iopub.execute_input":"2025-03-16T09:56:08.945433Z","iopub.status.idle":"2025-03-16T09:56:09.085097Z","shell.execute_reply.started":"2025-03-16T09:56:08.945414Z","shell.execute_reply":"2025-03-16T09:56:09.084527Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"##@ Redefining the tokenization before calling the pipeline function\ndef tokenize_fn(examples):\n # Combine 'question' and 'context['contexts']' into a single string per example\n inputs = [q + \" [SEP] \" + \" \".join(c['contexts']) for q, c in zip(examples['question'], examples['context'])]\n \n # Map 'final_decision' to numerical labels\n labels = [label_map[label] for label in examples['final_decision']]\n \n # Tokenize the combined inputs\n tokenized = tokenizer(\n inputs,\n padding='max_length',\n truncation=True,\n max_length=128\n )\n \n # Add labels to the tokenized output\n tokenized['labels'] = labels\n return tokenized","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:56:09.085917Z","iopub.execute_input":"2025-03-16T09:56:09.086113Z","iopub.status.idle":"2025-03-16T09:56:09.090427Z","shell.execute_reply.started":"2025-03-16T09:56:09.086095Z","shell.execute_reply":"2025-03-16T09:56:09.089552Z"}},"outputs":[],"execution_count":9},{"cell_type":"code","source":"## Pipeline Call\nresults = run_gem_pipeline(\n dataset=dataset,\n model_name=\"bert-base-uncased\",\n num_classes=2,\n num_epochs=5,\n batch_size=128,\n learning_rate=2e-5,\n max_seq_length=128,\n gradient_accum_steps=2,\n cluster_size=256,\n threshold=0.65,\n tokenize_fn=tokenize_fn,\n save_path=\"gem_model_pubmedqa_final.pth\", \n checkpoint_dir=\"checkpoints\", \n checkpoint_interval=2 \n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-16T09:56:09.091271Z","iopub.execute_input":"2025-03-16T09:56:09.091578Z","iopub.status.idle":"2025-03-16T11:43:46.059634Z","shell.execute_reply.started":"2025-03-16T09:56:09.091550Z","shell.execute_reply":"2025-03-16T11:43:46.058721Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/16000 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"35dd0f7cf7494644863e18d35b88bd7c"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/4000 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5d49b153ff0d4011a29ecf35a85c1fdc"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model.safetensors: 0%| | 0.00/440M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9761fba4f0974e179d83b94026049837"}},"metadata":{}},{"name":"stderr","text":"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n100%|ββββββββββ| 125/125 [20:39<00:00, 9.91s/it]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 1/5 | Avg Loss: 0.4772\n","output_type":"stream"},{"name":"stderr","text":"100%|ββββββββββ| 125/125 [20:48<00:00, 9.99s/it]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 2/5 | Avg Loss: 0.4117\nSaved checkpoint to checkpoints/checkpoint_epoch_2.pth\n","output_type":"stream"},{"name":"stderr","text":"100%|ββββββββββ| 125/125 [20:42<00:00, 9.94s/it]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 3/5 | Avg Loss: 0.4077\n","output_type":"stream"},{"name":"stderr","text":"100%|ββββββββββ| 125/125 [20:27<00:00, 9.82s/it]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 4/5 | Avg Loss: 0.3999\nSaved checkpoint to checkpoints/checkpoint_epoch_4.pth\n","output_type":"stream"},{"name":"stderr","text":"100%|ββββββββββ| 125/125 [20:29<00:00, 9.83s/it]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 5/5 | Avg Loss: 0.3862\n","output_type":"stream"},{"name":"stderr","text":"100%|ββββββββββ| 32/32 [03:56<00:00, 7.40s/it]\n","output_type":"stream"},{"name":"stdout","text":"Final Accuracy: 92.58%\nSaved final model to gem_model_pubmedqa_final.pth\n","output_type":"stream"}],"execution_count":10},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]} |