AmelieSchreiber commited on
Commit
63a538f
1 Parent(s): 2248c60

Upload data_processing_v1.ipynb

Browse files
Files changed (1) hide show
  1. data_processing_v1.ipynb +692 -0
data_processing_v1.ipynb ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "91af3f42-063e-4d5c-ae16-c7c54599d582",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Number of entries with angle brackets: 35\n",
14
+ "Number of remaining rows: 16460737\n",
15
+ "Number of distinct protein families: 10258\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "100%|██████████| 10258/10258 [00:04<00:00, 2232.02family/s]\n"
23
+ ]
24
+ },
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "Number of distinct protein families in the test set: 2076\n",
30
+ "Number of distinct protein families in the train set: 8182\n",
31
+ "Percentage of families in test set: 0.20237863131214662\n"
32
+ ]
33
+ },
34
+ {
35
+ "data": {
36
+ "text/plain": [
37
+ "(3307395, 13153342)"
38
+ ]
39
+ },
40
+ "execution_count": 2,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "import pandas as pd\n",
47
+ "import numpy as np\n",
48
+ "from tqdm import tqdm\n",
49
+ "\n",
50
+ "# Load the dataset\n",
51
+ "file_path = 'binding_sites_uniprot_16M.tsv'\n",
52
+ "data = pd.read_csv(file_path, sep='\\t')\n",
53
+ "\n",
54
+ "# Display the first few rows of the dataframe\n",
55
+ "#data.head()\n",
56
+ "\n",
57
+ "# Filter out rows with NaN values in the 'Protein families' column\n",
58
+ "data = data[pd.notna(data['Protein families'])]\n",
59
+ "\n",
60
+ "# Display the first few rows of the modified dataframe\n",
61
+ "#data.head()\n",
62
+ "\n",
63
+ "# Group the data by 'Protein families' and get the size of each group\n",
64
+ "family_sizes = data.groupby('Protein families').size()\n",
65
+ "\n",
66
+ "# Create a new column with the size of each family\n",
67
+ "data['Family size'] = data['Protein families'].map(family_sizes)\n",
68
+ "\n",
69
+ "# Sort the data by 'Family size' in descending order and then by 'Protein families'\n",
70
+ "data_sorted = data.sort_values(by=['Family size', 'Protein families'], ascending=[False, True])\n",
71
+ "\n",
72
+ "# Drop the 'Family size' column as it is no longer needed\n",
73
+ "data_sorted.drop(columns='Family size', inplace=True)\n",
74
+ "\n",
75
+ "# Define a function to extract the location from the binding and active site columns\n",
76
+ "def extract_location(site_info):\n",
77
+ " if pd.isnull(site_info):\n",
78
+ " return None\n",
79
+ " locations = []\n",
80
+ " for info in site_info.split(';'):\n",
81
+ " if 'BINDING' in info or 'ACT_SITE' in info:\n",
82
+ " locations.append(info.split()[1])\n",
83
+ " return '; '.join(locations)\n",
84
+ "\n",
85
+ "# Apply the function to the 'Binding site' and 'Active site' columns to extract the locations\n",
86
+ "data_sorted['Binding site'] = data_sorted['Binding site'].apply(extract_location)\n",
87
+ "data_sorted['Active site'] = data_sorted['Active site'].apply(extract_location)\n",
88
+ "\n",
89
+ "# Display the first few rows of the modified dataframe\n",
90
+ "#data_sorted.head()\n",
91
+ "\n",
92
+ "# Create a new column that combines the 'Binding site' and 'Active site' columns\n",
93
+ "data_sorted['Binding-Active site'] = data_sorted['Binding site'].astype(str) + '; ' + data_sorted['Active site'].astype(str)\n",
94
+ "\n",
95
+ "# Replace 'nan' values with None\n",
96
+ "data_sorted['Binding-Active site'] = data_sorted['Binding-Active site'].replace('nan; nan', None)\n",
97
+ "\n",
98
+ "# Display the first few rows of the updated dataframe\n",
99
+ "#data_sorted.head()\n",
100
+ "\n",
101
+ "# Find entries in the \"Binding-Active site\" column containing '<' or '>'\n",
102
+ "entries_with_angle_brackets = data_sorted['Binding-Active site'].str.contains('<|>', na=False)\n",
103
+ "\n",
104
+ "# Get the number of such entries\n",
105
+ "num_entries_with_angle_brackets = entries_with_angle_brackets.sum()\n",
106
+ "\n",
107
+ "# Display the number of entries containing '<' or '>'\n",
108
+ "print(f\"Number of entries with angle brackets: {num_entries_with_angle_brackets}\")\n",
109
+ "\n",
110
+ "# Remove all rows where the \"Binding-Active site\" column contains '<' or '>'\n",
111
+ "data_filtered = data_sorted[~entries_with_angle_brackets]\n",
112
+ "\n",
113
+ "# Get the number of remaining rows\n",
114
+ "num_remaining_rows = data_filtered.shape[0]\n",
115
+ "\n",
116
+ "# Display the number of remaining rows\n",
117
+ "print(f\"Number of remaining rows: {num_remaining_rows}\")\n",
118
+ "\n",
119
+ "# Get the number of distinct protein families\n",
120
+ "num_distinct_families = data_filtered['Protein families'].nunique()\n",
121
+ "\n",
122
+ "# Display the number of distinct protein families\n",
123
+ "print(f\"Number of distinct protein families: {num_distinct_families}\")\n",
124
+ "\n",
125
+ "# Define the target number of rows for the test set (approximately 20% of the data)\n",
126
+ "target_test_rows = int(0.20 * num_remaining_rows)\n",
127
+ "\n",
128
+ "# Get unique protein families\n",
129
+ "unique_families = data_filtered['Protein families'].unique()\n",
130
+ "\n",
131
+ "# Shuffle the unique families to randomize the selection\n",
132
+ "np.random.shuffle(unique_families)\n",
133
+ "\n",
134
+ "# Group the data by 'Protein families' to facilitate faster family-wise selection\n",
135
+ "grouped_data = data_filtered.groupby('Protein families')\n",
136
+ "\n",
137
+ "# Initialize variables to keep track of the selected rows for the test and train sets\n",
138
+ "test_rows = []\n",
139
+ "current_test_rows = 0\n",
140
+ "\n",
141
+ "# Initialize a flag to indicate whether the threshold has been crossed\n",
142
+ "threshold_crossed = False\n",
143
+ "\n",
144
+ "# Initialize a variable to keep track of the previous family\n",
145
+ "previous_family = None\n",
146
+ "\n",
147
+ "# Loop through the shuffled families and add rows to the test set until we reach the target number of rows\n",
148
+ "for family in tqdm(unique_families, unit=\"family\"):\n",
149
+ " family_rows = grouped_data.get_group(family).index.tolist()\n",
150
+ " # If the threshold is not yet crossed, or the family is the same as the previous family, add the family to the test set\n",
151
+ " if not threshold_crossed or (previous_family == family):\n",
152
+ " test_rows.extend(family_rows)\n",
153
+ " current_test_rows += len(family_rows)\n",
154
+ " previous_family = family # Keep track of the previous family\n",
155
+ " # Check if the threshold is crossed with the addition of the current family\n",
156
+ " if current_test_rows >= target_test_rows:\n",
157
+ " threshold_crossed = True # Set the flag to True once the threshold is crossed\n",
158
+ "\n",
159
+ "# Get the indices of the rows for the train set (all rows not in the test set) using set operations for efficiency\n",
160
+ "train_rows = set(data_filtered.index) - set(test_rows)\n",
161
+ "\n",
162
+ "# Create the test and train datasets using loc indexer with list of indices\n",
163
+ "test_df = data_filtered.loc[list(test_rows)]\n",
164
+ "train_df = data_filtered.loc[list(train_rows)]\n",
165
+ "\n",
166
+ "# Print the number of distinct protein families in the test and train sets\n",
167
+ "num_test_families = test_df['Protein families'].nunique()\n",
168
+ "num_train_families = train_df['Protein families'].nunique()\n",
169
+ "print(f\"Number of distinct protein families in the test set: {num_test_families}\")\n",
170
+ "print(f\"Number of distinct protein families in the train set: {num_train_families}\")\n",
171
+ "percentage = num_test_families/(num_test_families+num_train_families)\n",
172
+ "print(f\"Percentage of families in test set: {percentage}\")\n",
173
+ "\n",
174
+ "test_df.shape[0], train_df.shape[0]\n"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 3,
180
+ "id": "772edd92-5137-486a-8a81-8ab7bf51568f",
181
+ "metadata": {},
182
+ "outputs": [
183
+ {
184
+ "name": "stdout",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "Number of common families: 0\n",
188
+ "No common families between test and train datasets.\n"
189
+ ]
190
+ }
191
+ ],
192
+ "source": [
193
+ "# Get the unique families in the test and train datasets\n",
194
+ "unique_test_families = set(test_df['Protein families'].unique())\n",
195
+ "unique_train_families = set(train_df['Protein families'].unique())\n",
196
+ "\n",
197
+ "# Find the common families between the test and train datasets\n",
198
+ "common_families = unique_test_families.intersection(unique_train_families)\n",
199
+ "\n",
200
+ "# Output the common families and their count\n",
201
+ "print(f\"Number of common families: {len(common_families)}\")\n",
202
+ "if len(common_families) > 0:\n",
203
+ " print(f\"Common families: {common_families}\")\n",
204
+ "else:\n",
205
+ " print(\"No common families between test and train datasets.\")\n"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 4,
211
+ "id": "bc8825d6-60f8-4029-a4ab-a2317b170d09",
212
+ "metadata": {},
213
+ "outputs": [
214
+ {
215
+ "name": "stdout",
216
+ "output_type": "stream",
217
+ "text": [
218
+ "Number of test rows with question mark: 0\n",
219
+ "Number of train rows with question mark: 2\n",
220
+ "Number of remaining test rows: 3307395\n",
221
+ "Number of remaining train rows: 13153340\n"
222
+ ]
223
+ }
224
+ ],
225
+ "source": [
226
+ "import re\n",
227
+ "\n",
228
+ "# Find rows where the \"Binding-Active site\" column contains the character \"?\", treating \"?\" as a literal character\n",
229
+ "test_rows_with_question_mark = test_df[test_df['Binding-Active site'].str.contains('\\?', na=False, regex=True)]\n",
230
+ "train_rows_with_question_mark = train_df[train_df['Binding-Active site'].str.contains('\\?', na=False, regex=True)]\n",
231
+ "\n",
232
+ "# Get the number of such rows in both datasets\n",
233
+ "num_test_rows_with_question_mark = len(test_rows_with_question_mark)\n",
234
+ "num_train_rows_with_question_mark = len(train_rows_with_question_mark)\n",
235
+ "\n",
236
+ "print(f\"Number of test rows with question mark: {num_test_rows_with_question_mark}\")\n",
237
+ "print(f\"Number of train rows with question mark: {num_train_rows_with_question_mark}\")\n",
238
+ "\n",
239
+ "# Delete the rows containing '?' in the \"Binding-Active site\" column\n",
240
+ "test_df = test_df.drop(test_rows_with_question_mark.index)\n",
241
+ "train_df = train_df.drop(train_rows_with_question_mark.index)\n",
242
+ "\n",
243
+ "# Check the number of remaining rows in both datasets\n",
244
+ "remaining_test_rows = test_df.shape[0]\n",
245
+ "remaining_train_rows = train_df.shape[0]\n",
246
+ "\n",
247
+ "print(f\"Number of remaining test rows: {remaining_test_rows}\")\n",
248
+ "print(f\"Number of remaining train rows: {remaining_train_rows}\")\n",
249
+ "\n",
250
+ "def expand_ranges(s):\n",
251
+ " \"\"\"Expand ranges in a string.\"\"\"\n",
252
+ " return re.sub(r'(\\d+)\\.\\.(\\d+)', lambda m: ', '.join(map(str, range(int(m.group(1)), int(m.group(2))+1))), str(s))\n",
253
+ "\n",
254
+ "# Apply the function to expand ranges in the \"Binding-Active site\" column in both datasets\n",
255
+ "test_df['Binding-Active site'] = test_df['Binding-Active site'].apply(expand_ranges)\n",
256
+ "train_df['Binding-Active site'] = train_df['Binding-Active site'].apply(expand_ranges)\n"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 5,
262
+ "id": "d91da865-495a-4c1d-91ed-0ebeff1ecd50",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "def convert_to_binary_list(binding_active_str, sequence_len):\n",
267
+ " \"\"\"Convert a Binding-Active site string to a binary list based on the sequence length.\"\"\"\n",
268
+ " # Step 2: Create a list of 0s with length equal to the sequence length\n",
269
+ " binary_list = [0] * sequence_len\n",
270
+ " \n",
271
+ " # Step 3: Retrieve the indices and set the corresponding positions to 1\n",
272
+ " if pd.notna(binding_active_str):\n",
273
+ " # Get the indices from the binding-active site string\n",
274
+ " indices = [int(x) - 1 for segment in binding_active_str.split(';') for x in segment.split(',') if x.strip().isdigit()]\n",
275
+ " for idx in indices:\n",
276
+ " # Ensure the index is within the valid range\n",
277
+ " if 0 <= idx < sequence_len:\n",
278
+ " binary_list[idx] = 1\n",
279
+ " \n",
280
+ " # Step 4: Return the binary list\n",
281
+ " return binary_list\n",
282
+ "\n",
283
+ "# Apply the function to both datasets\n",
284
+ "test_df['Binding-Active site'] = test_df.apply(lambda row: convert_to_binary_list(row['Binding-Active site'], len(row['Sequence'])), axis=1)\n",
285
+ "train_df['Binding-Active site'] = train_df.apply(lambda row: convert_to_binary_list(row['Binding-Active site'], len(row['Sequence'])), axis=1)\n"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 6,
291
+ "id": "4cea2656-75eb-4350-b1a6-704c18793473",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/html": [
297
+ "<div>\n",
298
+ "<style scoped>\n",
299
+ " .dataframe tbody tr th:only-of-type {\n",
300
+ " vertical-align: middle;\n",
301
+ " }\n",
302
+ "\n",
303
+ " .dataframe tbody tr th {\n",
304
+ " vertical-align: top;\n",
305
+ " }\n",
306
+ "\n",
307
+ " .dataframe thead th {\n",
308
+ " text-align: right;\n",
309
+ " }\n",
310
+ "</style>\n",
311
+ "<table border=\"1\" class=\"dataframe\">\n",
312
+ " <thead>\n",
313
+ " <tr style=\"text-align: right;\">\n",
314
+ " <th></th>\n",
315
+ " <th>Entry</th>\n",
316
+ " <th>Protein families</th>\n",
317
+ " <th>Binding site</th>\n",
318
+ " <th>Active site</th>\n",
319
+ " <th>Sequence</th>\n",
320
+ " <th>Binding-Active site</th>\n",
321
+ " </tr>\n",
322
+ " </thead>\n",
323
+ " <tbody>\n",
324
+ " <tr>\n",
325
+ " <th>791321</th>\n",
326
+ " <td>A0A0C2CBT0</td>\n",
327
+ " <td>TDD superfamily, TSR3 family; Protein kinase s...</td>\n",
328
+ " <td>275; 323; 346</td>\n",
329
+ " <td>None</td>\n",
330
+ " <td>MFDVFSGHNDAVLCVQYRDQESLAVSGSADNSIKCWDTRTGRPEMT...</td>\n",
331
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
332
+ " </tr>\n",
333
+ " <tr>\n",
334
+ " <th>1008964</th>\n",
335
+ " <td>A0A0N4V212</td>\n",
336
+ " <td>TDD superfamily, TSR3 family; Protein kinase s...</td>\n",
337
+ " <td>131; 179; 202</td>\n",
338
+ " <td>None</td>\n",
339
+ " <td>MVGYGVRARASGYHGRSKFRVKNKRKADKSYAENVSELAADSSRAI...</td>\n",
340
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <th>1009019</th>\n",
344
+ " <td>A0A0N4XGU1</td>\n",
345
+ " <td>TDD superfamily, TSR3 family; Protein kinase s...</td>\n",
346
+ " <td>73; 121; 178</td>\n",
347
+ " <td>None</td>\n",
348
+ " <td>MGKKGREQHGNKRTNKSRHADAGDAEPLSSHGEEDSESLDESRDDH...</td>\n",
349
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
350
+ " </tr>\n",
351
+ " <tr>\n",
352
+ " <th>1837901</th>\n",
353
+ " <td>A0A1I8B1G5</td>\n",
354
+ " <td>TDD superfamily, TSR3 family; Protein kinase s...</td>\n",
355
+ " <td>40; 88; 111</td>\n",
356
+ " <td>None</td>\n",
357
+ " <td>MASTDSSQSSDEDAKVEKAKKMPCILAMFDFGQCDPKRCSGRKLCR...</td>\n",
358
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
359
+ " </tr>\n",
360
+ " <tr>\n",
361
+ " <th>5447097</th>\n",
362
+ " <td>A0A6V7USP8</td>\n",
363
+ " <td>TDD superfamily, TSR3 family; Protein kinase s...</td>\n",
364
+ " <td>61; 109; 132</td>\n",
365
+ " <td>None</td>\n",
366
+ " <td>MLFMVVPVLIMMQVDVVAIKKMTNTDSSESSGDDAVDDKSKKMPCI...</td>\n",
367
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
368
+ " </tr>\n",
369
+ " </tbody>\n",
370
+ "</table>\n",
371
+ "</div>"
372
+ ],
373
+ "text/plain": [
374
+ " Entry Protein families \\\n",
375
+ "791321 A0A0C2CBT0 TDD superfamily, TSR3 family; Protein kinase s... \n",
376
+ "1008964 A0A0N4V212 TDD superfamily, TSR3 family; Protein kinase s... \n",
377
+ "1009019 A0A0N4XGU1 TDD superfamily, TSR3 family; Protein kinase s... \n",
378
+ "1837901 A0A1I8B1G5 TDD superfamily, TSR3 family; Protein kinase s... \n",
379
+ "5447097 A0A6V7USP8 TDD superfamily, TSR3 family; Protein kinase s... \n",
380
+ "\n",
381
+ " Binding site Active site \\\n",
382
+ "791321 275; 323; 346 None \n",
383
+ "1008964 131; 179; 202 None \n",
384
+ "1009019 73; 121; 178 None \n",
385
+ "1837901 40; 88; 111 None \n",
386
+ "5447097 61; 109; 132 None \n",
387
+ "\n",
388
+ " Sequence \\\n",
389
+ "791321 MFDVFSGHNDAVLCVQYRDQESLAVSGSADNSIKCWDTRTGRPEMT... \n",
390
+ "1008964 MVGYGVRARASGYHGRSKFRVKNKRKADKSYAENVSELAADSSRAI... \n",
391
+ "1009019 MGKKGREQHGNKRTNKSRHADAGDAEPLSSHGEEDSESLDESRDDH... \n",
392
+ "1837901 MASTDSSQSSDEDAKVEKAKKMPCILAMFDFGQCDPKRCSGRKLCR... \n",
393
+ "5447097 MLFMVVPVLIMMQVDVVAIKKMTNTDSSESSGDDAVDDKSKKMPCI... \n",
394
+ "\n",
395
+ " Binding-Active site \n",
396
+ "791321 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
397
+ "1008964 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
398
+ "1009019 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
399
+ "1837901 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
400
+ "5447097 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... "
401
+ ]
402
+ },
403
+ "execution_count": 6,
404
+ "metadata": {},
405
+ "output_type": "execute_result"
406
+ }
407
+ ],
408
+ "source": [
409
+ "test_df.head()"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": 7,
415
+ "id": "bf55ec46-3685-41a3-a382-46273940ed79",
416
+ "metadata": {},
417
+ "outputs": [
418
+ {
419
+ "data": {
420
+ "text/html": [
421
+ "<div>\n",
422
+ "<style scoped>\n",
423
+ " .dataframe tbody tr th:only-of-type {\n",
424
+ " vertical-align: middle;\n",
425
+ " }\n",
426
+ "\n",
427
+ " .dataframe tbody tr th {\n",
428
+ " vertical-align: top;\n",
429
+ " }\n",
430
+ "\n",
431
+ " .dataframe thead th {\n",
432
+ " text-align: right;\n",
433
+ " }\n",
434
+ "</style>\n",
435
+ "<table border=\"1\" class=\"dataframe\">\n",
436
+ " <thead>\n",
437
+ " <tr style=\"text-align: right;\">\n",
438
+ " <th></th>\n",
439
+ " <th>Entry</th>\n",
440
+ " <th>Protein families</th>\n",
441
+ " <th>Binding site</th>\n",
442
+ " <th>Active site</th>\n",
443
+ " <th>Sequence</th>\n",
444
+ " <th>Binding-Active site</th>\n",
445
+ " </tr>\n",
446
+ " </thead>\n",
447
+ " <tbody>\n",
448
+ " <tr>\n",
449
+ " <th>1</th>\n",
450
+ " <td>A0A009GI32</td>\n",
451
+ " <td>3-hydroxyacyl-CoA dehydrogenase family; Enoyl-...</td>\n",
452
+ " <td>298; 326; 345; 402..404; 409; 431; 455; 502</td>\n",
453
+ " <td>452</td>\n",
454
+ " <td>MIHAGNAITVQMLADGIAEFRFDLQGESVNKFNRATIEDFKAAIAA...</td>\n",
455
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
456
+ " </tr>\n",
457
+ " <tr>\n",
458
+ " <th>3</th>\n",
459
+ " <td>A0A009HWM5</td>\n",
460
+ " <td>3-hydroxyacyl-CoA dehydrogenase family; Enoyl-...</td>\n",
461
+ " <td>298; 326; 345; 402..404; 409; 431; 455; 502</td>\n",
462
+ " <td>452</td>\n",
463
+ " <td>MIHAGNAITVQMLADGIAEFRFDLQGESVNKFNRATIEDFKAAIAA...</td>\n",
464
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <th>4</th>\n",
468
+ " <td>A0A009I6Q1</td>\n",
469
+ " <td>3-hydroxyacyl-CoA dehydrogenase family; Enoyl-...</td>\n",
470
+ " <td>298; 326; 345; 402..404; 409; 431; 455; 502</td>\n",
471
+ " <td>452</td>\n",
472
+ " <td>MIHAGNAITVQMLSDGIAEFRFDLQGESVNKFNRATIEDFQAAIAA...</td>\n",
473
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
474
+ " </tr>\n",
475
+ " <tr>\n",
476
+ " <th>7</th>\n",
477
+ " <td>A0A009NCR4</td>\n",
478
+ " <td>3-hydroxyacyl-CoA dehydrogenase family; Enoyl-...</td>\n",
479
+ " <td>298; 326; 345; 402..404; 409; 431; 455; 502</td>\n",
480
+ " <td>452</td>\n",
481
+ " <td>MIHAGNAITVQMLSDGIAEFRFDLQGESVNKFNRATIEDFQAAIAA...</td>\n",
482
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
483
+ " </tr>\n",
484
+ " <tr>\n",
485
+ " <th>9</th>\n",
486
+ " <td>A0A009QK39</td>\n",
487
+ " <td>3-hydroxyacyl-CoA dehydrogenase family; Enoyl-...</td>\n",
488
+ " <td>298; 326; 345; 402..404; 409; 431; 455; 502</td>\n",
489
+ " <td>452</td>\n",
490
+ " <td>MIHAGNAITVQMLADGIAEFRFDLQGESVNKFNRATIEDFKAAIAA...</td>\n",
491
+ " <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
492
+ " </tr>\n",
493
+ " </tbody>\n",
494
+ "</table>\n",
495
+ "</div>"
496
+ ],
497
+ "text/plain": [
498
+ " Entry Protein families \\\n",
499
+ "1 A0A009GI32 3-hydroxyacyl-CoA dehydrogenase family; Enoyl-... \n",
500
+ "3 A0A009HWM5 3-hydroxyacyl-CoA dehydrogenase family; Enoyl-... \n",
501
+ "4 A0A009I6Q1 3-hydroxyacyl-CoA dehydrogenase family; Enoyl-... \n",
502
+ "7 A0A009NCR4 3-hydroxyacyl-CoA dehydrogenase family; Enoyl-... \n",
503
+ "9 A0A009QK39 3-hydroxyacyl-CoA dehydrogenase family; Enoyl-... \n",
504
+ "\n",
505
+ " Binding site Active site \\\n",
506
+ "1 298; 326; 345; 402..404; 409; 431; 455; 502 452 \n",
507
+ "3 298; 326; 345; 402..404; 409; 431; 455; 502 452 \n",
508
+ "4 298; 326; 345; 402..404; 409; 431; 455; 502 452 \n",
509
+ "7 298; 326; 345; 402..404; 409; 431; 455; 502 452 \n",
510
+ "9 298; 326; 345; 402..404; 409; 431; 455; 502 452 \n",
511
+ "\n",
512
+ " Sequence \\\n",
513
+ "1 MIHAGNAITVQMLADGIAEFRFDLQGESVNKFNRATIEDFKAAIAA... \n",
514
+ "3 MIHAGNAITVQMLADGIAEFRFDLQGESVNKFNRATIEDFKAAIAA... \n",
515
+ "4 MIHAGNAITVQMLSDGIAEFRFDLQGESVNKFNRATIEDFQAAIAA... \n",
516
+ "7 MIHAGNAITVQMLSDGIAEFRFDLQGESVNKFNRATIEDFQAAIAA... \n",
517
+ "9 MIHAGNAITVQMLADGIAEFRFDLQGESVNKFNRATIEDFKAAIAA... \n",
518
+ "\n",
519
+ " Binding-Active site \n",
520
+ "1 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
521
+ "3 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
522
+ "4 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
523
+ "7 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n",
524
+ "9 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... "
525
+ ]
526
+ },
527
+ "execution_count": 7,
528
+ "metadata": {},
529
+ "output_type": "execute_result"
530
+ }
531
+ ],
532
+ "source": [
533
+ "train_df.head()"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "code",
538
+ "execution_count": 8,
539
+ "id": "1a997e94-2bea-4c56-89f2-f10737c96447",
540
+ "metadata": {},
541
+ "outputs": [
542
+ {
543
+ "data": {
544
+ "text/plain": [
545
+ "('065_data/test_labels_chunked_by_family.pkl',\n",
546
+ " '065_data/test_sequences_chunked_by_family.pkl',\n",
547
+ " '065_data/train_labels_chunked_by_family.pkl',\n",
548
+ " '065_data/train_sequences_chunked_by_family.pkl')"
549
+ ]
550
+ },
551
+ "execution_count": 8,
552
+ "metadata": {},
553
+ "output_type": "execute_result"
554
+ }
555
+ ],
556
+ "source": [
557
+ "import pickle\n",
558
+ "import random\n",
559
+ "\n",
560
+ "def split_into_chunks(sequences, labels):\n",
561
+ " \"\"\"Split sequences and labels into chunks of size 1000 or less.\"\"\"\n",
562
+ " chunk_size = 1000\n",
563
+ " new_sequences = []\n",
564
+ " new_labels = []\n",
565
+ " \n",
566
+ " for seq, lbl in zip(sequences, labels):\n",
567
+ " if len(seq) > chunk_size:\n",
568
+ " # Split the sequence and labels into chunks of size 1000 or less\n",
569
+ " for i in range(0, len(seq), chunk_size):\n",
570
+ " new_sequences.append(seq[i:i+chunk_size])\n",
571
+ " new_labels.append(lbl[i:i+chunk_size])\n",
572
+ " else:\n",
573
+ " new_sequences.append(seq)\n",
574
+ " new_labels.append(lbl)\n",
575
+ " \n",
576
+ " return new_sequences, new_labels\n",
577
+ "\n",
578
+ "# Extract the necessary columns to create lists of sequences and labels\n",
579
+ "test_sequences_by_family = test_df['Sequence'].tolist()\n",
580
+ "test_labels_by_family = test_df['Binding-Active site'].tolist()\n",
581
+ "train_sequences_by_family = train_df['Sequence'].tolist()\n",
582
+ "train_labels_by_family = train_df['Binding-Active site'].tolist()\n",
583
+ "\n",
584
+ "# Get the number of samples in each dataset\n",
585
+ "num_test_samples = len(test_sequences_by_family)\n",
586
+ "num_train_samples = len(train_sequences_by_family)\n",
587
+ "\n",
588
+ "# Define the percentage of data you want to keep\n",
589
+ "percentage_to_keep = 100 # for keeping 6.00% of the data\n",
590
+ "\n",
591
+ "# Generate random indices representing a percentage of each dataset\n",
592
+ "random_test_indices = random.sample(range(num_test_samples), int(num_test_samples * (percentage_to_keep / 100)))\n",
593
+ "random_train_indices = random.sample(range(num_train_samples), int(num_train_samples * (percentage_to_keep / 100)))\n",
594
+ "\n",
595
+ "# Create smaller datasets using the random indices\n",
596
+ "test_sequences_small = [test_sequences_by_family[i] for i in random_test_indices]\n",
597
+ "test_labels_small = [test_labels_by_family[i] for i in random_test_indices]\n",
598
+ "train_sequences_small = [train_sequences_by_family[i] for i in random_train_indices]\n",
599
+ "train_labels_small = [train_labels_by_family[i] for i in random_train_indices]\n",
600
+ "\n",
601
+ "# Apply the function to create new datasets with chunks of size 1000 or less\n",
602
+ "test_sequences_chunked, test_labels_chunked = split_into_chunks(test_sequences_small, test_labels_small)\n",
603
+ "train_sequences_chunked, train_labels_chunked = split_into_chunks(train_sequences_small, train_labels_small)\n",
604
+ "\n",
605
+ "# Paths to save the new chunked pickle files\n",
606
+ "test_labels_chunked_path = '16M_data/test_labels_chunked_by_family.pkl'\n",
607
+ "test_sequences_chunked_path = '16M_data/test_sequences_chunked_by_family.pkl'\n",
608
+ "train_labels_chunked_path = '16M_data/train_labels_chunked_by_family.pkl'\n",
609
+ "train_sequences_chunked_path = '16M_data/train_sequences_chunked_by_family.pkl'\n",
610
+ "\n",
611
+ "# Save the chunked datasets as new pickle files\n",
612
+ "with open(test_labels_chunked_path, 'wb') as file:\n",
613
+ " pickle.dump(test_labels_chunked, file)\n",
614
+ "with open(test_sequences_chunked_path, 'wb') as file:\n",
615
+ " pickle.dump(test_sequences_chunked, file)\n",
616
+ "with open(train_labels_chunked_path, 'wb') as file:\n",
617
+ " pickle.dump(train_labels_chunked, file)\n",
618
+ "with open(train_sequences_chunked_path, 'wb') as file:\n",
619
+ " pickle.dump(train_sequences_chunked, file)\n",
620
+ "\n",
621
+ "test_labels_chunked_path, test_sequences_chunked_path, train_labels_chunked_path, train_sequences_chunked_path\n"
622
+ ]
623
+ },
624
+ {
625
+ "cell_type": "code",
626
+ "execution_count": 9,
627
+ "id": "6479ec75-c1a2-403c-8139-43e9754cc137",
628
+ "metadata": {},
629
+ "outputs": [
630
+ {
631
+ "data": {
632
+ "text/plain": [
633
+ "(220620, 220620, 890637, 890637)"
634
+ ]
635
+ },
636
+ "execution_count": 9,
637
+ "metadata": {},
638
+ "output_type": "execute_result"
639
+ }
640
+ ],
641
+ "source": [
642
+ "# Load each pickle file and get the number of entries in each\n",
643
+ "with open(test_labels_chunked_path, 'rb') as file:\n",
644
+ " test_labels_chunked = pickle.load(file)\n",
645
+ " num_test_labels_chunked = len(test_labels_chunked)\n",
646
+ "\n",
647
+ "with open(test_sequences_chunked_path, 'rb') as file:\n",
648
+ " test_sequences_chunked = pickle.load(file)\n",
649
+ " num_test_sequences_chunked = len(test_sequences_chunked)\n",
650
+ "\n",
651
+ "with open(train_labels_chunked_path, 'rb') as file:\n",
652
+ " train_labels_chunked = pickle.load(file)\n",
653
+ " num_train_labels_chunked = len(train_labels_chunked)\n",
654
+ "\n",
655
+ "with open(train_sequences_chunked_path, 'rb') as file:\n",
656
+ " train_sequences_chunked = pickle.load(file)\n",
657
+ " num_train_sequences_chunked = len(train_sequences_chunked)\n",
658
+ "\n",
659
+ "num_test_labels_chunked, num_test_sequences_chunked, num_train_labels_chunked, num_train_sequences_chunked\n"
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "code",
664
+ "execution_count": null,
665
+ "id": "da7df429-62ab-4b8e-b3dd-7c5a9eb14921",
666
+ "metadata": {},
667
+ "outputs": [],
668
+ "source": []
669
+ }
670
+ ],
671
+ "metadata": {
672
+ "kernelspec": {
673
+ "display_name": "esm2_binding_py38b",
674
+ "language": "python",
675
+ "name": "esm2_binding_py38b"
676
+ },
677
+ "language_info": {
678
+ "codemirror_mode": {
679
+ "name": "ipython",
680
+ "version": 3
681
+ },
682
+ "file_extension": ".py",
683
+ "mimetype": "text/x-python",
684
+ "name": "python",
685
+ "nbconvert_exporter": "python",
686
+ "pygments_lexer": "ipython3",
687
+ "version": "3.8.17"
688
+ }
689
+ },
690
+ "nbformat": 4,
691
+ "nbformat_minor": 5
692
+ }