aaronday3 commited on
Commit
0db33af
·
verified ·
1 Parent(s): 891b27f

Upload 2 files

Browse files
convert_into_distilbert_dataset.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The purpose of this file is to take given texts
2
+ # Put AI ones into negative and human ones into positive
3
+ # While making sure to split all the texts into word by word
4
+ # To ensure searching before the text has finished streaming
5
+
6
+ # Example this: "The dog walked over the pavement." will be turned into:
7
+ # The
8
+ # The dog
9
+ # The dog walked
10
+ # The dog walked over
11
+ # The dog walked over the
12
+ # The dog walked over the pavement
13
+ # The dog walked over the pavement.
14
+
15
+ # Example data row:
16
+ # {"query": "Write a story about dogs", "pos": ["lorem ipsum..."], "neg": ["lorem ipsum..."]}
17
+
18
+ import re
19
+ import ujson as json
20
+ import random
21
+ from tqdm import tqdm
22
+
23
+ def split_string(text):
24
+ """Split a given text by spaces and punctuation"""
25
+ # Split the text by spaces
26
+ words = text.split()
27
+
28
+ # For now we disabled further splitting because of issues
29
+ # # Further split each word by punctuation using regex
30
+ # split_words = []
31
+ # for word in words:
32
+ # # Find all substrings that match the pattern: either a word or a punctuation mark
33
+ # split_words.extend(re.findall(r'\w+|[^\w\s]', word))
34
+
35
+ return words
36
+
37
+ reddit_vs_synth_writing_prompts = []
38
+ with open("writing_prompts/reddit_vs_synth_writing_prompts.jsonl", "r") as f:
39
+ temp = f.read()
40
+ for line in temp.splitlines():
41
+ loaded_object = json.loads(line)
42
+ if not "story_human" in loaded_object: # Remove ones where we don't have human data
43
+ continue
44
+
45
+ reddit_vs_synth_writing_prompts.append(loaded_object)
46
+
47
+ dataset_entries = []
48
+
49
+ SAVE_FILE_NAME = "bert_reddit_vs_synth_writing_prompts.jsonl"
50
+
51
+ def add_streamed_data(data):
52
+ entries = []
53
+ data_parts = split_string(data)
54
+
55
+ for i in range(len(data_parts)):
56
+ streamed_so_far = " ".join(data_parts[:i + 1]) # Since python slicing is exclusive toward the end
57
+ entries.append({"text": streamed_so_far, "label": HUMAN_LABEL})
58
+
59
+ return entries
60
+
61
+ with open(SAVE_FILE_NAME, "w") as f:
62
+ f.write("")
63
+
64
+ NUM_OF_TURNS_TO_DUMP = 200
65
+ i = 0
66
+ for data in tqdm(reddit_vs_synth_writing_prompts):
67
+ #  {"text": "AI-generated text example 1", "label": 1},
68
+ # Assuming 1 means AI generated, 0 means human
69
+ HUMAN_LABEL = 0
70
+ AI_LABEL = 1
71
+ i += 1
72
+
73
+ # Below is to enable writing dataset part by part
74
+ if i == NUM_OF_TURNS_TO_DUMP:
75
+ i = 0
76
+ dumped_string = ""
77
+ dumped_entries = []
78
+ for entry in dataset_entries:
79
+ dumped_entries.append(json.dumps(entry))
80
+
81
+ dumped_string = "\n".join(dumped_entries) + "\n"
82
+
83
+ with open(SAVE_FILE_NAME, "a") as f:
84
+ f.write(dumped_string)
85
+
86
+ dataset_entries = []
87
+
88
+ if False: # Disable Streaming
89
+ # Add streamed data
90
+ human_entries = add_streamed_data(data["story_human"])
91
+ dataset_entries.extend(human_entries)
92
+
93
+ ai_data = []
94
+ if data.get("story_opus"):
95
+ ai_data.extend(add_streamed_data(data["story_opus"]))
96
+ if data.get("story_gpt_3_5"):
97
+ ai_data.extend(add_streamed_data(data["story_gpt_3_5"]))
98
+
99
+ dataset_entries.extend(ai_data)
100
+
101
+ else:
102
+ # Add without streaming
103
+ dataset_entries.append({"text": data["story_human"], "label": HUMAN_LABEL})
104
+
105
+ ai_data = []
106
+ if data.get("story_opus"):
107
+ dataset_entries.append({"text": data["story_opus"], "label": AI_LABEL})
108
+ if data.get("story_gpt_3_5"):
109
+ dataset_entries.append({"text": data["story_gpt_3_5"], "label": AI_LABEL})
110
+
111
+ # Dump as JSONL
112
+ dumped_string = ""
113
+ dumped_entries = []
114
+ for entry in dataset_entries:
115
+ dumped_entries.append(json.dumps(entry))
116
+
117
+ dumped_string = "\n".join(dumped_entries) + "\n"
118
+
119
+ with open(SAVE_FILE_NAME, "a") as f:
120
+ f.write(dumped_string)
fine-tune-distil-bert.ipynb ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "!pip install torch transformers scikit-learn wandb accelerate tqdm\n",
10
+ "from IPython.display import clear_output\n",
11
+ "clear_output(wait=True)\n",
12
+ "print(\".\")"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "!apt-get update\n",
22
+ "!apt-get install zstd\n",
23
+ "!tar --use-compress-program=unzstd -xvf bert_streamed_dataset.tar.zst\n",
24
+ "clear_output(wait=True)\n",
25
+ "print(\".\")"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "import torch\n",
35
+ "from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments\n",
36
+ "from sklearn.model_selection import train_test_split\n",
37
+ "from tqdm import tqdm\n",
38
+ "import wandb\n",
39
+ "import json\n",
40
+ "\n",
41
+ "# Initialize W&B\n",
42
+ "wandb.init(project=\"distilbert-ai-text-classification\")\n",
43
+ "\n",
44
+ "# Check if MPS is available and set the device\n",
45
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
46
+ "print(device)\n",
47
+ "\n",
48
+ "# Load pre-trained DistilBERT tokenizer and model\n",
49
+ "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
50
+ "model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)\n",
51
+ "model.to(device)"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "# Load the JSONL dataset\n",
61
+ "data = []\n",
62
+ "total_num_of_lines = 0\n",
63
+ "with open('bert_reddit_vs_synth_writing_prompts.jsonl', 'r') as infile:\n",
64
+ " for line in tqdm(infile, desc=\"Checking dataset size\"):\n",
65
+ " total_num_of_lines += 1\n",
66
+ "\n",
67
+ "with open('bert_reddit_vs_synth_writing_prompts.jsonl', 'r') as infile:\n",
68
+ " for line in tqdm(infile, desc=\"Loading dataset\", total=total_num_of_lines):\n",
69
+ " data.append(json.loads(line))\n",
70
+ "\n",
71
+ "# Extract texts and labels\n",
72
+ "print(\"Extracting texts and labels\")\n",
73
+ "texts = [entry['text'] for entry in data]\n",
74
+ "labels = [entry['label'] for entry in data]\n",
75
+ "\n",
76
+ "# Tokenize the text\n",
77
+ "print(\"Tokenizing text\")\n",
78
+ "inputs = tokenizer(texts, padding=True, truncation=True, return_tensors=\"pt\")\n",
79
+ "\n",
80
+ "# Move input tensors to the device\n",
81
+ "print(\"Moving input tensors\")\n",
82
+ "inputs = {key: val for key, val in inputs.items()}\n",
83
+ "\n",
84
+ "# Split the data into training and validation sets\n",
85
+ "print(\"Splitting data into train and validation\")\n",
86
+ "train_inputs, val_inputs, train_labels, val_labels = train_test_split(\n",
87
+ " inputs['input_ids'], labels, test_size=0.2, random_state=42)\n",
88
+ "\n",
89
+ "train_attention_masks, val_attention_masks, _, _ = train_test_split(\n",
90
+ " inputs['attention_mask'], labels, test_size=0.2, random_state=42)\n",
91
+ "\n",
92
+ "# Create a PyTorch dataset\n",
93
+ "class TextDataset(torch.utils.data.Dataset):\n",
94
+ " def __init__(self, input_ids, attention_masks, labels):\n",
95
+ " self.input_ids = input_ids\n",
96
+ " self.attention_masks = attention_masks\n",
97
+ " self.labels = labels\n",
98
+ "\n",
99
+ " def __len__(self):\n",
100
+ " return len(self.labels)\n",
101
+ "\n",
102
+ " def __getitem__(self, idx):\n",
103
+ " return {\n",
104
+ " 'input_ids': self.input_ids[idx],\n",
105
+ " 'attention_mask': self.attention_masks[idx],\n",
106
+ " 'labels': torch.tensor(self.labels[idx])\n",
107
+ " }\n",
108
+ "\n",
109
+ "print(\"Creating pytorch datasets\")\n",
110
+ "train_dataset = TextDataset(train_inputs, train_attention_masks, train_labels)\n",
111
+ "val_dataset = TextDataset(val_inputs, val_attention_masks, val_labels)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "# Reduce eval set to X examples to speed up training\n",
121
+ "NUM_OF_EVAL_EXAMPLES = 1000\n",
122
+ "val_inputs_subset = val_inputs[:NUM_OF_EVAL_EXAMPLES]\n",
123
+ "val_attention_masks_subset = val_attention_masks[:NUM_OF_EVAL_EXAMPLES]\n",
124
+ "val_labels_subset = val_labels[:NUM_OF_EVAL_EXAMPLES]\n",
125
+ "\n",
126
+ "# Create a TextDataset with only X examples\n",
127
+ "val_dataset = Textdataset(val_inputs_subset, val_attention_masks_subset, val_labels_subset)"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "# Define the training arguments\n",
137
+ "training_args = TrainingArguments(\n",
138
+ " output_dir='./distil-bert-train-results', \n",
139
+ " num_train_epochs=3, \n",
140
+ " per_device_train_batch_size=16, \n",
141
+ " per_device_eval_batch_size=16, \n",
142
+ " warmup_steps=500, # number of warmup steps for learning rate scheduler\n",
143
+ " weight_decay=0.01, \n",
144
+ " logging_dir='./logs', \n",
145
+ " logging_steps=10, \n",
146
+ " report_to=\"wandb\", \n",
147
+ " evaluation_strategy=\"steps\", # Evaluate every logging step\n",
148
+ " eval_steps=100, # Evaluate every 10 steps\n",
149
+ " fp16=True,\n",
150
+ ")\n",
151
+ "\n",
152
+ "# Create the Trainer\n",
153
+ "trainer = Trainer(\n",
154
+ " model=model, # the instantiated 🤗 Transformers model to be trained\n",
155
+ " args=training_args, # training arguments, defined above\n",
156
+ " train_dataset=train_dataset, # training dataset\n",
157
+ " eval_dataset=val_dataset # evaluation dataset\n",
158
+ ")\n",
159
+ "\n",
160
+ "# Train the model\n",
161
+ "trainer.train()\n",
162
+ "\n",
163
+ "# Save the model\n",
164
+ "model.save_pretrained('./distil-bert-train-final-result')\n",
165
+ "\n",
166
+ "# Finish the W&B run\n",
167
+ "wandb.finish()"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": []
176
+ }
177
+ ],
178
+ "metadata": {
179
+ "kernelspec": {
180
+ "display_name": "Python 3 (ipykernel)",
181
+ "language": "python",
182
+ "name": "python3"
183
+ },
184
+ "language_info": {
185
+ "codemirror_mode": {
186
+ "name": "ipython",
187
+ "version": 3
188
+ },
189
+ "file_extension": ".py",
190
+ "mimetype": "text/x-python",
191
+ "name": "python",
192
+ "nbconvert_exporter": "python",
193
+ "pygments_lexer": "ipython3",
194
+ "version": "3.10.12"
195
+ }
196
+ },
197
+ "nbformat": 4,
198
+ "nbformat_minor": 4
199
+ }