sgoel30 commited on
Commit
faaee28
·
verified ·
1 Parent(s): b6a71c9

Delete benchmarks/DeepLoc

Browse files
benchmarks/DeepLoc/OG_membrane_type_all.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2d878da32a06092f880262048e3c1eb692721c274b0a458fcc712a0dcbd80c71
3
- size 15683507
 
 
 
 
benchmarks/DeepLoc/cell_localization_predictor.py DELETED
@@ -1,137 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torch.utils.data import DataLoader, Dataset
5
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
6
-
7
- from tqdm import tqdm
8
- from datetime import datetime
9
- import pandas as pd
10
- import numpy as np
11
- import pickle
12
- import os
13
-
14
- # Hyperparameters dictionary
15
- path = "/home/a03-sgoel/MDpLM"
16
-
17
- hyperparams = {
18
- "batch_size": 1,
19
- "learning_rate": 4e-5,
20
- "num_epochs": 5,
21
- "max_length": 2000,
22
- "train_data": path + "/benchmarks/DeepLoc/cell_localization_train_val.csv.csv",
23
- "test_data" : path + "/benchmarks/DeepLoc/cell_localization_test.csv",
24
- "val_data": "", # None
25
- "embeddings_pkl": "", # Need to generate ESM embeddings and save as pkl file
26
- }
27
-
28
- # Dataset class can load pickle file
29
- class LocalizationDataset(Dataset):
30
- def __init__(self, csv_file, embeddings_pkl, max_length=2000):
31
- self.data = pd.read_csv(csv_file)
32
- self.max_length = max_length
33
-
34
- # Map sequences to embeddings
35
- with open(embeddings_pkl, 'rb') as f:
36
- self.embeddings_dict = pickle.load(f)
37
- self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict)
38
-
39
- # Ensure sequences and embeddings are of the same length
40
- assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch"
41
-
42
- # Create multi-class label list
43
- self.data['label'] = self.data.iloc[:, 1:9].value.tolist()
44
-
45
- def __len__(self):
46
- return len(self.data)
47
-
48
- def __getitem__(self, idx):
49
- embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float)
50
- labels = torch.tensor(self.data['label'][idx], dtype=torch.long)
51
-
52
- return embeddings, labels
53
-
54
- # Multi-class localization predictor
55
- class LocalizationPredictor(nn.Module):
56
- def __init__(self, input_dim, num_classes):
57
- super(LocalizationPredictor, self).__init__()
58
- self.classifier = nn.Linear(input_dim, num_classes) # 1280 x 8
59
-
60
- def forward(self, embeddings):
61
- avg_embedding = torch.mean(embeddings, dim=0) # Average embedding dimension: 1280
62
- logits = self.classifier(avg_embedding)
63
- return logits # pass logits of dimension 1x8 (8-class distribution) to CE loss
64
-
65
- # Training function
66
- def train(model, dataloader, optimizer, criterion, device):
67
- model.train()
68
- total_loss = 0
69
- for embeddings, labels in tqdm(dataloader):
70
- embeddings, labels = embeddings.to(device), labels.to(device)
71
- optimizer.zero_grad()
72
- outputs = model(embeddings)
73
- loss = criterion(outputs, labels)
74
- loss.backward()
75
- optimizer.step()
76
- total_loss += loss.item()
77
- return total_loss / len(dataloader)
78
-
79
- # Evaluation function
80
- def evaluate(model, dataloader, device):
81
- model.eval()
82
- preds, true_labels = [], []
83
- with torch.no_grad():
84
- for embeddings, labels in tqdm(dataloader):
85
- embeddings, labels = embeddings.to(device), labels.to(device)
86
- outputs = model(embeddings)
87
- preds.append(outputs.cpu().numpy())
88
- true_labels.append(labels.cpu().numpy())
89
- return preds, true_labels
90
-
91
- # Metrics calculation
92
- def calculate_metrics(preds, labels, threshold=0.5):
93
- flat_binary_preds, flat_labels = [], []
94
-
95
- for pred, label in zip(preds, labels):
96
- flat_binary_preds.extend((pred > threshold).astype(int).flatten())
97
- flat_labels.extend(label.flatten())
98
-
99
- flat_binary_preds = np.array(flat_binary_preds)
100
- flat_labels = np.array(flat_labels)
101
-
102
- accuracy = accuracy_score(flat_labels, flat_binary_preds)
103
- precision = precision_score(flat_labels, flat_binary_preds, average='macro')
104
- recall = recall_score(flat_labels, flat_binary_preds, average='macro')
105
- f1 = f1_score(flat_labels, flat_binary_preds, average='macro')
106
-
107
- return accuracy, precision, recall, f1
108
-
109
-
110
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
-
112
- train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
113
- test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
114
-
115
- train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True)
116
- test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False)
117
-
118
- model = LocalizationPredictor(input_dim=1280, num_classes=8).to(device)
119
- optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"])
120
- criterion = nn.CrossEntropyLoss()
121
-
122
- # Train the model
123
- for epoch in range(hyperparams["num_epochs"]):
124
- train_loss = train(model, train_dataloader, optimizer, criterion, device)
125
- print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}")
126
- print(f"TRAIN LOSS: {train_loss:.4f}")
127
- print("\n")
128
-
129
- # Evaluate model on test dataset
130
- print("Test set")
131
- test_preds, test_labels = evaluate(model, test_dataloader, device)
132
- test_metrics = calculate_metrics(test_preds, test_labels)
133
- print("TEST METRICS:")
134
- print(f"Accuracy: {test_metrics[0]:.4f}")
135
- print(f"Precision: {test_metrics[1]:.4f}")
136
- print(f"Recall: {test_metrics[2]:.4f}")
137
- print(f"F1 Score: {test_metrics[3]:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/DeepLoc/cell_localization_test.csv DELETED
The diff for this file is too large to render. See raw diff
 
benchmarks/DeepLoc/cell_localization_train_val.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:29a07b293fed2994a966b70bdcd6bacc59915b8b01fa200cb2b07d8db18384a2
3
- size 17724293
 
 
 
 
benchmarks/DeepLoc/membrane_localization_predictor.py DELETED
@@ -1,137 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torch.utils.data import DataLoader, Dataset
5
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
6
-
7
- from tqdm import tqdm
8
- from datetime import datetime
9
- import pandas as pd
10
- import numpy as np
11
- import pickle
12
- import os
13
-
14
- # Hyperparameters dictionary
15
- path = "/home/a03-sgoel/MDpLM"
16
-
17
- hyperparams = {
18
- "batch_size": 1,
19
- "learning_rate": 4e-5,
20
- "num_epochs": 5,
21
- "max_length": 2000,
22
- "train_data": path + "/benchmarks/membrane_type_train.csv",
23
- "test_data" : path + "/benchmarks/membrane_type_test.csv",
24
- "val_data": "", # none
25
- "embeddings_pkl": "" # Need to generate ESM embeddings
26
- }
27
-
28
- # Dataset class can load pickle file
29
- class LocalizationDataset(Dataset):
30
- def __init__(self, csv_file, embeddings_pkl, max_length=2000):
31
- self.data = pd.read_csv(csv_file)
32
- self.max_length = max_length
33
-
34
- # Map sequences to embeddings
35
- with open(embeddings_pkl, 'rb') as f:
36
- self.embeddings_dict = pickle.load(f)
37
- self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict)
38
-
39
- # Ensure sequences and embeddings are of the same length
40
- assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch"
41
-
42
- # Create multi-class label list
43
- self.data['label'] = self.data.iloc[:, 2:7].value.tolist()
44
-
45
- def __len__(self):
46
- return len(self.data)
47
-
48
- def __getitem__(self, idx):
49
- embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float)
50
- labels = torch.tensor(self.data['label'][idx], dtype=torch.long)
51
-
52
- return embeddings, labels
53
-
54
- # Multi-class localization predictor
55
- class LocalizationPredictor(nn.Module):
56
- def __init__(self, input_dim, num_classes):
57
- super(LocalizationPredictor, self).__init__()
58
- self.classifier = nn.Linear(input_dim, num_classes) # 1280 x 4
59
-
60
- def forward(self, embeddings):
61
- avg_embedding = torch.mean(embeddings, dim=0) # Average embedding dimension: 1280
62
- logits = self.classifier(avg_embedding)
63
- return logits # pass logits of dimension 1x4 (4-class distribution) to CE loss
64
-
65
- # Training function
66
- def train(model, dataloader, optimizer, criterion, device):
67
- model.train()
68
- total_loss = 0
69
- for embeddings, labels in tqdm(dataloader):
70
- embeddings, labels = embeddings.to(device), labels.to(device)
71
- optimizer.zero_grad()
72
- outputs = model(embeddings)
73
- loss = criterion(outputs, labels)
74
- loss.backward()
75
- optimizer.step()
76
- total_loss += loss.item()
77
- return total_loss / len(dataloader)
78
-
79
- # Evaluation function
80
- def evaluate(model, dataloader, device):
81
- model.eval()
82
- preds, true_labels = [], []
83
- with torch.no_grad():
84
- for embeddings, labels in tqdm(dataloader):
85
- embeddings, labels = embeddings.to(device), labels.to(device)
86
- outputs = model(embeddings)
87
- preds.append(outputs.cpu().numpy())
88
- true_labels.append(labels.cpu().numpy())
89
- return preds, true_labels
90
-
91
- # Metrics calculation
92
- def calculate_metrics(preds, labels, threshold=0.5):
93
- flat_binary_preds, flat_labels = [], []
94
-
95
- for pred, label in zip(preds, labels):
96
- flat_binary_preds.extend((pred > threshold).astype(int).flatten())
97
- flat_labels.extend(label.flatten())
98
-
99
- flat_binary_preds = np.array(flat_binary_preds)
100
- flat_labels = np.array(flat_labels)
101
-
102
- accuracy = accuracy_score(flat_labels, flat_binary_preds)
103
- precision = precision_score(flat_labels, flat_binary_preds, average='macro')
104
- recall = recall_score(flat_labels, flat_binary_preds, average='macro')
105
- f1 = f1_score(flat_labels, flat_binary_preds, average='macro')
106
-
107
- return accuracy, precision, recall, f1
108
-
109
-
110
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
-
112
- train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
113
- test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
114
-
115
- train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True)
116
- test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False)
117
-
118
- model = LocalizationPredictor(input_dim=1280, num_classes=4).to(device)
119
- optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"])
120
- criterion = nn.CrossEntropyLoss()
121
-
122
- # Train the model
123
- for epoch in range(hyperparams["num_epochs"]):
124
- train_loss = train(model, train_dataloader, optimizer, criterion, device)
125
- print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}")
126
- print(f"TRAIN LOSS: {train_loss:.4f}")
127
- print("\n")
128
-
129
- # Evaluate model on test dataset
130
- print("Test set")
131
- test_preds, test_labels = evaluate(model, test_dataloader, device)
132
- test_metrics = calculate_metrics(test_preds, test_labels)
133
- print("TEST METRICS:")
134
- print(f"Accuracy: {test_metrics[0]:.4f}")
135
- print(f"Precision: {test_metrics[1]:.4f}")
136
- print(f"Recall: {test_metrics[2]:.4f}")
137
- print(f"F1 Score: {test_metrics[3]:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/DeepLoc/membrane_type_test.csv DELETED
The diff for this file is too large to render. See raw diff
 
benchmarks/DeepLoc/membrane_type_train.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:16b8eec677afa2de578d04ee1a0fc9582b2f8cfc47622cbd6374309cd6ab96f3
3
- size 12335695
 
 
 
 
benchmarks/DeepLoc/prep_deeploc_benchmark_data.ipynb DELETED
@@ -1,488 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 3,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import pandas as pd"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": 1,
15
- "metadata": {},
16
- "outputs": [],
17
- "source": [
18
- "path = \"/home/a03-sgoel/mESMerize/benchmarks/DeepLoc\""
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 7,
24
- "metadata": {},
25
- "outputs": [
26
- {
27
- "data": {
28
- "text/html": [
29
- "<div>\n",
30
- "<style scoped>\n",
31
- " .dataframe tbody tr th:only-of-type {\n",
32
- " vertical-align: middle;\n",
33
- " }\n",
34
- "\n",
35
- " .dataframe tbody tr th {\n",
36
- " vertical-align: top;\n",
37
- " }\n",
38
- "\n",
39
- " .dataframe thead th {\n",
40
- " text-align: right;\n",
41
- " }\n",
42
- "</style>\n",
43
- "<table border=\"1\" class=\"dataframe\">\n",
44
- " <thead>\n",
45
- " <tr style=\"text-align: right;\">\n",
46
- " <th></th>\n",
47
- " <th>Unnamed: 0</th>\n",
48
- " <th>ACC</th>\n",
49
- " <th>Kingdom</th>\n",
50
- " <th>Partition</th>\n",
51
- " <th>Peripheral</th>\n",
52
- " <th>Transmembrane</th>\n",
53
- " <th>LipidAnchor</th>\n",
54
- " <th>Soluble</th>\n",
55
- " <th>Sequence</th>\n",
56
- " </tr>\n",
57
- " </thead>\n",
58
- " <tbody>\n",
59
- " <tr>\n",
60
- " <th>0</th>\n",
61
- " <td>0</td>\n",
62
- " <td>I3R9M8</td>\n",
63
- " <td>Archaea</td>\n",
64
- " <td>0</td>\n",
65
- " <td>1</td>\n",
66
- " <td>0</td>\n",
67
- " <td>0</td>\n",
68
- " <td>0</td>\n",
69
- " <td>MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG...</td>\n",
70
- " </tr>\n",
71
- " <tr>\n",
72
- " <th>1</th>\n",
73
- " <td>1</td>\n",
74
- " <td>I3R9M9</td>\n",
75
- " <td>Archaea</td>\n",
76
- " <td>1</td>\n",
77
- " <td>1</td>\n",
78
- " <td>0</td>\n",
79
- " <td>0</td>\n",
80
- " <td>0</td>\n",
81
- " <td>MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR...</td>\n",
82
- " </tr>\n",
83
- " <tr>\n",
84
- " <th>2</th>\n",
85
- " <td>2</td>\n",
86
- " <td>Q7ZAG8</td>\n",
87
- " <td>Archaea</td>\n",
88
- " <td>2</td>\n",
89
- " <td>1</td>\n",
90
- " <td>0</td>\n",
91
- " <td>0</td>\n",
92
- " <td>0</td>\n",
93
- " <td>MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP...</td>\n",
94
- " </tr>\n",
95
- " <tr>\n",
96
- " <th>3</th>\n",
97
- " <td>3</td>\n",
98
- " <td>Q8PZ67</td>\n",
99
- " <td>Archaea</td>\n",
100
- " <td>0</td>\n",
101
- " <td>1</td>\n",
102
- " <td>0</td>\n",
103
- " <td>0</td>\n",
104
- " <td>1</td>\n",
105
- " <td>MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY...</td>\n",
106
- " </tr>\n",
107
- " <tr>\n",
108
- " <th>4</th>\n",
109
- " <td>4</td>\n",
110
- " <td>Q9YGA6</td>\n",
111
- " <td>Archaea</td>\n",
112
- " <td>0</td>\n",
113
- " <td>1</td>\n",
114
- " <td>0</td>\n",
115
- " <td>0</td>\n",
116
- " <td>0</td>\n",
117
- " <td>MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL...</td>\n",
118
- " </tr>\n",
119
- " <tr>\n",
120
- " <th>...</th>\n",
121
- " <td>...</td>\n",
122
- " <td>...</td>\n",
123
- " <td>...</td>\n",
124
- " <td>...</td>\n",
125
- " <td>...</td>\n",
126
- " <td>...</td>\n",
127
- " <td>...</td>\n",
128
- " <td>...</td>\n",
129
- " <td>...</td>\n",
130
- " </tr>\n",
131
- " <tr>\n",
132
- " <th>28021</th>\n",
133
- " <td>28021</td>\n",
134
- " <td>P86949</td>\n",
135
- " <td>Eukaryota</td>\n",
136
- " <td>0</td>\n",
137
- " <td>0</td>\n",
138
- " <td>0</td>\n",
139
- " <td>0</td>\n",
140
- " <td>1</td>\n",
141
- " <td>MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY...</td>\n",
142
- " </tr>\n",
143
- " <tr>\n",
144
- " <th>28022</th>\n",
145
- " <td>28022</td>\n",
146
- " <td>P86950</td>\n",
147
- " <td>Eukaryota</td>\n",
148
- " <td>0</td>\n",
149
- " <td>0</td>\n",
150
- " <td>0</td>\n",
151
- " <td>0</td>\n",
152
- " <td>1</td>\n",
153
- " <td>MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG...</td>\n",
154
- " </tr>\n",
155
- " <tr>\n",
156
- " <th>28023</th>\n",
157
- " <td>28023</td>\n",
158
- " <td>P86951</td>\n",
159
- " <td>Eukaryota</td>\n",
160
- " <td>0</td>\n",
161
- " <td>0</td>\n",
162
- " <td>0</td>\n",
163
- " <td>0</td>\n",
164
- " <td>1</td>\n",
165
- " <td>MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG...</td>\n",
166
- " </tr>\n",
167
- " <tr>\n",
168
- " <th>28024</th>\n",
169
- " <td>28024</td>\n",
170
- " <td>P86983</td>\n",
171
- " <td>Eukaryota</td>\n",
172
- " <td>3</td>\n",
173
- " <td>0</td>\n",
174
- " <td>0</td>\n",
175
- " <td>0</td>\n",
176
- " <td>1</td>\n",
177
- " <td>MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL...</td>\n",
178
- " </tr>\n",
179
- " <tr>\n",
180
- " <th>28025</th>\n",
181
- " <td>28025</td>\n",
182
- " <td>P86984</td>\n",
183
- " <td>Eukaryota</td>\n",
184
- " <td>4</td>\n",
185
- " <td>0</td>\n",
186
- " <td>0</td>\n",
187
- " <td>0</td>\n",
188
- " <td>1</td>\n",
189
- " <td>MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL...</td>\n",
190
- " </tr>\n",
191
- " </tbody>\n",
192
- "</table>\n",
193
- "<p>28026 rows × 9 columns</p>\n",
194
- "</div>"
195
- ],
196
- "text/plain": [
197
- " Unnamed: 0 ACC Kingdom Partition Peripheral Transmembrane \\\n",
198
- "0 0 I3R9M8 Archaea 0 1 0 \n",
199
- "1 1 I3R9M9 Archaea 1 1 0 \n",
200
- "2 2 Q7ZAG8 Archaea 2 1 0 \n",
201
- "3 3 Q8PZ67 Archaea 0 1 0 \n",
202
- "4 4 Q9YGA6 Archaea 0 1 0 \n",
203
- "... ... ... ... ... ... ... \n",
204
- "28021 28021 P86949 Eukaryota 0 0 0 \n",
205
- "28022 28022 P86950 Eukaryota 0 0 0 \n",
206
- "28023 28023 P86951 Eukaryota 0 0 0 \n",
207
- "28024 28024 P86983 Eukaryota 3 0 0 \n",
208
- "28025 28025 P86984 Eukaryota 4 0 0 \n",
209
- "\n",
210
- " LipidAnchor Soluble Sequence \n",
211
- "0 0 0 MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG... \n",
212
- "1 0 0 MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR... \n",
213
- "2 0 0 MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP... \n",
214
- "3 0 1 MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY... \n",
215
- "4 0 0 MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL... \n",
216
- "... ... ... ... \n",
217
- "28021 0 1 MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY... \n",
218
- "28022 0 1 MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG... \n",
219
- "28023 0 1 MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG... \n",
220
- "28024 0 1 MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL... \n",
221
- "28025 0 1 MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL... \n",
222
- "\n",
223
- "[28026 rows x 9 columns]"
224
- ]
225
- },
226
- "execution_count": 7,
227
- "metadata": {},
228
- "output_type": "execute_result"
229
- }
230
- ],
231
- "source": [
232
- "df"
233
- ]
234
- },
235
- {
236
- "cell_type": "code",
237
- "execution_count": 9,
238
- "metadata": {},
239
- "outputs": [
240
- {
241
- "data": {
242
- "text/html": [
243
- "<div>\n",
244
- "<style scoped>\n",
245
- " .dataframe tbody tr th:only-of-type {\n",
246
- " vertical-align: middle;\n",
247
- " }\n",
248
- "\n",
249
- " .dataframe tbody tr th {\n",
250
- " vertical-align: top;\n",
251
- " }\n",
252
- "\n",
253
- " .dataframe thead th {\n",
254
- " text-align: right;\n",
255
- " }\n",
256
- "</style>\n",
257
- "<table border=\"1\" class=\"dataframe\">\n",
258
- " <thead>\n",
259
- " <tr style=\"text-align: right;\">\n",
260
- " <th></th>\n",
261
- " <th>ACC</th>\n",
262
- " <th>Kingdom</th>\n",
263
- " <th>Partition</th>\n",
264
- " <th>Peripheral</th>\n",
265
- " <th>Transmembrane</th>\n",
266
- " <th>LipidAnchor</th>\n",
267
- " <th>Soluble</th>\n",
268
- " <th>Sequence</th>\n",
269
- " </tr>\n",
270
- " </thead>\n",
271
- " <tbody>\n",
272
- " <tr>\n",
273
- " <th>0</th>\n",
274
- " <td>I3R9M8</td>\n",
275
- " <td>Archaea</td>\n",
276
- " <td>0</td>\n",
277
- " <td>1</td>\n",
278
- " <td>0</td>\n",
279
- " <td>0</td>\n",
280
- " <td>0</td>\n",
281
- " <td>MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG...</td>\n",
282
- " </tr>\n",
283
- " <tr>\n",
284
- " <th>1</th>\n",
285
- " <td>I3R9M9</td>\n",
286
- " <td>Archaea</td>\n",
287
- " <td>1</td>\n",
288
- " <td>1</td>\n",
289
- " <td>0</td>\n",
290
- " <td>0</td>\n",
291
- " <td>0</td>\n",
292
- " <td>MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR...</td>\n",
293
- " </tr>\n",
294
- " <tr>\n",
295
- " <th>2</th>\n",
296
- " <td>Q7ZAG8</td>\n",
297
- " <td>Archaea</td>\n",
298
- " <td>2</td>\n",
299
- " <td>1</td>\n",
300
- " <td>0</td>\n",
301
- " <td>0</td>\n",
302
- " <td>0</td>\n",
303
- " <td>MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP...</td>\n",
304
- " </tr>\n",
305
- " <tr>\n",
306
- " <th>3</th>\n",
307
- " <td>Q8PZ67</td>\n",
308
- " <td>Archaea</td>\n",
309
- " <td>0</td>\n",
310
- " <td>1</td>\n",
311
- " <td>0</td>\n",
312
- " <td>0</td>\n",
313
- " <td>1</td>\n",
314
- " <td>MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY...</td>\n",
315
- " </tr>\n",
316
- " <tr>\n",
317
- " <th>4</th>\n",
318
- " <td>Q9YGA6</td>\n",
319
- " <td>Archaea</td>\n",
320
- " <td>0</td>\n",
321
- " <td>1</td>\n",
322
- " <td>0</td>\n",
323
- " <td>0</td>\n",
324
- " <td>0</td>\n",
325
- " <td>MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL...</td>\n",
326
- " </tr>\n",
327
- " <tr>\n",
328
- " <th>...</th>\n",
329
- " <td>...</td>\n",
330
- " <td>...</td>\n",
331
- " <td>...</td>\n",
332
- " <td>...</td>\n",
333
- " <td>...</td>\n",
334
- " <td>...</td>\n",
335
- " <td>...</td>\n",
336
- " <td>...</td>\n",
337
- " </tr>\n",
338
- " <tr>\n",
339
- " <th>28021</th>\n",
340
- " <td>P86949</td>\n",
341
- " <td>Eukaryota</td>\n",
342
- " <td>0</td>\n",
343
- " <td>0</td>\n",
344
- " <td>0</td>\n",
345
- " <td>0</td>\n",
346
- " <td>1</td>\n",
347
- " <td>MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY...</td>\n",
348
- " </tr>\n",
349
- " <tr>\n",
350
- " <th>28022</th>\n",
351
- " <td>P86950</td>\n",
352
- " <td>Eukaryota</td>\n",
353
- " <td>0</td>\n",
354
- " <td>0</td>\n",
355
- " <td>0</td>\n",
356
- " <td>0</td>\n",
357
- " <td>1</td>\n",
358
- " <td>MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG...</td>\n",
359
- " </tr>\n",
360
- " <tr>\n",
361
- " <th>28023</th>\n",
362
- " <td>P86951</td>\n",
363
- " <td>Eukaryota</td>\n",
364
- " <td>0</td>\n",
365
- " <td>0</td>\n",
366
- " <td>0</td>\n",
367
- " <td>0</td>\n",
368
- " <td>1</td>\n",
369
- " <td>MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG...</td>\n",
370
- " </tr>\n",
371
- " <tr>\n",
372
- " <th>28024</th>\n",
373
- " <td>P86983</td>\n",
374
- " <td>Eukaryota</td>\n",
375
- " <td>3</td>\n",
376
- " <td>0</td>\n",
377
- " <td>0</td>\n",
378
- " <td>0</td>\n",
379
- " <td>1</td>\n",
380
- " <td>MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL...</td>\n",
381
- " </tr>\n",
382
- " <tr>\n",
383
- " <th>28025</th>\n",
384
- " <td>P86984</td>\n",
385
- " <td>Eukaryota</td>\n",
386
- " <td>4</td>\n",
387
- " <td>0</td>\n",
388
- " <td>0</td>\n",
389
- " <td>0</td>\n",
390
- " <td>1</td>\n",
391
- " <td>MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL...</td>\n",
392
- " </tr>\n",
393
- " </tbody>\n",
394
- "</table>\n",
395
- "<p>28026 rows × 8 columns</p>\n",
396
- "</div>"
397
- ],
398
- "text/plain": [
399
- " ACC Kingdom Partition Peripheral Transmembrane LipidAnchor \\\n",
400
- "0 I3R9M8 Archaea 0 1 0 0 \n",
401
- "1 I3R9M9 Archaea 1 1 0 0 \n",
402
- "2 Q7ZAG8 Archaea 2 1 0 0 \n",
403
- "3 Q8PZ67 Archaea 0 1 0 0 \n",
404
- "4 Q9YGA6 Archaea 0 1 0 0 \n",
405
- "... ... ... ... ... ... ... \n",
406
- "28021 P86949 Eukaryota 0 0 0 0 \n",
407
- "28022 P86950 Eukaryota 0 0 0 0 \n",
408
- "28023 P86951 Eukaryota 0 0 0 0 \n",
409
- "28024 P86983 Eukaryota 3 0 0 0 \n",
410
- "28025 P86984 Eukaryota 4 0 0 0 \n",
411
- "\n",
412
- " Soluble Sequence \n",
413
- "0 0 MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG... \n",
414
- "1 0 MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR... \n",
415
- "2 0 MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP... \n",
416
- "3 1 MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY... \n",
417
- "4 0 MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL... \n",
418
- "... ... ... \n",
419
- "28021 1 MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY... \n",
420
- "28022 1 MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG... \n",
421
- "28023 1 MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG... \n",
422
- "28024 1 MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL... \n",
423
- "28025 1 MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL... \n",
424
- "\n",
425
- "[28026 rows x 8 columns]"
426
- ]
427
- },
428
- "execution_count": 9,
429
- "metadata": {},
430
- "output_type": "execute_result"
431
- }
432
- ],
433
- "source": [
434
- "df = pd.read_csv(path + \"/OG_membrane_type_all.csv\")\n",
435
- "df = df.drop(columns=['Unnamed: 0'])\n",
436
- "df"
437
- ]
438
- },
439
- {
440
- "cell_type": "code",
441
- "execution_count": 14,
442
- "metadata": {},
443
- "outputs": [],
444
- "source": [
445
- "train = df[df['Partition'] != 4]\n",
446
- "test = df[df['Partition'] == 4]"
447
- ]
448
- },
449
- {
450
- "cell_type": "code",
451
- "execution_count": 17,
452
- "metadata": {},
453
- "outputs": [],
454
- "source": [
455
- "train.to_csv(path + \"/membrane_type_train.csv\", index=False)\n",
456
- "test.to_csv(path + \"/membrane_type_test.csv\", index=False)"
457
- ]
458
- },
459
- {
460
- "cell_type": "code",
461
- "execution_count": null,
462
- "metadata": {},
463
- "outputs": [],
464
- "source": []
465
- }
466
- ],
467
- "metadata": {
468
- "kernelspec": {
469
- "display_name": "Python 3",
470
- "language": "python",
471
- "name": "python3"
472
- },
473
- "language_info": {
474
- "codemirror_mode": {
475
- "name": "ipython",
476
- "version": 3
477
- },
478
- "file_extension": ".py",
479
- "mimetype": "text/x-python",
480
- "name": "python",
481
- "nbconvert_exporter": "python",
482
- "pygments_lexer": "ipython3",
483
- "version": "3.10.12"
484
- }
485
- },
486
- "nbformat": 4,
487
- "nbformat_minor": 2
488
- }