sgoel30 commited on
Commit
9b7a4b0
·
verified ·
1 Parent(s): a43e004

Benchmarking pipeline. Predicts the specific type of the generated membrane protein and the subcellular localization of the generated protein

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ benchmarks/DeepLoc/cell_localization_train_val.csv filter=lfs diff=lfs merge=lfs -text
37
+ benchmarks/DeepLoc/membrane_type_train.csv filter=lfs diff=lfs merge=lfs -text
38
+ benchmarks/DeepLoc/OG_membrane_type_all.csv filter=lfs diff=lfs merge=lfs -text
benchmarks/DeepLoc/OG_membrane_type_all.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d878da32a06092f880262048e3c1eb692721c274b0a458fcc712a0dcbd80c71
3
+ size 15683507
benchmarks/DeepLoc/cell_localization_predictor.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
6
+
7
+ from tqdm import tqdm
8
+ from datetime import datetime
9
+ import pandas as pd
10
+ import numpy as np
11
+ import pickle
12
+ import os
13
+
14
+ # Hyperparameters dictionary
15
+ path = "/home/a03-sgoel/MDpLM"
16
+
17
+ hyperparams = {
18
+ "batch_size": 1,
19
+ "learning_rate": 4e-5,
20
+ "num_epochs": 5,
21
+ "max_length": 2000,
22
+ "train_data": path + "/benchmarks/DeepLoc/cell_localization_train_val.csv.csv",
23
+ "test_data" : path + "/benchmarks/DeepLoc/cell_localization_test.csv",
24
+ "val_data": "", # None
25
+ "embeddings_pkl": "", # Need to generate ESM embeddings and save as pkl file
26
+ }
27
+
28
+ # Dataset class can load pickle file
29
+ class LocalizationDataset(Dataset):
30
+ def __init__(self, csv_file, embeddings_pkl, max_length=2000):
31
+ self.data = pd.read_csv(csv_file)
32
+ self.max_length = max_length
33
+
34
+ # Map sequences to embeddings
35
+ with open(embeddings_pkl, 'rb') as f:
36
+ self.embeddings_dict = pickle.load(f)
37
+ self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict)
38
+
39
+ # Ensure sequences and embeddings are of the same length
40
+ assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch"
41
+
42
+ # Create multi-class label list
43
+ self.data['label'] = self.data.iloc[:, 1:9].value.tolist()
44
+
45
+ def __len__(self):
46
+ return len(self.data)
47
+
48
+ def __getitem__(self, idx):
49
+ embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float)
50
+ labels = torch.tensor(self.data['label'][idx], dtype=torch.long)
51
+
52
+ return embeddings, labels
53
+
54
+ # Multi-class localization predictor
55
+ class LocalizationPredictor(nn.Module):
56
+ def __init__(self, input_dim, num_classes):
57
+ super(LocalizationPredictor, self).__init__()
58
+ self.classifier = nn.Linear(input_dim, num_classes) # 1280 x 8
59
+
60
+ def forward(self, embeddings):
61
+ avg_embedding = torch.mean(embeddings, dim=0) # Average embedding dimension: 1280
62
+ logits = self.classifier(avg_embedding)
63
+ return logits # pass logits of dimension 1x8 (8-class distribution) to CE loss
64
+
65
+ # Training function
66
+ def train(model, dataloader, optimizer, criterion, device):
67
+ model.train()
68
+ total_loss = 0
69
+ for embeddings, labels in tqdm(dataloader):
70
+ embeddings, labels = embeddings.to(device), labels.to(device)
71
+ optimizer.zero_grad()
72
+ outputs = model(embeddings)
73
+ loss = criterion(outputs, labels)
74
+ loss.backward()
75
+ optimizer.step()
76
+ total_loss += loss.item()
77
+ return total_loss / len(dataloader)
78
+
79
+ # Evaluation function
80
+ def evaluate(model, dataloader, device):
81
+ model.eval()
82
+ preds, true_labels = [], []
83
+ with torch.no_grad():
84
+ for embeddings, labels in tqdm(dataloader):
85
+ embeddings, labels = embeddings.to(device), labels.to(device)
86
+ outputs = model(embeddings)
87
+ preds.append(outputs.cpu().numpy())
88
+ true_labels.append(labels.cpu().numpy())
89
+ return preds, true_labels
90
+
91
+ # Metrics calculation
92
+ def calculate_metrics(preds, labels, threshold=0.5):
93
+ flat_binary_preds, flat_labels = [], []
94
+
95
+ for pred, label in zip(preds, labels):
96
+ flat_binary_preds.extend((pred > threshold).astype(int).flatten())
97
+ flat_labels.extend(label.flatten())
98
+
99
+ flat_binary_preds = np.array(flat_binary_preds)
100
+ flat_labels = np.array(flat_labels)
101
+
102
+ accuracy = accuracy_score(flat_labels, flat_binary_preds)
103
+ precision = precision_score(flat_labels, flat_binary_preds, average='macro')
104
+ recall = recall_score(flat_labels, flat_binary_preds, average='macro')
105
+ f1 = f1_score(flat_labels, flat_binary_preds, average='macro')
106
+
107
+ return accuracy, precision, recall, f1
108
+
109
+
110
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
+
112
+ train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
113
+ test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
114
+
115
+ train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True)
116
+ test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False)
117
+
118
+ model = LocalizationPredictor(input_dim=1280, num_classes=8).to(device)
119
+ optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"])
120
+ criterion = nn.CrossEntropyLoss()
121
+
122
+ # Train the model
123
+ for epoch in range(hyperparams["num_epochs"]):
124
+ train_loss = train(model, train_dataloader, optimizer, criterion, device)
125
+ print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}")
126
+ print(f"TRAIN LOSS: {train_loss:.4f}")
127
+ print("\n")
128
+
129
+ # Evaluate model on test dataset
130
+ print("Test set")
131
+ test_preds, test_labels = evaluate(model, test_dataloader, device)
132
+ test_metrics = calculate_metrics(test_preds, test_labels)
133
+ print("TEST METRICS:")
134
+ print(f"Accuracy: {test_metrics[0]:.4f}")
135
+ print(f"Precision: {test_metrics[1]:.4f}")
136
+ print(f"Recall: {test_metrics[2]:.4f}")
137
+ print(f"F1 Score: {test_metrics[3]:.4f}")
benchmarks/DeepLoc/cell_localization_test.csv ADDED
The diff for this file is too large to render. See raw diff
 
benchmarks/DeepLoc/cell_localization_train_val.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29a07b293fed2994a966b70bdcd6bacc59915b8b01fa200cb2b07d8db18384a2
3
+ size 17724293
benchmarks/DeepLoc/membrane_localization_predictor.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
6
+
7
+ from tqdm import tqdm
8
+ from datetime import datetime
9
+ import pandas as pd
10
+ import numpy as np
11
+ import pickle
12
+ import os
13
+
14
+ # Hyperparameters dictionary
15
+ path = "/home/a03-sgoel/MDpLM"
16
+
17
+ hyperparams = {
18
+ "batch_size": 1,
19
+ "learning_rate": 4e-5,
20
+ "num_epochs": 5,
21
+ "max_length": 2000,
22
+ "train_data": path + "/benchmarks/membrane_type_train.csv",
23
+ "test_data" : path + "/benchmarks/membrane_type_test.csv",
24
+ "val_data": "", # none
25
+ "embeddings_pkl": "" # Need to generate ESM embeddings
26
+ }
27
+
28
+ # Dataset class can load pickle file
29
+ class LocalizationDataset(Dataset):
30
+ def __init__(self, csv_file, embeddings_pkl, max_length=2000):
31
+ self.data = pd.read_csv(csv_file)
32
+ self.max_length = max_length
33
+
34
+ # Map sequences to embeddings
35
+ with open(embeddings_pkl, 'rb') as f:
36
+ self.embeddings_dict = pickle.load(f)
37
+ self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict)
38
+
39
+ # Ensure sequences and embeddings are of the same length
40
+ assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch"
41
+
42
+ # Create multi-class label list
43
+ self.data['label'] = self.data.iloc[:, 2:7].value.tolist()
44
+
45
+ def __len__(self):
46
+ return len(self.data)
47
+
48
+ def __getitem__(self, idx):
49
+ embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float)
50
+ labels = torch.tensor(self.data['label'][idx], dtype=torch.long)
51
+
52
+ return embeddings, labels
53
+
54
+ # Multi-class localization predictor
55
+ class LocalizationPredictor(nn.Module):
56
+ def __init__(self, input_dim, num_classes):
57
+ super(LocalizationPredictor, self).__init__()
58
+ self.classifier = nn.Linear(input_dim, num_classes) # 1280 x 4
59
+
60
+ def forward(self, embeddings):
61
+ avg_embedding = torch.mean(embeddings, dim=0) # Average embedding dimension: 1280
62
+ logits = self.classifier(avg_embedding)
63
+ return logits # pass logits of dimension 1x4 (4-class distribution) to CE loss
64
+
65
+ # Training function
66
+ def train(model, dataloader, optimizer, criterion, device):
67
+ model.train()
68
+ total_loss = 0
69
+ for embeddings, labels in tqdm(dataloader):
70
+ embeddings, labels = embeddings.to(device), labels.to(device)
71
+ optimizer.zero_grad()
72
+ outputs = model(embeddings)
73
+ loss = criterion(outputs, labels)
74
+ loss.backward()
75
+ optimizer.step()
76
+ total_loss += loss.item()
77
+ return total_loss / len(dataloader)
78
+
79
+ # Evaluation function
80
+ def evaluate(model, dataloader, device):
81
+ model.eval()
82
+ preds, true_labels = [], []
83
+ with torch.no_grad():
84
+ for embeddings, labels in tqdm(dataloader):
85
+ embeddings, labels = embeddings.to(device), labels.to(device)
86
+ outputs = model(embeddings)
87
+ preds.append(outputs.cpu().numpy())
88
+ true_labels.append(labels.cpu().numpy())
89
+ return preds, true_labels
90
+
91
+ # Metrics calculation
92
+ def calculate_metrics(preds, labels, threshold=0.5):
93
+ flat_binary_preds, flat_labels = [], []
94
+
95
+ for pred, label in zip(preds, labels):
96
+ flat_binary_preds.extend((pred > threshold).astype(int).flatten())
97
+ flat_labels.extend(label.flatten())
98
+
99
+ flat_binary_preds = np.array(flat_binary_preds)
100
+ flat_labels = np.array(flat_labels)
101
+
102
+ accuracy = accuracy_score(flat_labels, flat_binary_preds)
103
+ precision = precision_score(flat_labels, flat_binary_preds, average='macro')
104
+ recall = recall_score(flat_labels, flat_binary_preds, average='macro')
105
+ f1 = f1_score(flat_labels, flat_binary_preds, average='macro')
106
+
107
+ return accuracy, precision, recall, f1
108
+
109
+
110
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
+
112
+ train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
113
+ test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"])
114
+
115
+ train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True)
116
+ test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False)
117
+
118
+ model = LocalizationPredictor(input_dim=1280, num_classes=4).to(device)
119
+ optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"])
120
+ criterion = nn.CrossEntropyLoss()
121
+
122
+ # Train the model
123
+ for epoch in range(hyperparams["num_epochs"]):
124
+ train_loss = train(model, train_dataloader, optimizer, criterion, device)
125
+ print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}")
126
+ print(f"TRAIN LOSS: {train_loss:.4f}")
127
+ print("\n")
128
+
129
+ # Evaluate model on test dataset
130
+ print("Test set")
131
+ test_preds, test_labels = evaluate(model, test_dataloader, device)
132
+ test_metrics = calculate_metrics(test_preds, test_labels)
133
+ print("TEST METRICS:")
134
+ print(f"Accuracy: {test_metrics[0]:.4f}")
135
+ print(f"Precision: {test_metrics[1]:.4f}")
136
+ print(f"Recall: {test_metrics[2]:.4f}")
137
+ print(f"F1 Score: {test_metrics[3]:.4f}")
benchmarks/DeepLoc/membrane_type_test.csv ADDED
The diff for this file is too large to render. See raw diff
 
benchmarks/DeepLoc/membrane_type_train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16b8eec677afa2de578d04ee1a0fc9582b2f8cfc47622cbd6374309cd6ab96f3
3
+ size 12335695
benchmarks/DeepLoc/prep_deeploc_benchmark_data.ipynb ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pandas as pd"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "path = \"/home/a03-sgoel/mESMerize/benchmarks/DeepLoc\""
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 7,
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "data": {
28
+ "text/html": [
29
+ "<div>\n",
30
+ "<style scoped>\n",
31
+ " .dataframe tbody tr th:only-of-type {\n",
32
+ " vertical-align: middle;\n",
33
+ " }\n",
34
+ "\n",
35
+ " .dataframe tbody tr th {\n",
36
+ " vertical-align: top;\n",
37
+ " }\n",
38
+ "\n",
39
+ " .dataframe thead th {\n",
40
+ " text-align: right;\n",
41
+ " }\n",
42
+ "</style>\n",
43
+ "<table border=\"1\" class=\"dataframe\">\n",
44
+ " <thead>\n",
45
+ " <tr style=\"text-align: right;\">\n",
46
+ " <th></th>\n",
47
+ " <th>Unnamed: 0</th>\n",
48
+ " <th>ACC</th>\n",
49
+ " <th>Kingdom</th>\n",
50
+ " <th>Partition</th>\n",
51
+ " <th>Peripheral</th>\n",
52
+ " <th>Transmembrane</th>\n",
53
+ " <th>LipidAnchor</th>\n",
54
+ " <th>Soluble</th>\n",
55
+ " <th>Sequence</th>\n",
56
+ " </tr>\n",
57
+ " </thead>\n",
58
+ " <tbody>\n",
59
+ " <tr>\n",
60
+ " <th>0</th>\n",
61
+ " <td>0</td>\n",
62
+ " <td>I3R9M8</td>\n",
63
+ " <td>Archaea</td>\n",
64
+ " <td>0</td>\n",
65
+ " <td>1</td>\n",
66
+ " <td>0</td>\n",
67
+ " <td>0</td>\n",
68
+ " <td>0</td>\n",
69
+ " <td>MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG...</td>\n",
70
+ " </tr>\n",
71
+ " <tr>\n",
72
+ " <th>1</th>\n",
73
+ " <td>1</td>\n",
74
+ " <td>I3R9M9</td>\n",
75
+ " <td>Archaea</td>\n",
76
+ " <td>1</td>\n",
77
+ " <td>1</td>\n",
78
+ " <td>0</td>\n",
79
+ " <td>0</td>\n",
80
+ " <td>0</td>\n",
81
+ " <td>MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR...</td>\n",
82
+ " </tr>\n",
83
+ " <tr>\n",
84
+ " <th>2</th>\n",
85
+ " <td>2</td>\n",
86
+ " <td>Q7ZAG8</td>\n",
87
+ " <td>Archaea</td>\n",
88
+ " <td>2</td>\n",
89
+ " <td>1</td>\n",
90
+ " <td>0</td>\n",
91
+ " <td>0</td>\n",
92
+ " <td>0</td>\n",
93
+ " <td>MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP...</td>\n",
94
+ " </tr>\n",
95
+ " <tr>\n",
96
+ " <th>3</th>\n",
97
+ " <td>3</td>\n",
98
+ " <td>Q8PZ67</td>\n",
99
+ " <td>Archaea</td>\n",
100
+ " <td>0</td>\n",
101
+ " <td>1</td>\n",
102
+ " <td>0</td>\n",
103
+ " <td>0</td>\n",
104
+ " <td>1</td>\n",
105
+ " <td>MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY...</td>\n",
106
+ " </tr>\n",
107
+ " <tr>\n",
108
+ " <th>4</th>\n",
109
+ " <td>4</td>\n",
110
+ " <td>Q9YGA6</td>\n",
111
+ " <td>Archaea</td>\n",
112
+ " <td>0</td>\n",
113
+ " <td>1</td>\n",
114
+ " <td>0</td>\n",
115
+ " <td>0</td>\n",
116
+ " <td>0</td>\n",
117
+ " <td>MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL...</td>\n",
118
+ " </tr>\n",
119
+ " <tr>\n",
120
+ " <th>...</th>\n",
121
+ " <td>...</td>\n",
122
+ " <td>...</td>\n",
123
+ " <td>...</td>\n",
124
+ " <td>...</td>\n",
125
+ " <td>...</td>\n",
126
+ " <td>...</td>\n",
127
+ " <td>...</td>\n",
128
+ " <td>...</td>\n",
129
+ " <td>...</td>\n",
130
+ " </tr>\n",
131
+ " <tr>\n",
132
+ " <th>28021</th>\n",
133
+ " <td>28021</td>\n",
134
+ " <td>P86949</td>\n",
135
+ " <td>Eukaryota</td>\n",
136
+ " <td>0</td>\n",
137
+ " <td>0</td>\n",
138
+ " <td>0</td>\n",
139
+ " <td>0</td>\n",
140
+ " <td>1</td>\n",
141
+ " <td>MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY...</td>\n",
142
+ " </tr>\n",
143
+ " <tr>\n",
144
+ " <th>28022</th>\n",
145
+ " <td>28022</td>\n",
146
+ " <td>P86950</td>\n",
147
+ " <td>Eukaryota</td>\n",
148
+ " <td>0</td>\n",
149
+ " <td>0</td>\n",
150
+ " <td>0</td>\n",
151
+ " <td>0</td>\n",
152
+ " <td>1</td>\n",
153
+ " <td>MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG...</td>\n",
154
+ " </tr>\n",
155
+ " <tr>\n",
156
+ " <th>28023</th>\n",
157
+ " <td>28023</td>\n",
158
+ " <td>P86951</td>\n",
159
+ " <td>Eukaryota</td>\n",
160
+ " <td>0</td>\n",
161
+ " <td>0</td>\n",
162
+ " <td>0</td>\n",
163
+ " <td>0</td>\n",
164
+ " <td>1</td>\n",
165
+ " <td>MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG...</td>\n",
166
+ " </tr>\n",
167
+ " <tr>\n",
168
+ " <th>28024</th>\n",
169
+ " <td>28024</td>\n",
170
+ " <td>P86983</td>\n",
171
+ " <td>Eukaryota</td>\n",
172
+ " <td>3</td>\n",
173
+ " <td>0</td>\n",
174
+ " <td>0</td>\n",
175
+ " <td>0</td>\n",
176
+ " <td>1</td>\n",
177
+ " <td>MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL...</td>\n",
178
+ " </tr>\n",
179
+ " <tr>\n",
180
+ " <th>28025</th>\n",
181
+ " <td>28025</td>\n",
182
+ " <td>P86984</td>\n",
183
+ " <td>Eukaryota</td>\n",
184
+ " <td>4</td>\n",
185
+ " <td>0</td>\n",
186
+ " <td>0</td>\n",
187
+ " <td>0</td>\n",
188
+ " <td>1</td>\n",
189
+ " <td>MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL...</td>\n",
190
+ " </tr>\n",
191
+ " </tbody>\n",
192
+ "</table>\n",
193
+ "<p>28026 rows × 9 columns</p>\n",
194
+ "</div>"
195
+ ],
196
+ "text/plain": [
197
+ " Unnamed: 0 ACC Kingdom Partition Peripheral Transmembrane \\\n",
198
+ "0 0 I3R9M8 Archaea 0 1 0 \n",
199
+ "1 1 I3R9M9 Archaea 1 1 0 \n",
200
+ "2 2 Q7ZAG8 Archaea 2 1 0 \n",
201
+ "3 3 Q8PZ67 Archaea 0 1 0 \n",
202
+ "4 4 Q9YGA6 Archaea 0 1 0 \n",
203
+ "... ... ... ... ... ... ... \n",
204
+ "28021 28021 P86949 Eukaryota 0 0 0 \n",
205
+ "28022 28022 P86950 Eukaryota 0 0 0 \n",
206
+ "28023 28023 P86951 Eukaryota 0 0 0 \n",
207
+ "28024 28024 P86983 Eukaryota 3 0 0 \n",
208
+ "28025 28025 P86984 Eukaryota 4 0 0 \n",
209
+ "\n",
210
+ " LipidAnchor Soluble Sequence \n",
211
+ "0 0 0 MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG... \n",
212
+ "1 0 0 MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR... \n",
213
+ "2 0 0 MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP... \n",
214
+ "3 0 1 MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY... \n",
215
+ "4 0 0 MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL... \n",
216
+ "... ... ... ... \n",
217
+ "28021 0 1 MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY... \n",
218
+ "28022 0 1 MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG... \n",
219
+ "28023 0 1 MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG... \n",
220
+ "28024 0 1 MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL... \n",
221
+ "28025 0 1 MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL... \n",
222
+ "\n",
223
+ "[28026 rows x 9 columns]"
224
+ ]
225
+ },
226
+ "execution_count": 7,
227
+ "metadata": {},
228
+ "output_type": "execute_result"
229
+ }
230
+ ],
231
+ "source": [
232
+ "df"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 9,
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "data": {
242
+ "text/html": [
243
+ "<div>\n",
244
+ "<style scoped>\n",
245
+ " .dataframe tbody tr th:only-of-type {\n",
246
+ " vertical-align: middle;\n",
247
+ " }\n",
248
+ "\n",
249
+ " .dataframe tbody tr th {\n",
250
+ " vertical-align: top;\n",
251
+ " }\n",
252
+ "\n",
253
+ " .dataframe thead th {\n",
254
+ " text-align: right;\n",
255
+ " }\n",
256
+ "</style>\n",
257
+ "<table border=\"1\" class=\"dataframe\">\n",
258
+ " <thead>\n",
259
+ " <tr style=\"text-align: right;\">\n",
260
+ " <th></th>\n",
261
+ " <th>ACC</th>\n",
262
+ " <th>Kingdom</th>\n",
263
+ " <th>Partition</th>\n",
264
+ " <th>Peripheral</th>\n",
265
+ " <th>Transmembrane</th>\n",
266
+ " <th>LipidAnchor</th>\n",
267
+ " <th>Soluble</th>\n",
268
+ " <th>Sequence</th>\n",
269
+ " </tr>\n",
270
+ " </thead>\n",
271
+ " <tbody>\n",
272
+ " <tr>\n",
273
+ " <th>0</th>\n",
274
+ " <td>I3R9M8</td>\n",
275
+ " <td>Archaea</td>\n",
276
+ " <td>0</td>\n",
277
+ " <td>1</td>\n",
278
+ " <td>0</td>\n",
279
+ " <td>0</td>\n",
280
+ " <td>0</td>\n",
281
+ " <td>MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG...</td>\n",
282
+ " </tr>\n",
283
+ " <tr>\n",
284
+ " <th>1</th>\n",
285
+ " <td>I3R9M9</td>\n",
286
+ " <td>Archaea</td>\n",
287
+ " <td>1</td>\n",
288
+ " <td>1</td>\n",
289
+ " <td>0</td>\n",
290
+ " <td>0</td>\n",
291
+ " <td>0</td>\n",
292
+ " <td>MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR...</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <th>2</th>\n",
296
+ " <td>Q7ZAG8</td>\n",
297
+ " <td>Archaea</td>\n",
298
+ " <td>2</td>\n",
299
+ " <td>1</td>\n",
300
+ " <td>0</td>\n",
301
+ " <td>0</td>\n",
302
+ " <td>0</td>\n",
303
+ " <td>MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP...</td>\n",
304
+ " </tr>\n",
305
+ " <tr>\n",
306
+ " <th>3</th>\n",
307
+ " <td>Q8PZ67</td>\n",
308
+ " <td>Archaea</td>\n",
309
+ " <td>0</td>\n",
310
+ " <td>1</td>\n",
311
+ " <td>0</td>\n",
312
+ " <td>0</td>\n",
313
+ " <td>1</td>\n",
314
+ " <td>MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY...</td>\n",
315
+ " </tr>\n",
316
+ " <tr>\n",
317
+ " <th>4</th>\n",
318
+ " <td>Q9YGA6</td>\n",
319
+ " <td>Archaea</td>\n",
320
+ " <td>0</td>\n",
321
+ " <td>1</td>\n",
322
+ " <td>0</td>\n",
323
+ " <td>0</td>\n",
324
+ " <td>0</td>\n",
325
+ " <td>MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL...</td>\n",
326
+ " </tr>\n",
327
+ " <tr>\n",
328
+ " <th>...</th>\n",
329
+ " <td>...</td>\n",
330
+ " <td>...</td>\n",
331
+ " <td>...</td>\n",
332
+ " <td>...</td>\n",
333
+ " <td>...</td>\n",
334
+ " <td>...</td>\n",
335
+ " <td>...</td>\n",
336
+ " <td>...</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <th>28021</th>\n",
340
+ " <td>P86949</td>\n",
341
+ " <td>Eukaryota</td>\n",
342
+ " <td>0</td>\n",
343
+ " <td>0</td>\n",
344
+ " <td>0</td>\n",
345
+ " <td>0</td>\n",
346
+ " <td>1</td>\n",
347
+ " <td>MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY...</td>\n",
348
+ " </tr>\n",
349
+ " <tr>\n",
350
+ " <th>28022</th>\n",
351
+ " <td>P86950</td>\n",
352
+ " <td>Eukaryota</td>\n",
353
+ " <td>0</td>\n",
354
+ " <td>0</td>\n",
355
+ " <td>0</td>\n",
356
+ " <td>0</td>\n",
357
+ " <td>1</td>\n",
358
+ " <td>MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG...</td>\n",
359
+ " </tr>\n",
360
+ " <tr>\n",
361
+ " <th>28023</th>\n",
362
+ " <td>P86951</td>\n",
363
+ " <td>Eukaryota</td>\n",
364
+ " <td>0</td>\n",
365
+ " <td>0</td>\n",
366
+ " <td>0</td>\n",
367
+ " <td>0</td>\n",
368
+ " <td>1</td>\n",
369
+ " <td>MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG...</td>\n",
370
+ " </tr>\n",
371
+ " <tr>\n",
372
+ " <th>28024</th>\n",
373
+ " <td>P86983</td>\n",
374
+ " <td>Eukaryota</td>\n",
375
+ " <td>3</td>\n",
376
+ " <td>0</td>\n",
377
+ " <td>0</td>\n",
378
+ " <td>0</td>\n",
379
+ " <td>1</td>\n",
380
+ " <td>MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL...</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <th>28025</th>\n",
384
+ " <td>P86984</td>\n",
385
+ " <td>Eukaryota</td>\n",
386
+ " <td>4</td>\n",
387
+ " <td>0</td>\n",
388
+ " <td>0</td>\n",
389
+ " <td>0</td>\n",
390
+ " <td>1</td>\n",
391
+ " <td>MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL...</td>\n",
392
+ " </tr>\n",
393
+ " </tbody>\n",
394
+ "</table>\n",
395
+ "<p>28026 rows × 8 columns</p>\n",
396
+ "</div>"
397
+ ],
398
+ "text/plain": [
399
+ " ACC Kingdom Partition Peripheral Transmembrane LipidAnchor \\\n",
400
+ "0 I3R9M8 Archaea 0 1 0 0 \n",
401
+ "1 I3R9M9 Archaea 1 1 0 0 \n",
402
+ "2 Q7ZAG8 Archaea 2 1 0 0 \n",
403
+ "3 Q8PZ67 Archaea 0 1 0 0 \n",
404
+ "4 Q9YGA6 Archaea 0 1 0 0 \n",
405
+ "... ... ... ... ... ... ... \n",
406
+ "28021 P86949 Eukaryota 0 0 0 0 \n",
407
+ "28022 P86950 Eukaryota 0 0 0 0 \n",
408
+ "28023 P86951 Eukaryota 0 0 0 0 \n",
409
+ "28024 P86983 Eukaryota 3 0 0 0 \n",
410
+ "28025 P86984 Eukaryota 4 0 0 0 \n",
411
+ "\n",
412
+ " Soluble Sequence \n",
413
+ "0 0 MSTDSDAETVDLADGVDHQVAMVMDLNKCIGCQTCTVACKSLWTEG... \n",
414
+ "1 0 MSRNDASQLDDGETTAESPPDDQANDAPEVGDPPGDPVDADSGVSR... \n",
415
+ "2 0 MTKVLVLGGRFGALTAAYTLKRLVGSKADVKVINKSRFSYFRPALP... \n",
416
+ "3 1 MPPKIAEVIQHDVCAACGACEAVCPIGAVTVKKAAEIRDPNDLSLY... \n",
417
+ "4 0 MAGVRLVDVWKVFGEVTAVREMSLEVKDGEFMILLGPSGCGKTTTL... \n",
418
+ "... ... ... \n",
419
+ "28021 1 MLRFIAIVALIATVNAKGGTYGIGVLPSVTYVSGGGGGYPGIYGTY... \n",
420
+ "28022 1 MKPFISLASLIVLIASASAGGDDDYGKYGYGSYGPGIGGIGGGGGG... \n",
421
+ "28023 1 MLKLVCAVVLIATVNAKGSSPGFGIGQLPGITVVSGGVSGGSLSGG... \n",
422
+ "28024 1 MHQSSLGVLVLFSLIYLCISVHVPFDLNGWKALRLDNNRVQDSTNL... \n",
423
+ "28025 1 MLMLLCIIATVIPFSLVEGRKGCWADPTPPGKECLYGKEIHGGRNL... \n",
424
+ "\n",
425
+ "[28026 rows x 8 columns]"
426
+ ]
427
+ },
428
+ "execution_count": 9,
429
+ "metadata": {},
430
+ "output_type": "execute_result"
431
+ }
432
+ ],
433
+ "source": [
434
+ "df = pd.read_csv(path + \"/OG_membrane_type_all.csv\")\n",
435
+ "df = df.drop(columns=['Unnamed: 0'])\n",
436
+ "df"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 14,
442
+ "metadata": {},
443
+ "outputs": [],
444
+ "source": [
445
+ "train = df[df['Partition'] != 4]\n",
446
+ "test = df[df['Partition'] == 4]"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": 17,
452
+ "metadata": {},
453
+ "outputs": [],
454
+ "source": [
455
+ "train.to_csv(path + \"/membrane_type_train.csv\", index=False)\n",
456
+ "test.to_csv(path + \"/membrane_type_test.csv\", index=False)"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": []
465
+ }
466
+ ],
467
+ "metadata": {
468
+ "kernelspec": {
469
+ "display_name": "Python 3",
470
+ "language": "python",
471
+ "name": "python3"
472
+ },
473
+ "language_info": {
474
+ "codemirror_mode": {
475
+ "name": "ipython",
476
+ "version": 3
477
+ },
478
+ "file_extension": ".py",
479
+ "mimetype": "text/x-python",
480
+ "name": "python",
481
+ "nbconvert_exporter": "python",
482
+ "pygments_lexer": "ipython3",
483
+ "version": "3.10.12"
484
+ }
485
+ },
486
+ "nbformat": 4,
487
+ "nbformat_minor": 2
488
+ }